目錄
- 1 準備工作:python庫包安裝
- 1.1 安裝必要庫
- 案例說明:模擬視頻幀的時序建模
- ConvLSTM概述
- 損失函數說明
- (python全代碼)
- 參考
ConvLSTM的原理說明可參見另一博客-【ConvLSTM第一期】ConvLSTM原理。
1 準備工作:python庫包安裝
1.1 安裝必要庫
pip install torch torchvision matplotlib numpy
案例說明:模擬視頻幀的時序建模
🎯 目標:給定一個人工生成的動態圖像序列(例如移動的方塊),使用 ConvLSTM 對其進行建模,輸出預測結果,并查看輸出的維度和特征變化。
ConvLSTM概述
ConvLSTM 的基本結構,包括:
- ConvLSTMCell:實現了一個時間步的 ConvLSTM 單元,類似于一個“時刻”的神經元。
- ConvLSTM:實現了多層ConvLSTM結構,能夠處理一整個時間序列的視頻幀數據。
損失函數說明
MSE(均方誤差) 衡量預測值和真實值之間的平均平方差。
關于訓練終止條件:
可以根據 MSE是否達到某個閾值(如 < 0.001)提前終止訓練,這是所謂的 “Early Stopping(提前停止)策略”。
(python全代碼)
MSE損失函數曲線如下:可知MSE一直在下降,雖然存在振蕩
前9幀圖像及預測的第十幀圖像得到的動圖如下:
python完整代碼如下:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image# 設置字體
plt.rcParams['font.family'] = 'Times New Roman'# 創建保存圖像目錄
os.makedirs("./Figures", exist_ok=True)# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ====================================
# 一、ConvLSTM 模型結構
# ====================================class ConvLSTMCell(nn.Module):def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):super(ConvLSTMCell, self).__init__()padding = kernel_size // 2self.input_channels = input_channelsself.hidden_channels = hidden_channelsself.conv = nn.Conv2d(input_channels + hidden_channels, 4 * hidden_channels, kernel_size, padding=padding, bias=bias)def forward(self, x, h_prev, c_prev):combined = torch.cat([x, h_prev], dim=1)conv_output = self.conv(combined)cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1)i = torch.sigmoid(cc_i)f = torch.sigmoid(cc_f)o = torch.sigmoid(cc_o)g = torch.tanh(cc_g)c = f * c_prev + i * gh = o * torch.tanh(c)return h, cclass ConvLSTM(nn.Module):def __init__(self, input_channels, hidden_channels, kernel_size, num_layers):super(ConvLSTM, self).__init__()self.num_layers = num_layerslayers = []for i in range(num_layers):in_channels = input_channels if i == 0 else hidden_channelslayers.append(ConvLSTMCell(in_channels, hidden_channels, kernel_size))self.layers = nn.ModuleList(layers)def forward(self, input_seq):b, t, c, h, w = input_seq.size()h_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]c_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]for time in range(t):x = input_seq[:, time]for i, layer in enumerate(self.layers):h_t[i], c_t[i] = layer(x, h_t[i], c_t[i])x = h_t[i]return h_t[-1] # 返回最后一層最后一幀的隱藏狀態# ====================================
# 二、生成移動方塊序列數據
# ====================================def generate_moving_square_sequence(batch_size, time_steps, height, width):data = torch.zeros((batch_size, time_steps, 1, height, width))for b in range(batch_size):dx = np.random.randint(1, 3)dy = np.random.randint(1, 3)x = np.random.randint(0, width - 6)y = np.random.randint(0, height - 6)for t in range(time_steps):data[b, t, 0, y:y+5, x:x+5] = 1.0x = (x + dx) % (width - 5)y = (y + dy) % (height - 5)return data# ====================================
# 三、模型、損失、優化器
# ====================================class ConvLSTM_Predictor(nn.Module):def __init__(self):super().__init__()self.convlstm = ConvLSTM(input_channels=1, hidden_channels=16, kernel_size=3, num_layers=1)self.decoder = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)def forward(self, input_seq):hidden = self.convlstm(input_seq)pred = self.decoder(hidden)return predmodel = ConvLSTM_Predictor().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# ====================================
# 四、訓練過程
# ====================================mse_list = []
max_epochs = 100
mse_threshold = 0.001
height, width = 64, 64for epoch in range(max_epochs):model.train()seq = generate_moving_square_sequence(8, 10, height, width).to(device)input_seq = seq[:, :9]target_frame = seq[:, 9, 0].unsqueeze(1)optimizer.zero_grad()output = model(input_seq)loss = criterion(output, target_frame)loss.backward()optimizer.step()mse = loss.item()mse_list.append(mse)print(f"Epoch {epoch+1}/{max_epochs}, MSE: {mse:.6f}")# 提前停止條件if mse < mse_threshold:print(f"? 提前停止:MSE 已達到閾值 {mse_threshold}")break# ====================================
# 五、測試與可視化結果
# ====================================model.eval()
with torch.no_grad():test_seq = generate_moving_square_sequence(1, 10, height, width).to(device)input_seq = test_seq[:, :9]true_frame = test_seq[:, 9, 0]pred_frame = model(input_seq)[0, 0].cpu().numpy()# 保存輸入幀
for t in range(9):frame = input_seq[0, t, 0].cpu().numpy()plt.imshow(frame, cmap='gray')plt.title(f"Input Frame t={t}")plt.colorbar()plt.savefig(f"./Figures/input_frame_{t}.png")plt.close()# 保存 Ground Truth
plt.imshow(true_frame[0].cpu().numpy(), cmap='gray')
plt.title("Ground Truth Frame t=9")
plt.colorbar()
plt.savefig("./Figures/ground_truth_t9.png")
plt.close()# 保存預測幀
plt.imshow(pred_frame, cmap='gray')
plt.title("Predicted Frame t=9")
plt.colorbar()
plt.savefig("./Figures/predicted_t9.png")
plt.close()# 保存 MSE 曲線圖
plt.plot(mse_list)
plt.title("Training MSE Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.grid(True)
plt.savefig("./Figures/mse_curve.png")
plt.close()# ---------------- 生成動圖 ----------------frames = []# 添加前9幀輸入
for t in range(9):img = Image.open(f"./Figures/input_frame_{t}.png")frames.append(img.copy())# 添加預測幀
img = Image.open("./Figures/predicted_t9.png")
frames.append(img.copy())# 保存動圖
frames[0].save("./Figures/sequence.gif", save_all=True, append_images=frames[1:], duration=500, loop=0)
print("? 所有圖像和動圖已保存至 ./Figures 文件夾")