PyTorch中nn.Module詳解和綜合代碼示例

在 PyTorch 中,nn.Module 是神經網絡中最核心的基類,用于構建所有模型。理解并熟練使用 nn.Module 是掌握 PyTorch 的關鍵。


一、什么是 nn.Module

nn.Module 是 PyTorch 中所有神經網絡模塊的基類。可以把它看作是“神經網絡的容器”,它封裝了以下幾件事:

  1. 網絡層(如 Linear、Conv2d 等)
  2. 前向傳播邏輯(forward 函數)
  3. 模型參數(自動注冊并可訓練)
  4. 可嵌套(可以包含多個子模塊)
  5. 便捷的模型保存 / 加載等工具函數

二、基礎用法

2.1 自定義模型類

import torch
import torch.nn as nnclass MyNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x

2.2 實例化與調用

model = MyNet()
x = torch.randn(32, 784)     # batch_size = 32
output = model(x)            # 自動調用 forward

三、構造方法詳解

3.1 __init__()

  • 定義子模塊、層等結構。
  • 例如 self.conv1 = nn.Conv2d(...) 會被自動注冊為模型參數。

3.2 forward()

  • 定義前向傳播邏輯。
  • 不能手動調用,應使用 model(x) 形式。

四、常見模塊層

模塊名作用示例
nn.Linear全連接層nn.Linear(128, 64)
nn.Conv2d卷積層nn.Conv2d(3, 16, 3)
nn.ReLU激活函數nn.ReLU()
nn.Sigmoid激活函數nn.Sigmoid()
nn.BatchNorm2d批歸一化nn.BatchNorm2d(16)
nn.DropoutDropout 層nn.Dropout(0.5)
nn.LSTMLSTM 層nn.LSTM(10, 20)
nn.Sequential層的順序容器見下文說明

五、模型嵌套結構(子模塊)

你可以將一個 nn.Module 作為另一個模塊的子模塊嵌套:

class Block(nn.Module):def __init__(self):super().__init__()self.layer = nn.Sequential(nn.Linear(64, 64),nn.ReLU())def forward(self, x):return self.layer(x)class Net(nn.Module):def __init__(self):super().__init__()self.block1 = Block()self.block2 = Block()self.output = nn.Linear(64, 10)def forward(self, x):x = self.block1(x)x = self.block2(x)return self.output(x)

六、內置方法和屬性

方法 / 屬性說明
model.parameters()返回所有可訓練參數(用于優化器)
model.named_parameters()返回帶名字的參數迭代器
model.children()返回子模塊迭代器
model.eval()設置為評估模式(Dropout、BN失效)
model.train()設置為訓練模式
model.to(device)將模型轉移到 GPU/CPU
model.state_dict()獲取模型參數字典(保存)
model.load_state_dict()加載模型參數字典

七、使用 nn.Sequential

nn.Sequential 是一個順序容器,可以用來簡化網絡結構定義:

model = nn.Sequential(nn.Linear(784, 128),nn.ReLU(),nn.Linear(128, 10)
)

等價于手寫的自定義 nn.Module。適合前向傳播是線性“流動”的結構。


八、實戰完整示例:MNIST 分類網絡

class MNISTNet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Flatten(),nn.Linear(28*28, 256),nn.ReLU(),nn.Linear(256, 10))def forward(self, x):return self.net(x)# 實例化模型
model = MNISTNet()
print(model)# 配置訓練
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)# 示例訓練循環
for epoch in range(10):for images, labels in train_loader:output = model(images)loss = criterion(output, labels)optimizer.zero_grad()loss.backward()optimizer.step()

九、常見陷阱和建議

問題說明
forward() 不起作用應該使用 model(x),而不是手動調用 model.forward(x)
忘記 super().__init__()子模塊將不會被注冊
參數未注冊層/模塊必須賦值為 self.xxx = ...
訓練/測試模式混淆注意 model.eval()model.train()

十、總結

項目說明
__init__()定義模型結構(子模塊、層)
forward()定義前向傳播
自動注冊參數所有 self.xxx = nn.XXX(...) 都會被追蹤
嵌套模塊支持遞歸子模塊調用
便捷方法.parameters().to().eval()

十一、綜合示例

以下是基于 PyTorch nn.Module 封裝的三種經典深度學習架構(ResNet18UNetTransformer)的簡潔而完整的實現,適合初學者快速上手。


1、ResNet18 簡潔實現(適合圖像分類)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_planes, planes, stride=1, downsample=None):super().__init__()self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1   = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)self.bn2   = nn.BatchNorm2d(planes)self.downsample = downsampledef forward(self, x):identity = xif self.downsample:identity = self.downsample(x)out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += identityreturn F.relu(out)class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000):super().__init__()self.in_planes = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1   = nn.BatchNorm2d(64)self.pool  = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64,  layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc      = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.in_planes != planes * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_planes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(planes * block.expansion))layers = [block(self.in_planes, planes, stride, downsample)]self.in_planes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.in_planes, planes))return nn.Sequential(*layers)def forward(self, x):x = self.pool(F.relu(self.bn1(self.conv1(x))))x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x).flatten(1)return self.fc(x)def ResNet18(num_classes=1000):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

2、UNet(適合圖像分割)

class UNetBlock(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.ReLU(inplace=True))def forward(self, x):return self.block(x)class UNet(nn.Module):def __init__(self, in_channels=1, out_channels=1):super().__init__()self.enc1 = UNetBlock(in_channels, 64)self.enc2 = UNetBlock(64, 128)self.enc3 = UNetBlock(128, 256)self.enc4 = UNetBlock(256, 512)self.pool = nn.MaxPool2d(2)self.bottleneck = UNetBlock(512, 1024)self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)self.dec4 = UNetBlock(1024, 512)self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)self.dec3 = UNetBlock(512, 256)self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)self.dec2 = UNetBlock(256, 128)self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)self.dec1 = UNetBlock(128, 64)self.final = nn.Conv2d(64, out_channels, kernel_size=1)def forward(self, x):e1 = self.enc1(x)e2 = self.enc2(self.pool(e1))e3 = self.enc3(self.pool(e2))e4 = self.enc4(self.pool(e3))b  = self.bottleneck(self.pool(e4))d4 = self.upconv4(b)d4 = self.dec4(torch.cat([d4, e4], dim=1))d3 = self.upconv3(d4)d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.upconv2(d3)d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.upconv1(d2)d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)

3、簡化版 Transformer 編碼器(適合序列建模)

class TransformerBlock(nn.Module):def __init__(self, embed_dim, heads, ff_hidden_dim, dropout=0.1):super().__init__()self.attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout, batch_first=True)self.ff = nn.Sequential(nn.Linear(embed_dim, ff_hidden_dim),nn.ReLU(),nn.Linear(ff_hidden_dim, embed_dim))self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):attn_out, _ = self.attn(x, x, x, attn_mask=mask)x = self.norm1(x + self.dropout(attn_out))ff_out = self.ff(x)x = self.norm2(x + self.dropout(ff_out))return xclass TransformerEncoder(nn.Module):def __init__(self, vocab_size, embed_dim=512, n_heads=8, ff_dim=2048, num_layers=6, max_len=512):super().__init__()self.embedding = nn.Embedding(vocab_size, embed_dim)self.pos_encoding = self._generate_positional_encoding(max_len, embed_dim)self.layers = nn.ModuleList([TransformerBlock(embed_dim, n_heads, ff_dim)for _ in range(num_layers)])self.dropout = nn.Dropout(0.1)def _generate_positional_encoding(self, max_len, d_model):pos = torch.arange(0, max_len).unsqueeze(1)i = torch.arange(0, d_model, 2)angle_rates = 1 / torch.pow(10000, (i / d_model))pos_enc = torch.zeros(max_len, d_model)pos_enc[:, 0::2] = torch.sin(pos * angle_rates)pos_enc[:, 1::2] = torch.cos(pos * angle_rates)return pos_enc.unsqueeze(0)def forward(self, x):B, T = x.shapex = self.embedding(x) + self.pos_encoding[:, :T].to(x.device)x = self.dropout(x)for layer in self.layers:x = layer(x)return x

4、 總結對比

模型類型場景特點
ResNet18圖像分類深殘差網絡結構,適合遷移學習
UNet圖像分割對稱結構,編碼 + 解碼 + skip
TransformerNLP / 序列建模全注意力機制,無卷積無循環

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/90562.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/90562.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/90562.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

深入解析三大Web安全威脅:文件上傳漏洞、SQL注入漏洞與WebShell

文章目錄文件上傳漏洞SQL注入漏洞WebShell三者的核心關聯:攻擊鏈閉環文件上傳漏洞 文件上傳漏洞(File Upload Vulnerability) 當Web應用允許用戶上傳文件但未實施充分的安全驗證時,攻擊者可上傳惡意文件(如WebShell、…

【對比】群體智能優化算法 vs 貝葉斯優化

在機器學習、工程優化和科學計算中,優化算法的選擇直接影響問題求解的效率與效果。群體智能優化算法(Swarm Intelligence, SI)和貝葉斯優化(Bayesian Optimization, BO)是兩種截然不同的優化范式,分別以不同…

LLMs之Agent:ChatGPT Agent發布—統一代理系統將研究與行動無縫對接,開啟智能助理新時代

LLMs之Agent:ChatGPT Agent發布—統一代理系統將研究與行動無縫對接,開啟智能助理新時代 目錄 OpenAI重磅發布ChatGPT Agent—統一代理系統將研究與行動無縫對接,開啟智能助理新時代 第一部分:Operator 和深度研究的自然演進 第…

Linux726 raid0,raid1,raid5;raid 創建、保存、停止、刪除

RAID創建 創建raid0 安裝mdadm yum install mdadm mdadm --create /dev/md0 --raid-devices2 /dev/sdb5 /dev/sdb6 [rootsamba caozx26]# mdadm --create /dev/md0 --raid-devices2 /dev/sdb3 /dev/sdb5 --level0 mdadm: Defaulting to version 1.2 metadata mdadm: array /dev…

深入剖析 MetaGPT 中的提示詞工程:WriteCode 動作的提示詞設計

今天,我想和大家分享關于 AI 提示詞工程的文章。提示詞(Prompt)是大型語言模型(LLM)生成高質量輸出的關鍵,而在像 MetaGPT 這樣的 AI 驅動軟件開發框架中,提示詞的設計直接決定了代碼生成的可靠…

關于 ESXi 中 “ExcelnstalledOnly 已禁用“ 的解決方案

第一步:使用ssh登錄esxi esxcli system settings advanced list -o /User/execInstalledOnly可能會得到以下內容 esxcli system settings advanced list -o /User/execInstalledOnlyPath: /User/ExecInstalledOnlyType: integerInt Value: 0Default Int Value: 1Min…

HTML5 Canvas 繪制圓弧效果

HTML5 Canvas 繪制圓弧效果 以下是一個使用HTML5 Canvas繪制圓弧的完整示例&#xff0c;你可以直接在瀏覽器中運行看到效果&#xff1a; <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"view…

智能Agent場景實戰指南 Day 18:Agent決策樹與規劃能力

【智能Agent場景實戰指南 Day 18】Agent決策樹與規劃能力 開篇 歡迎來到"智能Agent場景實戰指南"系列的第18天&#xff01;今天我們將深入探討智能Agent的核心能力之一&#xff1a;決策樹與規劃能力。在現代業務場景中&#xff0c;Agent需要具備類似人類的決策能力…

AI 編程工具 Trae 重要的升級。。。

大家好&#xff0c;我是櫻木。 今天打開 Trae &#xff0c;已經看到它進行圖標升級&#xff0c;之前的圖標&#xff0c;國際和國內版本長得非常像&#xff0c;現在做了很明顯的區分&#xff0c;這點給 Trae 團隊點個贊。 自從 Claude 使出了壓力以來&#xff0c;Cursor 鎖區&…

排序算法,咕咕咕

1.選擇排序void selectsort(vector<int>& v) { for(int i0;i<v.size()-1;i) {int minii;for(int ji1;j<v.size();j){if(v[i]>v[j]){minij;}}if(mini!i)swap(v[i],v[mini]); } }2.堆排序void adjustdown(vector<int>& v,int root,int size) { int …

數據庫查詢系統——pyqt+python實現Excel內查課

一、引言 數據庫查詢系統處處存在&#xff0c;在教育信息化背景下&#xff0c;數據庫查詢技術更已深度融入教務管理場景。本系統采用輕量化架構&#xff0c;結合Excel課表&#xff0c;通過PythonPyQt5實現跨平臺桌面應用&#xff0c;以實現簡單查課效果。 二、GUI界面設計 使用…

base64魔改算法 | jsvmp日志分析并還原

前言 上一篇我們講了標準 base64 算法還原&#xff0c;為了進一步學習 base64 算法特點&#xff0c;本文將結合 jsvmp 日志&#xff0c;實戰還原出 base64 魔改算法。 為了方便大家學習&#xff0c;我將入參和上篇文章一樣&#xff0c;入參為 Hello, World!。 插樁 在js代碼中&…

vue3筆記(2)自用

目錄 一、作用域插槽 二、pinia的使用 一、Pinia 基本概念與用法 1. 安裝與初始化 2. 創建 Store 3. 在組件中使用 Store 4. 高級用法 5、storeToRefs 二、Pinia 與 Vuex 的主要區別 三、為什么選擇 Pinia&#xff1f; 三、定義全局指令 1.封裝通用 DOM 操作&#…

大模型面試回答,介紹項目

1. 模型準備與轉換&#xff08;PC端/服務器&#xff09;你先在PC上下載或訓練好大語言模型&#xff08;如HuggingFace格式&#xff09;。用RKLLM-Toolkit把模型轉換成瑞芯微NPU能用的專用格式&#xff08;.rkllm&#xff09;&#xff0c;并可選擇量化優化。把轉換好的模型文件拷…

Oracle 19.20未知BUG導致oraagent進程內存泄漏

故障現象查詢操作系統進程的使用排序&#xff0c;這里看到oraagent的物理內存達到16G&#xff0c;遠遠超過正常環境&#xff08;正常環境在19.20大概就是100M多一點&#xff09;[rootorastd tmp]# ./hmem|more PID NAME VIRT(kB) SHARED(kB) R…

嘗試幾道算法題,提升python編程思維

一、跳躍游戲題目描述&#xff1a; 給定一個非負整數數組 nums&#xff0c;你最初位于數組的第一個下標。數組中的每個元素代表你在該位置可以跳躍的最大長度。判斷你是否能夠到達最后一個下標。示例&#xff1a;輸入&#xff1a;nums [2,3,1,1,4] → 輸出&#xff1a;True輸入…

【菜狗處理臟數據】對很多個不同時間序列數據的文件聚類—20250722

目錄 具體做法 可視化方法1&#xff1a;PCA降維 可視化方法2、TSNE降維可視化&#xff08;非線性降維&#xff0c;更適合聚類&#xff09; 可視化方法3、輪廓系數評判好壞 每個文件有很多行列的信息&#xff0c;每列是一個駕駛相關的數據&#xff0c;需要對這些文件進行聚類…

Qwen-MT:翻得快,譯得巧

我們再向大家介紹一位新朋友&#xff1a;機器翻譯模型Qwen-MT。開發者朋友們可通過Qwen API&#xff08;qwen-mt-turbo&#xff09;&#xff0c;來直接體驗它又快又準的翻譯技能。 本次更新基于強大的 Qwen3 模型&#xff0c;進一步使用超大規模多語言和翻譯數據對模型進行訓練…

在 OceanBase 中,使用 TO_CHAR 函數 直接轉換日期格式,簡潔高效的解決方案

SQL語句SELECT TO_CHAR(TO_DATE(your_column, DD-MON-YY), YYYY-MM-DD) AS formatted_date FROM your_table;關鍵說明&#xff1a;核心函數&#xff1a;TO_DATE(30-三月-15, DD-MON-YY) → 將字符串轉為日期類型TO_CHAR(..., YYYY-MM-DD) → 格式化為 2015-03-30處理中文月份&a…

pnpm運行electronic項目報錯,npm運行正常。electronic項目打包為exe報錯

pnpm運行electronic項目報錯 使用 pnpm 運行 electronic 項目報錯&#xff0c;npm 運行正常&#xff0c;報錯內容如下 error during start dev server and electron app: Error: Electron uninstallat getElectronPath (file:///E:/project/xxx-vue/node_modules/.pnpm/elect…