OpenSTL PredRNNv2 模型復現與自定義數據集訓練
概述
本文將詳細介紹如何復現 OpenSTL 中的 PredRNNv2 模型,并使用自定義的 NPY 格式數據集進行訓練和預測。我們將從環境配置開始,逐步講解數據預處理、模型構建、訓練過程和預測實現,最終實現輸入多張連續時間序列的 500×500 圖像并輸出相應數量預測圖像的目標。
目錄
- 環境配置與依賴安裝
- 數據集準備與預處理
- PredRNNv2 模型原理與架構
- 數據加載器實現
- 模型訓練流程
- 預測與結果可視化
- 模型評估與優化
- 完整代碼實現
- 常見問題與解決方案
- 總結與展望
1. 環境配置與依賴安裝
首先,我們需要創建一個合適的 Python 環境并安裝所有必要的依賴包。
# 創建conda環境
conda create -n openstl python=3.8
conda activate openstl# 安裝PyTorch (根據CUDA版本選擇)
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116# 安裝其他依賴
pip install numpy==1.21.6
pip install opencv-python==4.7.0.72
pip install matplotlib==3.5.3
pip install tensorboard==2.11.2
pip install scikit-learn==1.0.2
pip install tqdm==4.64.1
pip install nni==2.8
pip install timm==0.6.12
pip install einops==0.6.0
接下來,我們需要克隆 OpenSTL 倉庫并安裝相關依賴:
git clone https://github.com/chengtan9907/OpenSTL.git
cd OpenSTL
git checkout OpenSTL-Lightning
pip install -e .
2. 數據集準備與預處理
我們的數據集是 NPY 格式的文件,每張圖像尺寸為 500×500,且文件之間在時間上是連續的。首先,我們需要了解數據集的目錄結構:
dataset/
├── train/
│ ├── sequence_001/
│ │ ├── frame_001.npy
│ │ ├── frame_002.npy
│ │ └── ...
│ ├── sequence_002/
│ └── ...
├── valid/
└── test/
2.1 數據預處理類實現
我們需要創建一個數據預處理類,將 NPY 文件轉換為模型可用的格式:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import cv2class NPYDataset(Dataset):def __init__(self, data_root, mode='train', input_frames=10, output_frames=10, future_frames=10, transform=None, preprocess=True):"""初始化NPY數據集參數:data_root: 數據根目錄mode: 模式 ('train', 'valid', 'test')input_frames: 輸入幀數output_frames: 輸出幀數future_frames: 未來幀數 (預測幀數)transform: 數據轉換函數preprocess: 是否進行預處理"""self.data_root = os.path.join(data_root, mode)self.mode = modeself.input_frames = input_framesself.output_frames = output_framesself.future_frames = future_framesself.transform = transformself.preprocess = preprocess# 獲取所有序列self.sequences = []for seq_name in os.listdir(self.data_root):seq_path = os.path.join(self.data_root, seq_name)if os.path.isdir(seq_path):frames = sorted([f for f in os.listdir(seq_path) if f.endswith('.npy')])if len(frames) >= input_frames + future_frames:self.sequences.append((seq_path, frames))# 數據標準化器self.scaler = Noneif preprocess:self._init_scaler()def _init_scaler(self):"""初始化數據標準化器"""print(f"Initializing scaler for {self.mode} mode...")all_data = []for seq_path, frames in self.sequences:for frame_name in frames[:min(100, len(frames))]: # 使用前100幀計算統計量frame_path = os.path.join(seq_path, frame_name)data = np.load(frame_path)all_data.append(data.flatten())all_data = np.concatenate(all_data).reshape(-1, 1)self.scaler = StandardScaler()self.scaler.fit(all_data)print("Scaler initialized.")def _preprocess_data(self, data):"""預處理數據"""if self.preprocess and self.scaler is not None:original_shape = data.shapedata = data.flatten().reshape(-1, 1)data = self.scaler.transform(data)data = data.reshape(original_shape)return datadef _postprocess_data(self, data):"""后處理數據"""if self.preprocess and self.scaler is not None:original_shape = data.shapedata = data.flatten().reshape(-1, 1)data = self.scaler.inverse_transform(data)data = data.reshape(original_shape)return datadef __len__(self):return len(self.sequences)def __getitem__(self, idx):seq_path, frames = self.sequences[idx]# 隨機選擇起始幀total_frames = len(frames)max_start = total_frames - self.input_frames - self.future_framesstart_idx = np.random.randint(0, max_start + 1) if self.mode == 'train' else 0# 加載輸入幀input_frames = []for i in range(start_idx, start_idx + self.input_frames):frame_path = os.path.join(seq_path, frames[i])frame_data = np.load(frame_path)frame_data = self._preprocess_data(frame_data)input_frames.append(frame_data)# 加載目標幀target_frames = []for i in range(start_idx + self.input_frames, start_idx + self.input_frames + self.future_frames):frame_path = os.path.join(seq_path, frames[i])frame_data = np.load(frame_path)frame_data = self._preprocess_data(frame_data)target_frames.append(frame_data)# 轉換為numpy數組input_seq = np.stack(input_frames, axis=0)target_seq = np.stack(target_frames, axis=0)# 添加通道維度input_seq = np.expand_dims(input_seq, axis=1) # [T, 1, H, W]target_seq = np.expand_dims(target_seq, axis=1) # [T, 1, H, W]# 轉換為張量input_seq = torch.FloatTensor(input_seq)target_seq = torch.FloatTensor(target_seq)if self.transform:input_seq = self.transform(input_seq)target_seq = self.transform(target_seq)return input_seq, target_seq# 數據增強轉換
class RandomRotate:def __init__(self, angles=[0, 90, 180, 270]):self.angles = anglesdef __call__(self, x):angle = np.random.choice(self.angles)if angle == 0:return x# 旋轉每個幀rotated = []for i in range(x.shape[0]):frame = x[i].numpy()# 對于3D數據,我們需要分別旋轉每個通道if len(frame.shape) == 3:frame_rotated = np.stack([cv2.rotate(frame[c], cv2.ROTATE_90_CLOCKWISE) for c in range(frame.shape[0])], axis=0)else:frame_rotated = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)rotated.append(frame_rotated)return torch.FloatTensor(np.stack(rotated, axis=0))class RandomFlip:def __init__(self, p=0.5):self.p = pdef __call__(self, x):if np.random.random() < self.p:# 水平翻轉return x.flip(-1)return x
3. PredRNNv2 模型原理與架構
PredRNNv2 是一種改進的循環神經網絡,專門用于視頻預測任務。它通過引入時空記憶(STM)單元來更好地捕捉時空動態。
3.1 核心組件
import torch
import torch.nn as nn
from einops import rearrangeclass SpatioTemporalLSTMCell(nn.Module):def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):super(SpatioTemporalLSTMCell, self).__init__()self.num_hidden = num_hiddenself.padding = filter_size // 2self._forget_bias = 1.0# 卷積層self.conv_x = nn.Sequential(nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 7, height, width]))self.conv_h = nn.Sequential(nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 4, height, width]))self.conv_m = nn.Sequential(nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 3, height, width]))self.conv_o = nn.Sequential(nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden, height, width]))self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,stride=1, padding=0, bias=False)def forward(self, x_t, h_t, c_t, m_t):# 計算門控信號x_concat = self.conv_x(x_t)h_concat = self.conv_h(h_t)m_concat = self.conv_m(m_t)i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1)i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)i_t = torch.sigmoid(i_x + i_h)f_t = torch.sigmoid(f_x + f_h + self._forget_bias)g_t = torch.tanh(g_x + g_h)c_new = f_t * c_t + i_t * g_ti_t_prime = torch.sigmoid(i_x_prime + i_m)f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)g_t_prime = torch.tanh(g_x_prime + g_m)m_new = f_t_prime * m_t + i_t_prime * g_t_primemem = torch.cat((c_new, m_new), 1)o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))h_new = o_t * torch.tanh(self.conv_last(mem))return h_new, c_new, m_newclass PredRNNv2(nn.Module):def __init__(self, configs):super(PredRNNv2, self).__init__()self.configs = configsself.frame_channel = configs.patch_size * configs.patch_size * configs.img_channelself.num_layers = len(configs.num_hidden)self.num_hidden = configs.num_hiddenself.device = configs.device# 構建網絡cell_list = []height = configs.img_height // configs.patch_sizewidth = configs.img_width // configs.patch_sizefor i in range(self.num_layers):in_channel = self.frame_channel if i == 0 else self.num_hidden[i-1]cell_list.append(SpatioTemporalLSTMCell(in_channel, self.num_hidden[i], height, width,configs.filter_size, configs.stride, configs.layer_norm))self.cell_list = nn.ModuleList(cell_list)# 輸出層self.conv_last = nn.Conv2d(self.num_hidden[self.num_layers-1], self.frame_channel,kernel_size=1, stride=1, padding=0, bias=False)def forward(self, frames_tensor, mask_true):# frames_tensor: [batch, length, channel, height, width]batch = frames_tensor.shape[0]height = frames_tensor.shape[3]width = frames_tensor.shape[4]# 初始化隱藏狀態和記憶狀態next_frames = []h_t = []c_t = []m_t = []for i in range(self.num_layers):zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.device)h_t.append(zeros)c_t.append(zeros)m_t.append(zeros)# 記憶狀態memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.device)# 序列長度seq_length = self.configs.input_length + self.configs.total_lengthfor t in range(seq_length - 1):# 反向調度采樣if self.configs.reverse_scheduled_sampling == 1:if t == 0:net = frames_tensor[:, t]else:# 從真實數據或預測數據中采樣net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_genelse:# 常規訓練if t < self.configs.input_length:net = frames_tensor[:, t]else:# 從真實數據或預測數據中采樣net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_gen# 第一層h_t[0], c_t[0], m_t[0] = self.cell_list[0](net, h_t[0], c_t[0], m_t[0])# 后續層for i in range(1, self.num_layers):h_t[i], c_t[i], m_t[i] = self.cell_list[i](h_t[i-1], h_t[i], c_t[i], m_t[i])# 生成預測x_gen = self.conv_last(h_t[self.num_layers-1])next_frames.append(x_gen)# [length, batch, channel, height, width] -> [batch, length, channel, height, width]next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4)return next_frames
4. 數據加載器實現
接下來,我們需要實現數據加載器,將數據集轉換為模型可用的格式:
def create_data_loaders(configs):"""創建訓練、驗證和測試數據加載器"""# 數據轉換if configs.data_augmentation:train_transform = nn.Sequential(RandomRotate(),RandomFlip())else:train_transform = None# 創建數據集train_dataset = NPYDataset(data_root=configs.data_root,mode='train',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=train_transform,preprocess=configs.preprocess_data)valid_dataset = NPYDataset(data_root=configs.data_root,mode='valid',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)test_dataset = NPYDataset(data_root=configs.data_root,mode='test',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)# 創建數據加載器train_loader = DataLoader(train_dataset,batch_size=configs.batch_size,shuffle=True,num_workers=configs.num_workers,pin_memory=True)valid_loader = DataLoader(valid_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)test_loader = DataLoader(test_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)return train_loader, valid_loader, test_loader
5. 模型訓練流程
現在,我們實現完整的訓練流程,包括損失函數、優化器和學習率調度器:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdmclass Trainer:def __init__(self, configs, model, train_loader, valid_loader, test_loader):self.configs = configsself.model = modelself.train_loader = train_loaderself.valid_loader = valid_loaderself.test_loader = test_loaderself.device = configs.device# 損失函數self.criterion = nn.MSELoss()# 優化器self.optimizer = optim.Adam(model.parameters(),lr=configs.lr,weight_decay=configs.weight_decay)# 學習率調度器self.scheduler = ReduceLROnPlateau(self.optimizer,mode='min',factor=0.5,patience=5,verbose=True)# 記錄訓練歷史self.train_losses = []self.valid_losses = []self.best_loss = float('inf')# 創建檢查點目錄os.makedirs(configs.save_dir, exist_ok=True)def train_epoch(self, epoch):"""訓練一個epoch"""self.model.train()total_loss = 0progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)# 前向傳播self.optimizer.zero_grad()outputs = self.model(inputs, mask_true=None)# 計算損失loss = self.criterion(outputs, targets)# 反向傳播loss.backward()self.optimizer.step()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_lossdef validate(self):"""驗證模型"""self.model.eval()total_loss = 0with torch.no_grad():for inputs, targets in self.valid_loader:inputs = inputs.to(self.device)targets = targets.to(self.device)outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)total_loss += loss.item()avg_loss = total_loss / len(self.valid_loader)self.valid_losses.append(avg_loss)return avg_lossdef test(self):"""測試模型"""self.model.eval()total_loss = 0all_outputs = []all_targets = []with torch.no_grad():for inputs, targets in self.test_loader:inputs = inputs.to(self.device)targets = targets.to(self.device)outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)total_loss += loss.item()# 保存結果用于后續分析all_outputs.append(outputs.cpu().numpy())all_targets.append(targets.cpu().numpy())avg_loss = total_loss / len(self.test_loader)return avg_loss, np.concatenate(all_outputs, axis=0), np.concatenate(all_targets, axis=0)def save_checkpoint(self, epoch, is_best=False):"""保存檢查點"""checkpoint = {'epoch': epoch,'model_state_dict': self.model.state_dict(),'optimizer_state_dict': self.optimizer.state_dict(),'scheduler_state_dict': self.scheduler.state_dict(),'train_losses': self.train_losses,'valid_losses': self.valid_losses,'best_loss': self.best_loss}# 保存最新檢查點torch.save(checkpoint, os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))# 如果是最佳模型,保存為最佳檢查點if is_best:torch.save(checkpoint, os.path.join(self.configs.save_dir, 'best_checkpoint.pth'))def load_checkpoint(self, checkpoint_path):"""加載檢查點"""checkpoint = torch.load(checkpoint_path)self.model.load_state_dict(checkpoint['model_state_dict'])self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])self.train_losses = checkpoint['train_losses']self.valid_losses = checkpoint['valid_losses']self.best_loss = checkpoint['best_loss']return checkpoint['epoch']def train(self, num_epochs):"""完整訓練過程"""start_epoch = 0# 如果存在檢查點,加載檢查點if self.configs.resume and os.path.exists(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth')):print("Loading checkpoint...")start_epoch = self.load_checkpoint(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))print(f"Resumed from epoch {start_epoch}")for epoch in range(start_epoch, num_epochs):print(f"\nEpoch {epoch+1}/{num_epochs}")# 訓練train_loss = self.train_epoch(epoch)print(f"Train Loss: {train_loss:.6f}")# 驗證valid_loss = self.validate()print(f"Valid Loss: {valid_loss:.6f}")# 更新學習率self.scheduler.step(valid_loss)# 保存檢查點is_best = valid_loss < self.best_lossif is_best:self.best_loss = valid_lossself.save_checkpoint(epoch, is_best)# 每5個epoch測試一次if (epoch + 1) % 5 == 0:test_loss, _, _ = self.test()print(f"Test Loss: {test_loss:.6f}")# 最終測試print("\nFinal Testing...")test_loss, outputs, targets = self.test()print(f"Final Test Loss: {test_loss:.6f}")return test_loss, outputs, targets
6. 預測與結果可視化
實現預測功能和結果可視化:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGridclass Predictor:def __init__(self, configs, model):self.configs = configsself.model = modelself.device = configs.deviceself.model.eval()def predict(self, input_seq):"""預測未來幀"""with torch.no_grad():input_seq = input_seq.to(self.device)output_seq = self.model(input_seq, mask_true=None)return output_seq.cpu()def visualize_results(self, inputs, targets, predictions, save_path=None):"""可視化輸入、目標和預測結果"""# 選擇第一個批次進行可視化inputs = inputs[0].squeeze() # [T, H, W]targets = targets[0].squeeze() # [T, H, W]predictions = predictions[0].squeeze() # [T, H, W]# 創建子圖total_frames = inputs.shape[0] + targets.shape[0]fig = plt.figure(figsize=(20, 10))grid = ImageGrid(fig, 111, nrows_ncols=(3, total_frames), axes_pad=0.1)# 繪制輸入幀for i in range(inputs.shape[0]):ax = grid[i]ax.imshow(inputs[i], cmap='viridis')ax.set_title(f'Input {i+1}')ax.axis('off')# 繪制目標幀for i in range(targets.shape[0]):ax = grid[inputs.shape[0] + i]ax.imshow(targets[i], cmap='viridis')ax.set_title(f'Target {i+1}')ax.axis('off')# 繪制預測幀for i in range(predictions.shape[0]):ax = grid[inputs.shape[0] + targets.shape[0] + i]ax.imshow(predictions[i], cmap='viridis')ax.set_title(f'Pred {i+1}')ax.axis('off')plt.tight_layout()if save_path:plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()def save_predictions(self, predictions, save_dir):"""保存預測結果為NPY文件"""os.makedirs(save_dir, exist_ok=True)for i, pred_seq in enumerate(predictions):for j, frame in enumerate(pred_seq):frame_path = os.path.join(save_dir, f'batch_{i}_frame_{j}.npy')np.save(frame_path, frame.squeeze())def evaluate_metrics(self, targets, predictions):"""評估預測性能"""from sklearn.metrics import mean_squared_error, mean_absolute_error# 展平數據targets_flat = targets.flatten()predictions_flat = predictions.flatten()# 計算指標mse = mean_squared_error(targets_flat, predictions_flat)mae = mean_absolute_error(targets_flat, predictions_flat)rmse = np.sqrt(mse)# 計算PSNRmax_val = np.max(targets_flat)psnr = 20 * np.log10(max_val / rmse) if rmse > 0 else float('inf')# 計算SSIM (需要安裝skimage)try:from skimage.metrics import structural_similarity as ssim_funcssim = ssim_func(targets_flat.reshape(targets.shape), predictions_flat.reshape(targets.shape),data_range=max_val)except ImportError:ssim = 0print("SSIM calculation requires skimage. Install with: pip install scikit-image")return {'MSE': mse,'MAE': mae,'RMSE': rmse,'PSNR': psnr,'SSIM': ssim}
7. 模型評估與優化
實現模型評估和超參數優化功能:
def hyperparameter_optimization(configs):"""超參數優化"""import nni# 獲取NNI超參數optimized_params = nni.get_next_parameter()configs.lr = optimized_params.get('lr', configs.lr)configs.batch_size = optimized_params.get('batch_size', configs.batch_size)configs.num_hidden = optimized_params.get('num_hidden', configs.num_hidden)# 創建模型和數據加載器model = PredRNNv2(configs).to(configs.device)train_loader, valid_loader, test_loader = create_data_loaders(configs)# 訓練模型trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)test_loss, _, _ = trainer.train(configs.epoch)# 報告最終結果nni.report_final_result(test_loss)return test_lossdef analyze_results(configs, outputs, targets):"""分析預測結果"""predictor = Predictor(configs, None)metrics = predictor.evaluate_metrics(targets, outputs)print("Evaluation Metrics:")for metric, value in metrics.items():print(f"{metric}: {value:.4f}")# 繪制損失曲線plt.figure(figsize=(10, 6))plt.plot(range(len(outputs)), outputs.flatten(), label='Predictions', alpha=0.7)plt.plot(range(len(targets)), targets.flatten(), label='Targets', alpha=0.7)plt.xlabel('Sample Index')plt.ylabel('Value')plt.title('Predictions vs Targets')plt.legend()plt.grid(True)plt.savefig(os.path.join(configs.save_dir, 'predictions_vs_targets.png'), dpi=300)plt.show()return metrics
8. 完整代碼實現
現在,我們將所有組件整合到一個完整的腳本中:
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from models import PredRNNv2
from data_loader import NPYDataset, create_data_loaders
from trainer import Trainer
from predictor import Predictor
from utils import analyze_resultsdef parse_args():parser = argparse.ArgumentParser(description='PredRNNv2 for NPY dataset')# 數據參數parser.add_argument('--data_root', type=str, default='./dataset', help='數據集根目錄')parser.add_argument('--input_length', type=int, default=10, help='輸入幀數')parser.add_argument('--total_length', type=int, default=20, help='總幀數(輸入+預測)')parser.add_argument('--img_width', type=int, default=500, help='圖像寬度')parser.add_argument('--img_height', type=int, default=500, help='圖像高度')parser.add_argument('--img_channel', type=int, default=1, help='圖像通道數')parser.add_argument('--preprocess_data', type=bool, default=True, help='是否預處理數據')parser.add_argument('--data_augmentation', type=bool, default=True, help='是否使用數據增強')# 模型參數parser.add_argument('--num_hidden', type=list, default=[64, 64, 64, 64], help='每層隱藏單元數')parser.add_argument('--filter_size', type=int, default=5, help='濾波器大小')parser.add_argument('--stride', type=int, default=1, help='步長')parser.add_argument('--patch_size', type=int, default=1, help='補丁大小')parser.add_argument('--layer_norm', type=bool, default=True, help='是否使用層歸一化')parser.add_argument('--reverse_scheduled_sampling', type=int, default=0, help='反向調度采樣')# 訓練參數parser.add_argument('--batch_size', type=int, default=4, help='批次大小')parser.add_argument('--lr', type=float, default=1e-3, help='學習率')parser.add_argument('--weight_decay', type=float, default=0, help='權重衰減')parser.add_argument('--epoch', type=int, default=100, help='訓練輪數')parser.add_argument('--num_workers', type=int, default=4, help='數據加載工作線程數')parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='設備')parser.add_argument('--save_dir', type=str, default='./checkpoints', help='保存目錄')parser.add_argument('--resume', type=bool, default=False, help='是否恢復訓練')# 其他參數parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'predict'], help='運行模式')parser.add_argument('--checkpoint_path', type=str, default='', help='檢查點路徑')return parser.parse_args()def main():# 解析參數configs = parse_args()# 創建保存目錄os.makedirs(configs.save_dir, exist_ok=True)# 創建模型model = PredRNNv2(configs).to(configs.device)print(f"模型參數量: {sum(p.numel() for p in model.parameters()):,}")if configs.mode == 'train':# 創建數據加載器train_loader, valid_loader, test_loader = create_data_loaders(configs)# 創建訓練器并開始訓練trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)test_loss, outputs, targets = trainer.train(configs.epoch)# 分析結果analyze_results(configs, outputs, targets)elif configs.mode == 'test':# 加載檢查點if configs.checkpoint_path:checkpoint = torch.load(configs.checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])print(f"Loaded checkpoint from {configs.checkpoint_path}")# 創建數據加載器_, _, test_loader = create_data_loaders(configs)# 測試模型trainer = Trainer(configs, model, None, None, test_loader)test_loss, outputs, targets = trainer.test()print(f"Test Loss: {test_loss:.6f}")# 分析結果metrics = analyze_results(configs, outputs, targets)# 保存結果np.save(os.path.join(configs.save_dir, 'test_outputs.npy'), outputs)np.save(os.path.join(configs.save_dir, 'test_targets.npy'), targets)elif configs.mode == 'predict':# 加載檢查點if configs.checkpoint_path:checkpoint = torch.load(configs.checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])print(f"Loaded checkpoint from {configs.checkpoint_path}")# 創建預測器predictor = Predictor(configs, model)# 加載要預測的數據# 這里假設有一個單獨的預測數據集predict_dataset = NPYDataset(data_root=configs.data_root,mode='predict',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)predict_loader = DataLoader(predict_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)all_predictions = []all_inputs = []with torch.no_grad():for inputs, _ in predict_loader:inputs = inputs.to(configs.device)predictions = predictor.predict(inputs)all_predictions.append(predictions.numpy())all_inputs.append(inputs.cpu().numpy())all_predictions = np.concatenate(all_predictions, axis=0)all_inputs = np.concatenate(all_inputs, axis=0)# 保存預測結果output_dir = os.path.join(configs.save_dir, 'predictions')os.makedirs(output_dir, exist_ok=True)for i, (input_seq, pred_seq) in enumerate(zip(all_inputs, all_predictions)):# 保存輸入序列for j, frame in enumerate(input_seq):frame_path = os.path.join(output_dir, f'sequence_{i:03d}_input_{j:03d}.npy')np.save(frame_path, frame.squeeze())# 保存預測序列for j, frame in enumerate(pred_seq):frame_path = os.path.join(output_dir, f'sequence_{i:03d}_pred_{j:03d}.npy')np.save(frame_path, frame.squeeze())print(f"Predictions saved to {output_dir}")# 可視化一些結果if len(all_inputs) > 0:sample_idx = 0predictor.visualize_results(all_inputs[sample_idx:sample_idx+1],all_predictions[sample_idx:sample_idx+1],all_predictions[sample_idx:sample_idx+1],save_path=os.path.join(output_dir, 'sample_prediction.png'))if __name__ == '__main__':main()
9. 常見問題與解決方案
9.1 內存不足問題
當處理 500×500 的大尺寸圖像時,可能會遇到內存不足的問題。解決方案:
- 使用數據分塊:將大圖像分割成小塊進行處理
- 降低批次大小:減少每次處理的樣本數量
- 使用混合精度訓練:使用半精度浮點數減少內存占用
# 混合精度訓練示例
from torch.cuda.amp import autocast, GradScalerdef train_epoch_with_amp(self, epoch):"""使用混合精度訓練一個epoch"""self.model.train()total_loss = 0scaler = GradScaler()progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)# 使用自動混合精度with autocast():outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)# 縮放損失并反向傳播self.optimizer.zero_grad()scaler.scale(loss).backward()scaler.step(self.optimizer)scaler.update()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_loss
9.2 訓練不穩定問題
PredRNNv2 模型訓練可能會不穩定,可以嘗試以下方法:
- 梯度裁剪:防止梯度爆炸
- 學習率調度:動態調整學習率
- 權重初始化:使用合適的初始化方法
# 梯度裁剪示例
def train_epoch_with_gradient_clipping(self, epoch, clip_value=1.0):"""帶梯度裁剪的訓練"""self.model.train()total_loss = 0progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_value)self.optimizer.step()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_loss
9.3 過擬合問題
當模型在訓練集上表現良好但在驗證集上表現不佳時,可能存在過擬合問題:
- 數據增強:增加數據多樣性
- 正則化:使用 Dropout 或權重衰減
- 早停:在驗證損失不再改善時停止訓練
# 早停實現
class EarlyStopping:def __init__(self, patience=10, min_delta=0):self.patience = patienceself.min_delta = min_deltaself.counter = 0self.best_loss = Noneself.early_stop = Falsedef __call__(self, val_loss):if self.best_loss is None:self.best_loss = val_losselif val_loss > self.best_loss - self.min_delta:self.counter += 1if self.counter >= self.patience:self.early_stop = Trueelse:self.best_loss = val_lossself.counter = 0return self.early_stop# 在訓練循環中使用早停
early_stopping = EarlyStopping(patience=10)for epoch in range(num_epochs):# 訓練和驗證...if early_stopping(valid_loss):print("Early stopping triggered")break
10. 總結與展望
本文詳細介紹了如何復現 OpenSTL 中的 PredRNNv2 模型,并使用自定義的 NPY 格式數據集進行訓練和預測。我們涵蓋了從環境配置、數據預處理、模型構建到訓練和評估的完整流程。
10.1 主要成果
- 完整的數據處理流程:實現了針對 NPY 格式數據的加載、預處理和增強功能
- PredRNNv2 模型復現:成功實現了 PredRNNv2 模型的核心組件和完整架構
- 訓練框架:構建了完整的訓練、驗證和測試流程,包括損失函數、優化器和學習率調度
- 預測與可視化:實現了預測功能和結果可視化,便于分析模型性能
- 問題解決方案:提供了針對常見問題(內存不足、訓練不穩定、過擬合)的解決方案
10.2 未來工作方向
- 模型優化:嘗試更先進的視頻預測模型,如 SimVP、PhyDNet 等
- 多模態融合:結合其他傳感器數據(如氣象數據、地理信息)提高預測精度
- 實時預測:優化模型推理速度,實現實時預測功能
- 不確定性量化:增加對預測結果不確定性的估計
- 部署優化:將模型部署到生產環境,支持大規模數據處理
通過本文的指導和代碼實現,讀者應該能夠成功復現 PredRNNv2 模型,并在自己的數據集上進行訓練和預測。希望這項工作能夠為視頻預測任務的研究和應用提供有價值的參考。