使用PyTorch實現ResNet:從殘差塊到完整模型訓練

ResNet(殘差網絡)是深度學習中的經典模型,通過引入殘差連接解決了深層網絡訓練中的梯度消失問題。本文將從殘差塊的定義開始,逐步實現一個ResNet模型,并在Fashion MNIST數據集上進行訓練和測試。


1. 殘差塊(Residual Block)實現

殘差塊通過跳躍連接(Shortcut Connection)將輸入直接傳遞到輸出,緩解了深層網絡的訓練難題。以下是殘差塊的PyTorch實現:

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = Noneself.relu = nn.ReLU(inplace=True)def forward(self, x):y = F.relu(self.bn1(self.conv1(x)))y = self.bn2(self.conv2(y))if self.conv3:x = self.conv3(x)y += xreturn F.relu(y)

代碼解析

  • use_1x1conv:當輸入和輸出通道數不一致時,使用1x1卷積調整通道數。

  • strides:控制特征圖下采樣的步長。

  • 殘差相加后再次使用ReLU激活,增強非線性表達能力。


2. 構建ResNet模型

ResNet由多個殘差塊堆疊而成,以下代碼構建了一個簡化版ResNet-18:

# 初始卷積層
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:  # 第一個塊需下采樣blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# 堆疊殘差塊
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))# 完整網絡結構
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(512, 10)
)

模型結構說明

  • AdaptiveAvgPool2d:自適應平均池化,將特征圖尺寸統一為1x1。

  • Flatten:展平特征用于全連接層分類。


3. 數據加載與預處理

使用Fashion MNIST數據集,批量大小為256:

train_data, test_data = d2l.load_data_fashion_mnist(batch_size=256)

4. 模型訓練與測試

設置訓練參數:10個epoch,學習率0.05,并使用GPU加速:

d2l.train_ch6(net, train_data, test_data, num_epochs=10, lr=0.05, device=d2l.try_gpu())

訓練結果

loss 0.124, train acc 0.952, test acc 0.860
4921.4 examples/sec on cuda:0

5. 結果可視化

訓練過程中損失和準確率變化如下圖所示:

分析

  • 訓練準確率(紫色虛線)迅速上升并穩定在95%以上。

  • 測試準確率(綠色點線)達到86%,表明模型具有良好的泛化能力。

  • 損失值(藍色實線)持續下降,未出現過擬合。


6. 完整代碼

整合所有代碼片段(需安裝d2l庫):

# 殘差塊定義、模型構建、訓練代碼見上文

7. 總結

本文實現了ResNet的核心組件——殘差塊,并構建了一個簡化版ResNet模型。通過實驗驗證,模型在Fashion MNIST數據集上表現良好。讀者可嘗試調整網絡深度或超參數以進一步提升性能。

改進方向

  • 增加殘差塊數量構建更深的ResNet(如ResNet-34/50)。

  • 使用數據增強策略提升泛化能力。

  • 嘗試不同的優化器和學習率調度策略。


注意事項

  • 確保已安裝PyTorch和d2l庫。

  • GPU環境可顯著加速訓練,若使用CPU需調整批量大小。


希望本文能幫助您理解ResNet的實現細節!如有疑問,歡迎在評論區留言討論。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/900093.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/900093.shtml
英文地址,請注明出處:http://en.pswp.cn/news/900093.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Transformer架構詳解:從Encoder到Decoder的完整旅程

引言:從Self-Attention到完整架構 在上一篇文章中,我們深入剖析了Self-Attention機制的核心原理。然而,Transformer的魅力遠不止于此——其Encoder-Decoder架構通過巧妙的模塊化設計,實現了從機器翻譯到文本生成的廣泛能力。本文…

Docker學習--容器生命周期管理相關命令--docker create 命令

docker create 命令作用: 會根據指定的鏡像和參數創建一個容器實例,但容器只會在創建時進行初始化,并不會執行任何進程。 語法: docker create[參數] IMAGE(要執行的鏡像) [COMMAND](在容器內部…

【C++11】異步編程

異步編程的概念 什么是異步? 異步編程是一種編程范式,允許程序在等待某些操作時繼續執行其它任務,而不是阻塞或等待這些操作完成。 異步編程vs同步編程? 在傳統的同步編程中,代碼按順序同步執行,每個操作需…

FastAPI與ASGI深度整合實戰指南

一、ASGI技術體系解析 1. ASGI協議棧全景圖 #mermaid-svg-a5XPEshAsf64SBkw {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-a5XPEshAsf64SBkw .error-icon{fill:#552222;}#mermaid-svg-a5XPEshAsf64SBkw .error-te…

數組與特殊壓縮矩陣

一、數組的基本特性 定義: int arr[3][3]; // 3x3二維數組 存儲方式: 行優先存儲(C語言默認):元素按行連續存儲。 列優先存儲:需手動實現(如科學計算中的Fortran風格)。 訪問元素…

Word 插入無頁眉頁碼的空白頁(即插入奇數頁)

遇到問題 例如,我的第5章的頁碼是58,偶數頁,我想改成奇數頁59,需要在57頁和58頁之間插入奇數頁。 解決辦法 單擊上一頁(57頁),打開“視圖-大綱”,找到要插入奇數頁的位置&#x…

OpenCV 從入門到精通(day_05)

1. 模板匹配 1.1 什么是模板匹配 模板匹配就是用模板圖(通常是一個小圖)在目標圖像(通常是一個比模板圖大的圖片)中不斷的滑動比較,通過某種比較方法來判斷是否匹配成功。 1.2 匹配方法 rescv2.matchTemplate(image, …

【目標檢測】【深度學習】【Pytorch版本】YOLOV3模型算法詳解

【目標檢測】【深度學習】【Pytorch版本】YOLOV3模型算法詳解 文章目錄 【目標檢測】【深度學習】【Pytorch版本】YOLOV3模型算法詳解前言YOLOV3的模型結構YOLOV3模型的基本執行流程YOLOV3模型的網絡參數 YOLOV3的核心思想前向傳播階段反向傳播階段 總結 前言 YOLOV3是由華盛頓…

LN2220 2A 高效率升壓 DC/DC 電壓調整器

1、產品概述 LN2220 是一款微小型、高效率、升壓型 DC/DC 調整器。 電路由電流模 PWM 控制環路,誤差放大器,斜波補償電路, 比較器和功率開關等模塊組成。該芯片可在較寬負載范圍內 高效穩定的工作,內置一個 4A 的功率開關和…

【大模型基礎_毛玉仁】6.3 知識檢索

目錄 6.3 知識檢索6.3.1 知識庫構建1)數據采集及預處理2)知識庫增強 6.3.2 查詢增強1)查詢語義增強2)查詢內容增強 6.3.3 檢索器1)判別式檢索器2)生成式檢索器 6.3.4 檢索效率增強1)相似度索引算…

靜態方法和實例方法

在 Java 中,?靜態方法(static method)?和?實例方法(instance method)?是兩種不同類型的方法,它們在調用方式、內存分配和訪問權限上有顯著區別。以下是詳細對比: ?1. 靜態方法(…

Lua環境搭建+Lua基本語法

前期準備: 搜索并下載安裝LuaForWindows,例: 安裝完成后開啟cmd窗口,輸入lua 出現版本號證明成功下載安裝 使用Sublime Text編輯器編寫Lua 使用瀏覽器或CSDN搜索Sublime Text下載并安裝,安裝成功后打開編輯器,編輯…

FFmpeg錄制屏幕和音頻

一、FFmpeg命令行實現錄制屏幕和音頻 1、Windows 示例 #include <cstdlib> #include <string> #include <iostream>int main() {// FFmpeg 命令行&#xff08;錄制屏幕 麥克風音頻&#xff09;std::string command "ffmpeg -f gdigrab -framerate 3…

【數據集】多視圖文本數據集

多視圖文本數據集指的是包含多個不同類型或來源的信息的文本數據集。不同視圖可以來源于不同的數據模式&#xff08;如原始文本、元數據、網絡結構等&#xff09;&#xff0c;或者不同的文本表示方法&#xff08;如 TF-IDF、詞嵌入、主題分布等&#xff09;。這些數據集常用于多…

C++ 繼承方式使用場景(極簡版)

1. 公有繼承&#xff08;public&#xff09; 什么時候用&#xff1f; “是一個”&#xff08;is-a&#xff09;關系&#xff1a;派生類 是 基類的一種。 例&#xff1a;class Dog : public Animal&#xff08;狗是動物&#xff09; 最常見&#xff0c;90%的繼承都用它。 2. 保…

Ubuntu 系統 Docker 中搭建 CUDA cuDNN 開發環境

CUDA 是 NVIDIA 推出的并行計算平臺和編程模型&#xff0c;利用 GPU 多核心架構加速計算任務&#xff0c;廣泛應用于深度學習、科學計算等領域。cuDNN 是基于 CUDA 的深度神經網絡加速庫&#xff0c;為深度學習框架提供高效卷積、池化等操作的優化實現&#xff0c;提升模型訓練…

高密度任務下的挑戰與破局:數字樣機助力火箭發射提效提質

2025年4月1日12時&#xff0c;在酒泉衛星發射中心&#xff0c;長征二號丁運載火箭順利升空&#xff0c;成功將一顆衛星互聯網技術試驗衛星送入預定軌道&#xff0c;發射任務圓滿完成。這是長征二號丁火箭的第97次發射&#xff0c;也是長征系列火箭的第567次發射。 執行本次任務…

關于SQL子查詢的使用策略

在 SQL 優化中&#xff0c;一般遵循**“非必要不使用子查詢”**的原則&#xff0c;因為子查詢可能會帶來額外的計算開銷&#xff0c;影響查詢效率。但是&#xff0c;并不是所有子查詢都需要避免&#xff0c;有時子查詢是最優解&#xff0c;具體要根據實際場景選擇合適的優化方式…

JavaEE初階復習(JVM篇)

JVM Java虛擬機 jdk java開發工具包 jre java運行時環境 jvm java虛擬機(解釋執行 java 字節碼) java作為一個半解釋,半編譯的語言,可以做到跨平臺. java 通過javac把.java文件>.class文件(字節碼文件) 字節碼文件, 包含的就是java字節碼, jvm把字節碼進行翻譯轉化為…

2.pycharm保姆級安裝教程

一、pycharm安裝 1.官網上下載好好軟&#xff0c;雙擊打開 2.下一步 3.修改路徑地址 (默認也可以) 4.打勾 5.安裝 不用重啟電腦 二、添加解釋器 1.雙擊軟件&#xff0c;打開 2.projects – new project 3.指定項目名字&#xff0c;項目保存地址&#xff0c;解釋器 4.右擊 – …