OpenSTL PredRNNv2 模型復現與自定義數據集訓練

OpenSTL PredRNNv2 模型復現與自定義數據集訓練

概述

本文將詳細介紹如何復現 OpenSTL 中的 PredRNNv2 模型,并使用自定義的 NPY 格式數據集進行訓練和預測。我們將從環境配置開始,逐步講解數據預處理、模型構建、訓練過程和預測實現,最終實現輸入多張連續時間序列的 500×500 圖像并輸出相應數量預測圖像的目標。

目錄

  1. 環境配置與依賴安裝
  2. 數據集準備與預處理
  3. PredRNNv2 模型原理與架構
  4. 數據加載器實現
  5. 模型訓練流程
  6. 預測與結果可視化
  7. 模型評估與優化
  8. 完整代碼實現
  9. 常見問題與解決方案
  10. 總結與展望

1. 環境配置與依賴安裝

首先,我們需要創建一個合適的 Python 環境并安裝所有必要的依賴包。

# 創建conda環境
conda create -n openstl python=3.8
conda activate openstl# 安裝PyTorch (根據CUDA版本選擇)
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116# 安裝其他依賴
pip install numpy==1.21.6
pip install opencv-python==4.7.0.72
pip install matplotlib==3.5.3
pip install tensorboard==2.11.2
pip install scikit-learn==1.0.2
pip install tqdm==4.64.1
pip install nni==2.8
pip install timm==0.6.12
pip install einops==0.6.0

接下來,我們需要克隆 OpenSTL 倉庫并安裝相關依賴:

git clone https://github.com/chengtan9907/OpenSTL.git
cd OpenSTL
git checkout OpenSTL-Lightning
pip install -e .

2. 數據集準備與預處理

我們的數據集是 NPY 格式的文件,每張圖像尺寸為 500×500,且文件之間在時間上是連續的。首先,我們需要了解數據集的目錄結構:

dataset/
├── train/
│   ├── sequence_001/
│   │   ├── frame_001.npy
│   │   ├── frame_002.npy
│   │   └── ...
│   ├── sequence_002/
│   └── ...
├── valid/
└── test/

2.1 數據預處理類實現

我們需要創建一個數據預處理類,將 NPY 文件轉換為模型可用的格式:

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import cv2class NPYDataset(Dataset):def __init__(self, data_root, mode='train', input_frames=10, output_frames=10, future_frames=10, transform=None, preprocess=True):"""初始化NPY數據集參數:data_root: 數據根目錄mode: 模式 ('train', 'valid', 'test')input_frames: 輸入幀數output_frames: 輸出幀數future_frames: 未來幀數 (預測幀數)transform: 數據轉換函數preprocess: 是否進行預處理"""self.data_root = os.path.join(data_root, mode)self.mode = modeself.input_frames = input_framesself.output_frames = output_framesself.future_frames = future_framesself.transform = transformself.preprocess = preprocess# 獲取所有序列self.sequences = []for seq_name in os.listdir(self.data_root):seq_path = os.path.join(self.data_root, seq_name)if os.path.isdir(seq_path):frames = sorted([f for f in os.listdir(seq_path) if f.endswith('.npy')])if len(frames) >= input_frames + future_frames:self.sequences.append((seq_path, frames))# 數據標準化器self.scaler = Noneif preprocess:self._init_scaler()def _init_scaler(self):"""初始化數據標準化器"""print(f"Initializing scaler for {self.mode} mode...")all_data = []for seq_path, frames in self.sequences:for frame_name in frames[:min(100, len(frames))]:  # 使用前100幀計算統計量frame_path = os.path.join(seq_path, frame_name)data = np.load(frame_path)all_data.append(data.flatten())all_data = np.concatenate(all_data).reshape(-1, 1)self.scaler = StandardScaler()self.scaler.fit(all_data)print("Scaler initialized.")def _preprocess_data(self, data):"""預處理數據"""if self.preprocess and self.scaler is not None:original_shape = data.shapedata = data.flatten().reshape(-1, 1)data = self.scaler.transform(data)data = data.reshape(original_shape)return datadef _postprocess_data(self, data):"""后處理數據"""if self.preprocess and self.scaler is not None:original_shape = data.shapedata = data.flatten().reshape(-1, 1)data = self.scaler.inverse_transform(data)data = data.reshape(original_shape)return datadef __len__(self):return len(self.sequences)def __getitem__(self, idx):seq_path, frames = self.sequences[idx]# 隨機選擇起始幀total_frames = len(frames)max_start = total_frames - self.input_frames - self.future_framesstart_idx = np.random.randint(0, max_start + 1) if self.mode == 'train' else 0# 加載輸入幀input_frames = []for i in range(start_idx, start_idx + self.input_frames):frame_path = os.path.join(seq_path, frames[i])frame_data = np.load(frame_path)frame_data = self._preprocess_data(frame_data)input_frames.append(frame_data)# 加載目標幀target_frames = []for i in range(start_idx + self.input_frames, start_idx + self.input_frames + self.future_frames):frame_path = os.path.join(seq_path, frames[i])frame_data = np.load(frame_path)frame_data = self._preprocess_data(frame_data)target_frames.append(frame_data)# 轉換為numpy數組input_seq = np.stack(input_frames, axis=0)target_seq = np.stack(target_frames, axis=0)# 添加通道維度input_seq = np.expand_dims(input_seq, axis=1)  # [T, 1, H, W]target_seq = np.expand_dims(target_seq, axis=1)  # [T, 1, H, W]# 轉換為張量input_seq = torch.FloatTensor(input_seq)target_seq = torch.FloatTensor(target_seq)if self.transform:input_seq = self.transform(input_seq)target_seq = self.transform(target_seq)return input_seq, target_seq# 數據增強轉換
class RandomRotate:def __init__(self, angles=[0, 90, 180, 270]):self.angles = anglesdef __call__(self, x):angle = np.random.choice(self.angles)if angle == 0:return x# 旋轉每個幀rotated = []for i in range(x.shape[0]):frame = x[i].numpy()# 對于3D數據,我們需要分別旋轉每個通道if len(frame.shape) == 3:frame_rotated = np.stack([cv2.rotate(frame[c], cv2.ROTATE_90_CLOCKWISE) for c in range(frame.shape[0])], axis=0)else:frame_rotated = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)rotated.append(frame_rotated)return torch.FloatTensor(np.stack(rotated, axis=0))class RandomFlip:def __init__(self, p=0.5):self.p = pdef __call__(self, x):if np.random.random() < self.p:# 水平翻轉return x.flip(-1)return x

3. PredRNNv2 模型原理與架構

PredRNNv2 是一種改進的循環神經網絡,專門用于視頻預測任務。它通過引入時空記憶(STM)單元來更好地捕捉時空動態。

3.1 核心組件

import torch
import torch.nn as nn
from einops import rearrangeclass SpatioTemporalLSTMCell(nn.Module):def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):super(SpatioTemporalLSTMCell, self).__init__()self.num_hidden = num_hiddenself.padding = filter_size // 2self._forget_bias = 1.0# 卷積層self.conv_x = nn.Sequential(nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 7, height, width]))self.conv_h = nn.Sequential(nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 4, height, width]))self.conv_m = nn.Sequential(nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 3, height, width]))self.conv_o = nn.Sequential(nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden, height, width]))self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,stride=1, padding=0, bias=False)def forward(self, x_t, h_t, c_t, m_t):# 計算門控信號x_concat = self.conv_x(x_t)h_concat = self.conv_h(h_t)m_concat = self.conv_m(m_t)i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1)i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)i_t = torch.sigmoid(i_x + i_h)f_t = torch.sigmoid(f_x + f_h + self._forget_bias)g_t = torch.tanh(g_x + g_h)c_new = f_t * c_t + i_t * g_ti_t_prime = torch.sigmoid(i_x_prime + i_m)f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)g_t_prime = torch.tanh(g_x_prime + g_m)m_new = f_t_prime * m_t + i_t_prime * g_t_primemem = torch.cat((c_new, m_new), 1)o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))h_new = o_t * torch.tanh(self.conv_last(mem))return h_new, c_new, m_newclass PredRNNv2(nn.Module):def __init__(self, configs):super(PredRNNv2, self).__init__()self.configs = configsself.frame_channel = configs.patch_size * configs.patch_size * configs.img_channelself.num_layers = len(configs.num_hidden)self.num_hidden = configs.num_hiddenself.device = configs.device# 構建網絡cell_list = []height = configs.img_height // configs.patch_sizewidth = configs.img_width // configs.patch_sizefor i in range(self.num_layers):in_channel = self.frame_channel if i == 0 else self.num_hidden[i-1]cell_list.append(SpatioTemporalLSTMCell(in_channel, self.num_hidden[i], height, width,configs.filter_size, configs.stride, configs.layer_norm))self.cell_list = nn.ModuleList(cell_list)# 輸出層self.conv_last = nn.Conv2d(self.num_hidden[self.num_layers-1], self.frame_channel,kernel_size=1, stride=1, padding=0, bias=False)def forward(self, frames_tensor, mask_true):#  frames_tensor: [batch, length, channel, height, width]batch = frames_tensor.shape[0]height = frames_tensor.shape[3]width = frames_tensor.shape[4]# 初始化隱藏狀態和記憶狀態next_frames = []h_t = []c_t = []m_t = []for i in range(self.num_layers):zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.device)h_t.append(zeros)c_t.append(zeros)m_t.append(zeros)# 記憶狀態memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.device)# 序列長度seq_length = self.configs.input_length + self.configs.total_lengthfor t in range(seq_length - 1):# 反向調度采樣if self.configs.reverse_scheduled_sampling == 1:if t == 0:net = frames_tensor[:, t]else:# 從真實數據或預測數據中采樣net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_genelse:# 常規訓練if t < self.configs.input_length:net = frames_tensor[:, t]else:# 從真實數據或預測數據中采樣net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_gen# 第一層h_t[0], c_t[0], m_t[0] = self.cell_list[0](net, h_t[0], c_t[0], m_t[0])# 后續層for i in range(1, self.num_layers):h_t[i], c_t[i], m_t[i] = self.cell_list[i](h_t[i-1], h_t[i], c_t[i], m_t[i])# 生成預測x_gen = self.conv_last(h_t[self.num_layers-1])next_frames.append(x_gen)# [length, batch, channel, height, width] -> [batch, length, channel, height, width]next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4)return next_frames

4. 數據加載器實現

接下來,我們需要實現數據加載器,將數據集轉換為模型可用的格式:

def create_data_loaders(configs):"""創建訓練、驗證和測試數據加載器"""# 數據轉換if configs.data_augmentation:train_transform = nn.Sequential(RandomRotate(),RandomFlip())else:train_transform = None# 創建數據集train_dataset = NPYDataset(data_root=configs.data_root,mode='train',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=train_transform,preprocess=configs.preprocess_data)valid_dataset = NPYDataset(data_root=configs.data_root,mode='valid',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)test_dataset = NPYDataset(data_root=configs.data_root,mode='test',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)# 創建數據加載器train_loader = DataLoader(train_dataset,batch_size=configs.batch_size,shuffle=True,num_workers=configs.num_workers,pin_memory=True)valid_loader = DataLoader(valid_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)test_loader = DataLoader(test_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)return train_loader, valid_loader, test_loader

5. 模型訓練流程

現在,我們實現完整的訓練流程,包括損失函數、優化器和學習率調度器:

import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdmclass Trainer:def __init__(self, configs, model, train_loader, valid_loader, test_loader):self.configs = configsself.model = modelself.train_loader = train_loaderself.valid_loader = valid_loaderself.test_loader = test_loaderself.device = configs.device# 損失函數self.criterion = nn.MSELoss()# 優化器self.optimizer = optim.Adam(model.parameters(),lr=configs.lr,weight_decay=configs.weight_decay)# 學習率調度器self.scheduler = ReduceLROnPlateau(self.optimizer,mode='min',factor=0.5,patience=5,verbose=True)# 記錄訓練歷史self.train_losses = []self.valid_losses = []self.best_loss = float('inf')# 創建檢查點目錄os.makedirs(configs.save_dir, exist_ok=True)def train_epoch(self, epoch):"""訓練一個epoch"""self.model.train()total_loss = 0progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)# 前向傳播self.optimizer.zero_grad()outputs = self.model(inputs, mask_true=None)# 計算損失loss = self.criterion(outputs, targets)# 反向傳播loss.backward()self.optimizer.step()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_lossdef validate(self):"""驗證模型"""self.model.eval()total_loss = 0with torch.no_grad():for inputs, targets in self.valid_loader:inputs = inputs.to(self.device)targets = targets.to(self.device)outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)total_loss += loss.item()avg_loss = total_loss / len(self.valid_loader)self.valid_losses.append(avg_loss)return avg_lossdef test(self):"""測試模型"""self.model.eval()total_loss = 0all_outputs = []all_targets = []with torch.no_grad():for inputs, targets in self.test_loader:inputs = inputs.to(self.device)targets = targets.to(self.device)outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)total_loss += loss.item()# 保存結果用于后續分析all_outputs.append(outputs.cpu().numpy())all_targets.append(targets.cpu().numpy())avg_loss = total_loss / len(self.test_loader)return avg_loss, np.concatenate(all_outputs, axis=0), np.concatenate(all_targets, axis=0)def save_checkpoint(self, epoch, is_best=False):"""保存檢查點"""checkpoint = {'epoch': epoch,'model_state_dict': self.model.state_dict(),'optimizer_state_dict': self.optimizer.state_dict(),'scheduler_state_dict': self.scheduler.state_dict(),'train_losses': self.train_losses,'valid_losses': self.valid_losses,'best_loss': self.best_loss}# 保存最新檢查點torch.save(checkpoint, os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))# 如果是最佳模型,保存為最佳檢查點if is_best:torch.save(checkpoint, os.path.join(self.configs.save_dir, 'best_checkpoint.pth'))def load_checkpoint(self, checkpoint_path):"""加載檢查點"""checkpoint = torch.load(checkpoint_path)self.model.load_state_dict(checkpoint['model_state_dict'])self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])self.train_losses = checkpoint['train_losses']self.valid_losses = checkpoint['valid_losses']self.best_loss = checkpoint['best_loss']return checkpoint['epoch']def train(self, num_epochs):"""完整訓練過程"""start_epoch = 0# 如果存在檢查點,加載檢查點if self.configs.resume and os.path.exists(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth')):print("Loading checkpoint...")start_epoch = self.load_checkpoint(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))print(f"Resumed from epoch {start_epoch}")for epoch in range(start_epoch, num_epochs):print(f"\nEpoch {epoch+1}/{num_epochs}")# 訓練train_loss = self.train_epoch(epoch)print(f"Train Loss: {train_loss:.6f}")# 驗證valid_loss = self.validate()print(f"Valid Loss: {valid_loss:.6f}")# 更新學習率self.scheduler.step(valid_loss)# 保存檢查點is_best = valid_loss < self.best_lossif is_best:self.best_loss = valid_lossself.save_checkpoint(epoch, is_best)# 每5個epoch測試一次if (epoch + 1) % 5 == 0:test_loss, _, _ = self.test()print(f"Test Loss: {test_loss:.6f}")# 最終測試print("\nFinal Testing...")test_loss, outputs, targets = self.test()print(f"Final Test Loss: {test_loss:.6f}")return test_loss, outputs, targets

6. 預測與結果可視化

實現預測功能和結果可視化:

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGridclass Predictor:def __init__(self, configs, model):self.configs = configsself.model = modelself.device = configs.deviceself.model.eval()def predict(self, input_seq):"""預測未來幀"""with torch.no_grad():input_seq = input_seq.to(self.device)output_seq = self.model(input_seq, mask_true=None)return output_seq.cpu()def visualize_results(self, inputs, targets, predictions, save_path=None):"""可視化輸入、目標和預測結果"""# 選擇第一個批次進行可視化inputs = inputs[0].squeeze()  # [T, H, W]targets = targets[0].squeeze()  # [T, H, W]predictions = predictions[0].squeeze()  # [T, H, W]# 創建子圖total_frames = inputs.shape[0] + targets.shape[0]fig = plt.figure(figsize=(20, 10))grid = ImageGrid(fig, 111, nrows_ncols=(3, total_frames), axes_pad=0.1)# 繪制輸入幀for i in range(inputs.shape[0]):ax = grid[i]ax.imshow(inputs[i], cmap='viridis')ax.set_title(f'Input {i+1}')ax.axis('off')# 繪制目標幀for i in range(targets.shape[0]):ax = grid[inputs.shape[0] + i]ax.imshow(targets[i], cmap='viridis')ax.set_title(f'Target {i+1}')ax.axis('off')# 繪制預測幀for i in range(predictions.shape[0]):ax = grid[inputs.shape[0] + targets.shape[0] + i]ax.imshow(predictions[i], cmap='viridis')ax.set_title(f'Pred {i+1}')ax.axis('off')plt.tight_layout()if save_path:plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()def save_predictions(self, predictions, save_dir):"""保存預測結果為NPY文件"""os.makedirs(save_dir, exist_ok=True)for i, pred_seq in enumerate(predictions):for j, frame in enumerate(pred_seq):frame_path = os.path.join(save_dir, f'batch_{i}_frame_{j}.npy')np.save(frame_path, frame.squeeze())def evaluate_metrics(self, targets, predictions):"""評估預測性能"""from sklearn.metrics import mean_squared_error, mean_absolute_error# 展平數據targets_flat = targets.flatten()predictions_flat = predictions.flatten()# 計算指標mse = mean_squared_error(targets_flat, predictions_flat)mae = mean_absolute_error(targets_flat, predictions_flat)rmse = np.sqrt(mse)# 計算PSNRmax_val = np.max(targets_flat)psnr = 20 * np.log10(max_val / rmse) if rmse > 0 else float('inf')# 計算SSIM (需要安裝skimage)try:from skimage.metrics import structural_similarity as ssim_funcssim = ssim_func(targets_flat.reshape(targets.shape), predictions_flat.reshape(targets.shape),data_range=max_val)except ImportError:ssim = 0print("SSIM calculation requires skimage. Install with: pip install scikit-image")return {'MSE': mse,'MAE': mae,'RMSE': rmse,'PSNR': psnr,'SSIM': ssim}

7. 模型評估與優化

實現模型評估和超參數優化功能:

def hyperparameter_optimization(configs):"""超參數優化"""import nni# 獲取NNI超參數optimized_params = nni.get_next_parameter()configs.lr = optimized_params.get('lr', configs.lr)configs.batch_size = optimized_params.get('batch_size', configs.batch_size)configs.num_hidden = optimized_params.get('num_hidden', configs.num_hidden)# 創建模型和數據加載器model = PredRNNv2(configs).to(configs.device)train_loader, valid_loader, test_loader = create_data_loaders(configs)# 訓練模型trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)test_loss, _, _ = trainer.train(configs.epoch)# 報告最終結果nni.report_final_result(test_loss)return test_lossdef analyze_results(configs, outputs, targets):"""分析預測結果"""predictor = Predictor(configs, None)metrics = predictor.evaluate_metrics(targets, outputs)print("Evaluation Metrics:")for metric, value in metrics.items():print(f"{metric}: {value:.4f}")# 繪制損失曲線plt.figure(figsize=(10, 6))plt.plot(range(len(outputs)), outputs.flatten(), label='Predictions', alpha=0.7)plt.plot(range(len(targets)), targets.flatten(), label='Targets', alpha=0.7)plt.xlabel('Sample Index')plt.ylabel('Value')plt.title('Predictions vs Targets')plt.legend()plt.grid(True)plt.savefig(os.path.join(configs.save_dir, 'predictions_vs_targets.png'), dpi=300)plt.show()return metrics

8. 完整代碼實現

現在,我們將所有組件整合到一個完整的腳本中:

import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from models import PredRNNv2
from data_loader import NPYDataset, create_data_loaders
from trainer import Trainer
from predictor import Predictor
from utils import analyze_resultsdef parse_args():parser = argparse.ArgumentParser(description='PredRNNv2 for NPY dataset')# 數據參數parser.add_argument('--data_root', type=str, default='./dataset', help='數據集根目錄')parser.add_argument('--input_length', type=int, default=10, help='輸入幀數')parser.add_argument('--total_length', type=int, default=20, help='總幀數(輸入+預測)')parser.add_argument('--img_width', type=int, default=500, help='圖像寬度')parser.add_argument('--img_height', type=int, default=500, help='圖像高度')parser.add_argument('--img_channel', type=int, default=1, help='圖像通道數')parser.add_argument('--preprocess_data', type=bool, default=True, help='是否預處理數據')parser.add_argument('--data_augmentation', type=bool, default=True, help='是否使用數據增強')# 模型參數parser.add_argument('--num_hidden', type=list, default=[64, 64, 64, 64], help='每層隱藏單元數')parser.add_argument('--filter_size', type=int, default=5, help='濾波器大小')parser.add_argument('--stride', type=int, default=1, help='步長')parser.add_argument('--patch_size', type=int, default=1, help='補丁大小')parser.add_argument('--layer_norm', type=bool, default=True, help='是否使用層歸一化')parser.add_argument('--reverse_scheduled_sampling', type=int, default=0, help='反向調度采樣')# 訓練參數parser.add_argument('--batch_size', type=int, default=4, help='批次大小')parser.add_argument('--lr', type=float, default=1e-3, help='學習率')parser.add_argument('--weight_decay', type=float, default=0, help='權重衰減')parser.add_argument('--epoch', type=int, default=100, help='訓練輪數')parser.add_argument('--num_workers', type=int, default=4, help='數據加載工作線程數')parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='設備')parser.add_argument('--save_dir', type=str, default='./checkpoints', help='保存目錄')parser.add_argument('--resume', type=bool, default=False, help='是否恢復訓練')# 其他參數parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'predict'], help='運行模式')parser.add_argument('--checkpoint_path', type=str, default='', help='檢查點路徑')return parser.parse_args()def main():# 解析參數configs = parse_args()# 創建保存目錄os.makedirs(configs.save_dir, exist_ok=True)# 創建模型model = PredRNNv2(configs).to(configs.device)print(f"模型參數量: {sum(p.numel() for p in model.parameters()):,}")if configs.mode == 'train':# 創建數據加載器train_loader, valid_loader, test_loader = create_data_loaders(configs)# 創建訓練器并開始訓練trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)test_loss, outputs, targets = trainer.train(configs.epoch)# 分析結果analyze_results(configs, outputs, targets)elif configs.mode == 'test':# 加載檢查點if configs.checkpoint_path:checkpoint = torch.load(configs.checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])print(f"Loaded checkpoint from {configs.checkpoint_path}")# 創建數據加載器_, _, test_loader = create_data_loaders(configs)# 測試模型trainer = Trainer(configs, model, None, None, test_loader)test_loss, outputs, targets = trainer.test()print(f"Test Loss: {test_loss:.6f}")# 分析結果metrics = analyze_results(configs, outputs, targets)# 保存結果np.save(os.path.join(configs.save_dir, 'test_outputs.npy'), outputs)np.save(os.path.join(configs.save_dir, 'test_targets.npy'), targets)elif configs.mode == 'predict':# 加載檢查點if configs.checkpoint_path:checkpoint = torch.load(configs.checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])print(f"Loaded checkpoint from {configs.checkpoint_path}")# 創建預測器predictor = Predictor(configs, model)# 加載要預測的數據# 這里假設有一個單獨的預測數據集predict_dataset = NPYDataset(data_root=configs.data_root,mode='predict',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)predict_loader = DataLoader(predict_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)all_predictions = []all_inputs = []with torch.no_grad():for inputs, _ in predict_loader:inputs = inputs.to(configs.device)predictions = predictor.predict(inputs)all_predictions.append(predictions.numpy())all_inputs.append(inputs.cpu().numpy())all_predictions = np.concatenate(all_predictions, axis=0)all_inputs = np.concatenate(all_inputs, axis=0)# 保存預測結果output_dir = os.path.join(configs.save_dir, 'predictions')os.makedirs(output_dir, exist_ok=True)for i, (input_seq, pred_seq) in enumerate(zip(all_inputs, all_predictions)):# 保存輸入序列for j, frame in enumerate(input_seq):frame_path = os.path.join(output_dir, f'sequence_{i:03d}_input_{j:03d}.npy')np.save(frame_path, frame.squeeze())# 保存預測序列for j, frame in enumerate(pred_seq):frame_path = os.path.join(output_dir, f'sequence_{i:03d}_pred_{j:03d}.npy')np.save(frame_path, frame.squeeze())print(f"Predictions saved to {output_dir}")# 可視化一些結果if len(all_inputs) > 0:sample_idx = 0predictor.visualize_results(all_inputs[sample_idx:sample_idx+1],all_predictions[sample_idx:sample_idx+1],all_predictions[sample_idx:sample_idx+1],save_path=os.path.join(output_dir, 'sample_prediction.png'))if __name__ == '__main__':main()

9. 常見問題與解決方案

9.1 內存不足問題

當處理 500×500 的大尺寸圖像時,可能會遇到內存不足的問題。解決方案:

  1. 使用數據分塊:將大圖像分割成小塊進行處理
  2. 降低批次大小:減少每次處理的樣本數量
  3. 使用混合精度訓練:使用半精度浮點數減少內存占用
# 混合精度訓練示例
from torch.cuda.amp import autocast, GradScalerdef train_epoch_with_amp(self, epoch):"""使用混合精度訓練一個epoch"""self.model.train()total_loss = 0scaler = GradScaler()progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)# 使用自動混合精度with autocast():outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)# 縮放損失并反向傳播self.optimizer.zero_grad()scaler.scale(loss).backward()scaler.step(self.optimizer)scaler.update()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_loss

9.2 訓練不穩定問題

PredRNNv2 模型訓練可能會不穩定,可以嘗試以下方法:

  1. 梯度裁剪:防止梯度爆炸
  2. 學習率調度:動態調整學習率
  3. 權重初始化:使用合適的初始化方法
# 梯度裁剪示例
def train_epoch_with_gradient_clipping(self, epoch, clip_value=1.0):"""帶梯度裁剪的訓練"""self.model.train()total_loss = 0progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_value)self.optimizer.step()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_loss

9.3 過擬合問題

當模型在訓練集上表現良好但在驗證集上表現不佳時,可能存在過擬合問題:

  1. 數據增強:增加數據多樣性
  2. 正則化:使用 Dropout 或權重衰減
  3. 早停:在驗證損失不再改善時停止訓練
# 早停實現
class EarlyStopping:def __init__(self, patience=10, min_delta=0):self.patience = patienceself.min_delta = min_deltaself.counter = 0self.best_loss = Noneself.early_stop = Falsedef __call__(self, val_loss):if self.best_loss is None:self.best_loss = val_losselif val_loss > self.best_loss - self.min_delta:self.counter += 1if self.counter >= self.patience:self.early_stop = Trueelse:self.best_loss = val_lossself.counter = 0return self.early_stop# 在訓練循環中使用早停
early_stopping = EarlyStopping(patience=10)for epoch in range(num_epochs):# 訓練和驗證...if early_stopping(valid_loss):print("Early stopping triggered")break

10. 總結與展望

本文詳細介紹了如何復現 OpenSTL 中的 PredRNNv2 模型,并使用自定義的 NPY 格式數據集進行訓練和預測。我們涵蓋了從環境配置、數據預處理、模型構建到訓練和評估的完整流程。

10.1 主要成果

  1. 完整的數據處理流程:實現了針對 NPY 格式數據的加載、預處理和增強功能
  2. PredRNNv2 模型復現:成功實現了 PredRNNv2 模型的核心組件和完整架構
  3. 訓練框架:構建了完整的訓練、驗證和測試流程,包括損失函數、優化器和學習率調度
  4. 預測與可視化:實現了預測功能和結果可視化,便于分析模型性能
  5. 問題解決方案:提供了針對常見問題(內存不足、訓練不穩定、過擬合)的解決方案

10.2 未來工作方向

  1. 模型優化:嘗試更先進的視頻預測模型,如 SimVP、PhyDNet 等
  2. 多模態融合:結合其他傳感器數據(如氣象數據、地理信息)提高預測精度
  3. 實時預測:優化模型推理速度,實現實時預測功能
  4. 不確定性量化:增加對預測結果不確定性的估計
  5. 部署優化:將模型部署到生產環境,支持大規模數據處理

通過本文的指導和代碼實現,讀者應該能夠成功復現 PredRNNv2 模型,并在自己的數據集上進行訓練和預測。希望這項工作能夠為視頻預測任務的研究和應用提供有價值的參考。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/99239.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/99239.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/99239.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Linux內核IPv4隧道模式封裝機制剖析

概述 在Linux網絡棧中,XFRM(Transform)子系統負責實現IPsec等安全協議的功能。其中,xfrm4_mode_tunnel.c是實現IPv4隧道模式封裝的核心模塊,為IPv4數據包提供隧道模式的封裝和解封裝能力。本文將深入分析這一模塊的實現機制。 模塊架構與功能 該模塊通過注冊到XFRM框架…

OPC Client第10講:實現主界面;獲取初始界面傳來的所有配置信息config【C++讀寫Excel:xlnx;ODBC;緩沖區】

接前面代碼內容&#xff1a; OPC Client第6講&#xff08;wxwidgets&#xff09;&#xff1a;Logger.h日志記錄文件&#xff08;單例模式&#xff09;&#xff1b;登錄后的主界面_wx.logger-CSDN博客 OPC Client第8講&#xff1a;OPC UA&#xff1b;KEPServerEX創建OPC服務器…

快速入門HarmonyOS應用開發(一)

目錄 前言 一、準備工作 二、實戰開發 2.1、Navigation簡介 2.2、頁面路由開發 2.2.1、創建常量 2.2.2、創建字符串資源 2.2.3、創建float資源 2.2.4、創建color資源 2.2.5、創建數據實體 2.2.6、創建頁面路由表 2.2.7、創建Navigation根容器 2.2.8、創建NavDesti…

AI 進課堂 - 語文教學流程重塑

AI 進課堂 - 語文教學流程重塑執教語文十余年&#xff0c;備課案頭的參考書堆得比學生作業本還高&#xff0c;批改作文時紅筆芯換得比粉筆還勤。 直到去年把 JBoltAI 請進課堂&#xff0c;那些重復機械的工作突然有了新解法&#xff0c;連課堂上孩子們的眼神都亮了許多 —— 這…

用戶是否可以同時使用快照和備份來保護云服務器數據安全?

在云計算環境中&#xff0c;云服務器已成為企業和個人數據存儲、應用部署和業務運營的重要平臺。隨著業務數據量的不斷增長&#xff0c;數據安全和業務連續性成為用戶關注的核心問題。云服務器提供的快照和備份功能為用戶提供了有效的數據保護手段&#xff0c;但很多人會疑問&a…

RDS-MYSQL,這個RDS是什么?和mysql有什么區別?

好的&#xff0c;這是一個非常常見且重要的問題。我用最通俗易懂的方式給你解釋清楚。 一、大白話解釋 你可以把 MySQL 和 RDS MySQL 的關系&#xff0c;想象成&#xff1a;MySQL&#xff1a;就像是你自己買零件組裝的一臺電腦。 你需要自己挑選CPU、內存、硬盤、主板&#xff…

arcgis中實現四色/五色法制圖

四色定理是圖論中的一個著名定理&#xff0c;它指出在任何地圖上&#xff0c;只需四種顏色就足以使任何相鄰的區域&#xff08;擁有共同邊界線段&#xff0c;而非單個點&#xff09;顏色不同。五色定理則是另一個更早被證明的、較弱但更易證的定理。在地圖制圖中&#xff0c;這…

Spring如何巧妙解決循環依賴問題

什么是循環依賴&#xff1f;循環依賴是指兩個或多個Bean之間相互依賴&#xff0c;形成閉環的情況。例如&#xff1a;AService依賴BService&#xff0c;而BService又依賴AService。這種場景下&#xff0c;傳統的創建順序無法滿足依賴注入的要求。Spring的三級緩存機制Spring通過…

CUDA 中Thrust exclusive_scan使用詳解

1. 基本概念Thrust 是 NVIDIA CUDA 提供的類似 C STL 的并行算法庫。Scan (前綴和)&#xff1a;給定數組 [a0, a1, a2, ...]&#xff0c;產生前綴和序列。Exclusive Scan (排他前綴和)&#xff1a; 輸出位置 i 存放的是輸入數組中 0 到 i-1 的累積結果。換句話說&#xff0c;結…

Linux -- 信號【上】

目錄 一、信號的引入 1、信號概念 2、signal函數 普通標準信號詳解表 3、前臺/后臺進程 3.1 概念 3.2 查看后臺進程 3.3 后臺進程拉回前臺 3.4 終止后臺進程 3.5 暫停前臺進程 3.6 回復運行后臺進程 4、發信號的本質 二、信號的產生 1、終端按鍵 2、系統調用 2…

Altium Designer(AD)自定義PCB外觀顏色

目錄 1視圖設置界面介紹 2PCB阻焊層顏色設置 2.1進入視圖設置界面 2.2阻焊層顏色設置 2.3頂層和底層阻焊層顏色設置 2.4頂層阻焊層試圖效果 2.5底層阻焊層試圖效果 3設置PCB絲印顏色設置 3.1找到絲印設置選項 3.2設置頂層和底層絲印顏色 3.3頂層絲印 3.4底層絲印 4…

5天改造,節能50%!冷能改造如何實現“不停產節能”?

你有沒有發現一個現象&#xff1f;很多工廠老板一提到節能改造&#xff0c;第一反應就是搖頭。不是不想省電費&#xff0c;而是怕停產。停產一天損失幾十萬&#xff0c;改造周期動輒幾個月&#xff0c;這賬怎么算都不劃算。但如果我告訴你&#xff0c;有一種改造方式&#xff0…

【Flink】窗口

目錄窗口窗口的概念窗口的分類滾動窗口&#xff08;Tumbling Windows&#xff09;滑動窗口&#xff08;Sliding Windows&#xff09;會話窗口&#xff08;Session Windows&#xff09;全局窗口&#xff08;Global Windows&#xff09;窗口API概覽窗口函數增量聚合函數ReduceFun…

攻擊路徑(4):API安全風險導致敏感數據泄漏

本文是《攻防演練 | JS泄露到主機失陷[1]》的學習筆記&#xff0c;歡迎大家閱讀原文。攻擊路徑通過未授權訪問攻擊獲取敏感數據通過SQL注入攻擊獲取服務器權限通過憑據訪問攻擊獲取數據庫權限和敏感數據和應用權限安全風險與加固措施通過未授權訪問攻擊獲取敏感數據、通過SQL注…

機器學習面試題:請介紹一下你理解的集成學習算法

集成學習&#xff08;Ensemble Learning&#xff09;的核心思想是“集思廣益”&#xff0c;它通過構建并結合多個基學習器&#xff08;Base Learner&#xff09;來完成學習任務&#xff0c;從而獲得比單一學習器更顯著優越的泛化性能。俗話說&#xff0c;“三個臭皮匠&#xff…

Invalid bound statement (not found): com.XXX.XXx.service.xxx無法執行service

org.apache.ibatis.binding.BindingException: Invalid bound statement (not found): com.xxx.xxx.service.CitytownService.selectCitytown 出現無法加載sevice層的時候&#xff0c;如下圖所示1&#xff0c;處理方法是&#xff0c;先看下注解MapperScan內的包地址&#xff0c…

泛型(Generics)what why when【前端TS】

我總是提醒自己一定要嚴謹嚴謹嚴謹 目錄TypeScript 泛型 (Generics)1. 什么是泛型&#xff1f;2. 為什么需要泛型&#xff1f;3. 泛型常見用法3.1 函數泛型3.2 接口泛型3.3 類泛型3.4 泛型約束3.5 泛型默認值3.6 多個泛型參數4. 泛型應用場景TypeScript 泛型 (Generics) 1. 什…

分布式協議與算法實戰-協議和算法篇

05丨Paxos算法&#xff08;一&#xff09;&#xff1a;如何在多個節點間確定某變量的值? 提到分布式算法&#xff0c;就不得不提 Paxos 算法&#xff0c;在過去幾十年里&#xff0c;它基本上是分布式共識的代名詞&#xff0c;因為當前最常用的一批共識算法都是基于它改進的。比…

9.13 9.15 JavaWeb(事務管理、AOP P172-P182)

事務管理事務概念事務是一組操作的集合&#xff0c;是一個不可分割的工作單位&#xff0c;這些操作要么同時成功&#xff0c;要么同時失敗操作開啟事務&#xff08;一組操作開始前&#xff0c;開啟事務&#xff09;&#xff1a;start transaction / begin提交事務&#xff08;這…

檢索融合方法- Distribution-Based Score Fusion (DBSF)

在信息檢索&#xff08;IR&#xff09;、推薦系統和多模態檢索中&#xff0c;我們常常需要融合來自多個檢索器或模型的結果。不同檢索器可能對同一文檔打出的分數差異很大&#xff0c;如果直接簡單加權&#xff0c;很容易出現某個檢索器“主導融合結果”的情況。 Distribution…