【ConvLSTM第二期】模擬視頻幀的時序建模(Python代碼實現)

目錄

  • 1 準備工作:python庫包安裝
    • 1.1 安裝必要庫
  • 案例說明:模擬視頻幀的時序建模
    • ConvLSTM概述
    • 損失函數說明
    • (python全代碼)
  • 參考

ConvLSTM的原理說明可參見另一博客-【ConvLSTM第一期】ConvLSTM原理。

1 準備工作:python庫包安裝

1.1 安裝必要庫

pip install torch torchvision matplotlib numpy

案例說明:模擬視頻幀的時序建模

🎯 目標:給定一個人工生成的動態圖像序列(例如移動的方塊),使用 ConvLSTM 對其進行建模,輸出預測結果,并查看輸出的維度和特征變化。

ConvLSTM概述

ConvLSTM 的基本結構,包括:

  • ConvLSTMCell:實現了一個時間步的 ConvLSTM 單元,類似于一個“時刻”的神經元。
  • ConvLSTM:實現了多層ConvLSTM結構,能夠處理一整個時間序列的視頻幀數據。

損失函數說明

MSE(均方誤差) 衡量預測值和真實值之間的平均平方差。
在這里插入圖片描述

關于訓練終止條件:
可以根據 MSE是否達到某個閾值(如 < 0.001)提前終止訓練,這是所謂的 “Early Stopping(提前停止)策略”。

(python全代碼)

MSE損失函數曲線如下:可知MSE一直在下降,雖然存在振蕩
在這里插入圖片描述

前9幀圖像及預測的第十幀圖像得到的動圖如下:
在這里插入圖片描述

python完整代碼如下:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image# 設置字體
plt.rcParams['font.family'] = 'Times New Roman'# 創建保存圖像目錄
os.makedirs("./Figures", exist_ok=True)# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ====================================
# 一、ConvLSTM 模型結構
# ====================================class ConvLSTMCell(nn.Module):def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):super(ConvLSTMCell, self).__init__()padding = kernel_size // 2self.input_channels = input_channelsself.hidden_channels = hidden_channelsself.conv = nn.Conv2d(input_channels + hidden_channels, 4 * hidden_channels, kernel_size, padding=padding, bias=bias)def forward(self, x, h_prev, c_prev):combined = torch.cat([x, h_prev], dim=1)conv_output = self.conv(combined)cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1)i = torch.sigmoid(cc_i)f = torch.sigmoid(cc_f)o = torch.sigmoid(cc_o)g = torch.tanh(cc_g)c = f * c_prev + i * gh = o * torch.tanh(c)return h, cclass ConvLSTM(nn.Module):def __init__(self, input_channels, hidden_channels, kernel_size, num_layers):super(ConvLSTM, self).__init__()self.num_layers = num_layerslayers = []for i in range(num_layers):in_channels = input_channels if i == 0 else hidden_channelslayers.append(ConvLSTMCell(in_channels, hidden_channels, kernel_size))self.layers = nn.ModuleList(layers)def forward(self, input_seq):b, t, c, h, w = input_seq.size()h_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]c_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]for time in range(t):x = input_seq[:, time]for i, layer in enumerate(self.layers):h_t[i], c_t[i] = layer(x, h_t[i], c_t[i])x = h_t[i]return h_t[-1]  # 返回最后一層最后一幀的隱藏狀態# ====================================
# 二、生成移動方塊序列數據
# ====================================def generate_moving_square_sequence(batch_size, time_steps, height, width):data = torch.zeros((batch_size, time_steps, 1, height, width))for b in range(batch_size):dx = np.random.randint(1, 3)dy = np.random.randint(1, 3)x = np.random.randint(0, width - 6)y = np.random.randint(0, height - 6)for t in range(time_steps):data[b, t, 0, y:y+5, x:x+5] = 1.0x = (x + dx) % (width - 5)y = (y + dy) % (height - 5)return data# ====================================
# 三、模型、損失、優化器
# ====================================class ConvLSTM_Predictor(nn.Module):def __init__(self):super().__init__()self.convlstm = ConvLSTM(input_channels=1, hidden_channels=16, kernel_size=3, num_layers=1)self.decoder = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)def forward(self, input_seq):hidden = self.convlstm(input_seq)pred = self.decoder(hidden)return predmodel = ConvLSTM_Predictor().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# ====================================
# 四、訓練過程
# ====================================mse_list = []
max_epochs = 100
mse_threshold = 0.001
height, width = 64, 64for epoch in range(max_epochs):model.train()seq = generate_moving_square_sequence(8, 10, height, width).to(device)input_seq = seq[:, :9]target_frame = seq[:, 9, 0].unsqueeze(1)optimizer.zero_grad()output = model(input_seq)loss = criterion(output, target_frame)loss.backward()optimizer.step()mse = loss.item()mse_list.append(mse)print(f"Epoch {epoch+1}/{max_epochs}, MSE: {mse:.6f}")# 提前停止條件if mse < mse_threshold:print(f"? 提前停止:MSE 已達到閾值 {mse_threshold}")break# ====================================
# 五、測試與可視化結果
# ====================================model.eval()
with torch.no_grad():test_seq = generate_moving_square_sequence(1, 10, height, width).to(device)input_seq = test_seq[:, :9]true_frame = test_seq[:, 9, 0]pred_frame = model(input_seq)[0, 0].cpu().numpy()# 保存輸入幀
for t in range(9):frame = input_seq[0, t, 0].cpu().numpy()plt.imshow(frame, cmap='gray')plt.title(f"Input Frame t={t}")plt.colorbar()plt.savefig(f"./Figures/input_frame_{t}.png")plt.close()# 保存 Ground Truth
plt.imshow(true_frame[0].cpu().numpy(), cmap='gray')
plt.title("Ground Truth Frame t=9")
plt.colorbar()
plt.savefig("./Figures/ground_truth_t9.png")
plt.close()# 保存預測幀
plt.imshow(pred_frame, cmap='gray')
plt.title("Predicted Frame t=9")
plt.colorbar()
plt.savefig("./Figures/predicted_t9.png")
plt.close()# 保存 MSE 曲線圖
plt.plot(mse_list)
plt.title("Training MSE Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.grid(True)
plt.savefig("./Figures/mse_curve.png")
plt.close()# ---------------- 生成動圖 ----------------frames = []# 添加前9幀輸入
for t in range(9):img = Image.open(f"./Figures/input_frame_{t}.png")frames.append(img.copy())# 添加預測幀
img = Image.open("./Figures/predicted_t9.png")
frames.append(img.copy())# 保存動圖
frames[0].save("./Figures/sequence.gif", save_all=True, append_images=frames[1:], duration=500, loop=0)
print("? 所有圖像和動圖已保存至 ./Figures 文件夾")

參考

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/85624.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/85624.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/85624.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MySQL DDL操作全解析:從入門到精通,包含索引視圖分區表等全操作解析

目錄 一、DDL 基礎概述 1.1 DDL 定義與作用 1.2 DDL 語句分類 1.3 數據類型與存儲引擎 1.3.1 數據類型 1.3.2 存儲引擎差異 二、基礎 DDL 語句詳解 2.1 創建數據庫與表 2.1.1 創建數據庫 2.1.2 創建表 2.2 修改表結構 2.2.1 添加列 2.2.2 修改列屬性 2.2.3 刪除列…

設計模式——抽象工廠設計模式(創建型)

摘要 抽象工廠設計模式是一種創建型設計模式&#xff0c;旨在提供一個接口&#xff0c;用于創建一系列相關或依賴的對象&#xff0c;無需指定具體類。它通過抽象工廠、具體工廠、抽象產品和具體產品等組件構建&#xff0c;相比工廠方法模式&#xff0c;能創建一個產品族。該模…

Express教程【006】:使用Express寫接口

文章目錄 8、使用Express寫接口8.1 創建API路由模塊8.2 編寫GET接口8.3 編寫POST接口 8、使用Express寫接口 8.1 創建API路由模塊 1??新建routes/apiRouter.js路由模塊&#xff1a; /*** 路由模塊*/ // 1-導入express const express require(express); // 2-創建路由對象…

【iOS(swift)筆記-14】App版本不升級時本地數據庫sqlite更新邏輯二

App版本不升級時&#xff0c;又想即時更新本地數據庫怎么辦&#xff1f; 辦法二&#xff1a;從服務器下載最新的sqlite數據替換掉本地的數據&#xff08;注意是數據不是文件&#xff09; 稍加調整&#xff0c; // &#xff01;&#xff01;&#xff01;注意&#xff01;&…

Mac電腦_鑰匙串操作選項變灰的情況下如何刪除?

Mac電腦_鑰匙串操作選項變灰的情況下如何刪除&#xff1f; 這時候 可以使用相關的終端命令進行操作。 下面附加文章《Mac電腦_鑰匙串操作的終端命令》。 《Mac電腦_鑰匙串操作的終端命令》 &#xff08;來源&#xff1a;百度~百度AI 發布時間&#xff1a;2025-06&#xff09;…

對接系統外部服務組件技術方案

概述 當前系統需與多個外部系統對接,然而外部系統穩定性存在不確定性。對接過程中若出現異常,需依靠雙方的日志信息來定位問題,但若日志信息不夠完整,會極大降低問題定位效率。此外,問題發生后,很大程度上依賴第三方的重試機制,若第三方缺乏完善的重試機制,就需要手動…

WAF繞過,網絡層面后門分析,Windows/linux/數據庫提權實驗

一、WAF繞過文件上傳漏洞 win7&#xff1a;10.0.0.168 思路&#xff1a;要想要繞過WAF&#xff0c;第一步是要根據上傳的內容找出來被攔截的原因。對于文件上傳有三個可以考慮的點&#xff1a;文件后綴名&#xff0c;文件內容&#xff0c;文件類型。 第二步是根據找出來的攔截原…

一文學會c++中的內存管理知識點

文章目錄 c/c內存管理c語言動態內存管理c動態內存管理new/delete自定義類型妙用operator new和operator delete malloc/new&#xff0c;free/delete區別 c/c內存管理 int globalVar 1;static int staticGlobalVar 1;void Test(){static int staticVar 1;int localVar 1;in…

深入解析Linux死鎖:原理、原因及解決方案

Linux死鎖是系統資源管理的致命陷阱&#xff0c;平均每年導致全球數據中心約??3.7億小時??的服務中斷。本文深度剖析死鎖形成的??四個必要條件??和六種典型死鎖場景&#xff0c;結合Linux內核源碼層級的資源管理機制&#xff0c;揭示文件系統鎖、內存分配、多線程同步等…

SKUA-GOCAD入門教程-第八節 線的創建與編輯2

8.1.3根據線創建曲線 (1)從線生成線 這個命令可以將一組曲線合并為一條曲線。每個輸入曲線都會成為新曲線內的一個部分。 1、選擇 Curve commands > New > Curves 打開對話框。 圖1 根據曲線創建曲線 在“name”框中

『uniapp』把接口的內容下載為txt本地保存 / 讀取本地保存的txt文件內容(詳細圖文注釋)

目錄 預覽效果思路分析downloadTxt 方法readTxt 方法 完整代碼總結 歡迎關注 『uniapp』 專欄&#xff0c;持續更新中 歡迎關注 『uniapp』 專欄&#xff0c;持續更新中 預覽效果 思路分析 downloadTxt 方法 該方法主要完成兩個任務&#xff1a; 下載 txt 文件&#xff1a;通…

攻防世界-unseping

進入環境 在獲得的場景中發現PHP代碼并進行分析 編寫PHP編碼 得到 Tzo0OiJlYXNlIjoyOntzOjEyOiIAZWFzZQBtZXRob2QiO3M6NDoicGluZyI7czoxMDoiAGVhc2UAYXJncyI7YToxOntpOjA7czozOiJwd2QiO319 將其傳入 想執行ls&#xff0c;但是發現被過濾掉了 使用環境變量進行繞過 $a new…

IP查詢與網絡風險的關系

網絡風險場景與IP查詢的關聯 網絡攻擊、惡意行為、數據泄露等風險事件頻發&#xff0c;而IP地址作為網絡設備的唯一標識&#xff0c;承載著關鍵線索。例如&#xff0c;在DDoS惡意行為中&#xff0c;攻擊者利用大量IP地址發起流量洪泛&#xff1b;惡意行為通過變換IP地址繞過封…

pikachu通關教程-XSS

XSS XSS漏洞原理 XSS被稱為跨站腳本攻擊&#xff08;Cross Site Scripting&#xff09;&#xff0c;由于和層疊樣式表&#xff08;Cascading Style Sheets&#xff0c;CSS&#xff09;重名&#xff0c;改為XSS。主要基于JavaScript語言進行惡意攻擊&#xff0c;因為js非常靈活…

【時時三省】(C語言基礎)數組作為函數參數

山不在高&#xff0c;有仙則名。水不在深&#xff0c;有龍則靈。 ----CSDN 時時三省 調用有參函數時&#xff0c;需要提供實參。例如sin ( x )&#xff0c;sqrt ( 2&#xff0c;0 )&#xff0c;max ( a&#xff0c;b )等。實參可以是常量、變量或表達式。數組元素的作用與變量…

硬件工程師筆記——555定時器應用Multisim電路仿真實驗匯總

目錄 一 555定時器基礎知識 二、引腳功能 三、工作模式 1. 單穩態模式&#xff1a; 2. 雙穩態模式&#xff08;需要外部電路輔助&#xff09;&#xff1a; 3. 無穩態模式&#xff08;多諧振蕩器&#xff09;&#xff1a; 4. 可控脈沖寬度調制&#xff08;PWM&#xff09…

C++11特性:enum class(強枚舉類型)詳解

C11引入的 enum class&#xff08;強枚舉類型&#xff09;解決了傳統枚舉的多個問題&#xff1a; 防止枚舉值泄漏到外部作用域&#xff1b;禁止不同枚舉間的隱式轉換&#xff1b;允許指定底層數據類型優化內存&#xff1b;避免命名空間污染。 其基本語法為 enum class Name{.…

【QT】QString 與QString區別

在C中&#xff0c;QString 和 QString& 有本質區別&#xff0c;尤其是在參數傳遞和內存管理方面&#xff1a; 1. QString&#xff08;按值傳遞&#xff09; 創建副本&#xff1a;傳遞時會創建完整的字符串副本內存開銷&#xff1a;可能涉及深拷貝&#xff08;特別是大字符…

提升四級閱讀速度方法

以下是針對四級英語閱讀速度提升的系統性解決方案&#xff0c;結合最新考試規律和高效訓練方法&#xff0c;分五個核心模塊整理&#xff1a; &#x1f680; ??一、基礎提速訓練&#xff08;消除生理障礙&#xff09;?? ??擴大視幅范圍?? 從逐詞閱讀升級為 ??意群閱讀…

6.4 note

構造矩陣 class Solution { private: vector<int> empty {}; // 返回每個數字(-1)所在的序號&#xff0c;可以是行或列, 如果為空則無效 vector<int> topoSort(int k, vector<vector<int>>& conditions) { // 構建一個圖…