pytorch小記(十三):pytorch中`nn.ModuleList` 詳解

pytorch小記(十三):pytorch中`nn.ModuleList` 詳解

  • PyTorch 中的 `nn.ModuleList` 詳解
    • 1. 什么是 `nn.ModuleList`?
    • 2. 為什么不直接使用普通的 Python 列表?
    • 3. `nn.ModuleList` 的基本用法
      • 示例:構建一個包含兩層全連接網絡的模型
    • 4. 使用 `nn.ModuleList` 計算參數總數(與普通列表對比)
      • 示例代碼
    • 5. `nn.ModuleList` 的其他應用
      • 示例:構建動態 MLP 模型
      • Transformers中的多頭注意力機制
    • 6. 總結


PyTorch 中的 nn.ModuleList 詳解

在構建深度學習模型時,經常需要管理多個網絡層(例如多個 nn.Linearnn.Conv2d 等)。在 PyTorch 中,nn.ModuleList 是一個非常有用的容器,可以幫助我們存儲多個子模塊,并自動注冊它們的參數。這對于確保所有參數能夠參與訓練非常重要。本文將詳細介紹 nn.ModuleList 的作用、使用方法及與普通 Python 列表的區別,并給出清晰的代碼示例。


1. 什么是 nn.ModuleList

nn.ModuleList 是一個類似于 Python 列表的容器,但專門用來存儲 PyTorch 的子模塊(也就是繼承自 nn.Module 的對象)。其主要特點是:

  • 自動注冊子模塊:將 nn.Module 存儲在 ModuleList 中后,這些模塊的參數會自動被添加到父模塊的參數列表中。這意味著當你調用 model.parameters() 時,這些子模塊的參數也會被包含進去,從而參與梯度計算和優化。

  • 靈活管理:它可以像普通列表一樣進行索引、迭代和切片操作,方便構建動態網絡結構。

注意nn.ModuleList 不會像 nn.Sequential 那樣自動定義前向傳播(forward)流程。你需要在模型的 forward() 方法中手動遍歷 ModuleList 并調用各個子模塊。


2. 為什么不直接使用普通的 Python 列表?

雖然可以將 nn.Module 對象存儲在普通列表中,但這樣做有一個主要問題:
普通列表中的模塊不會自動注冊為父模塊的子模塊
這會導致:

  • 調用 model.parameters() 時無法獲取這些模塊的參數;
  • 優化器無法更新這些參數,從而影響模型訓練。

而使用 nn.ModuleList 可以避免這個問題,因為它會自動將內部所有的模塊注冊到父模塊中。


3. nn.ModuleList 的基本用法

下面通過一個簡單的示例來說明如何使用 nn.ModuleList 構建一個簡單的神經網絡模型。

示例:構建一個包含兩層全連接網絡的模型

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 創建一個 ModuleList 來存儲各層self.layers = nn.ModuleList([nn.Linear(10, 20),  # 第 1 層:輸入 10 個特征,輸出 20 個特征nn.ReLU(),          # 激活層nn.Linear(20, 5)    # 第 2 層:輸入 20 個特征,輸出 5 個特征])def forward(self, x):# 手動遍歷 ModuleList 中的每個模塊,并依次調用 forwardfor layer in self.layers:x = layer(x)return x# 創建模型實例
model = MyModel()# 打印模型結構
print("模型結構:")
print(model)# 生成一組示例輸入
input_tensor = torch.randn(3, 10)  # 3 個樣本,每個樣本 10 個特征# 得到模型輸出
output = model(input_tensor)
print("\n模型輸出:")
print(output)
模型結構:
MyModel((layers): ModuleList((0): Linear(in_features=10, out_features=20, bias=True)(1): ReLU()(2): Linear(in_features=20, out_features=5, bias=True))
)模型輸出:
tensor([[ 0.3741,  0.0883,  0.3550, -0.3930,  0.5173],[ 0.2171, -0.0978, -0.0585, -0.4568,  0.3331],[ 0.1232, -0.1491,  0.2026, -0.0978,  0.5478]],grad_fn=<AddmmBackward0>)

說明

  • __init__() 方法中,我們將各個層放在了 nn.ModuleList 中。
  • forward() 方法中,我們使用了一個簡單的 for 循環,依次調用 self.layers 中的每個子模塊。

4. 使用 nn.ModuleList 計算參數總數(與普通列表對比)

為了進一步說明 nn.ModuleList 與普通列表的區別,我們分別計算一下兩種方式下模型的參數總數。

示例代碼

import torch.nn as nn# 使用 ModuleList 存儲模型層
layers_ml = nn.ModuleList([nn.Linear(10, 20),nn.Linear(20, 5)
])# 計算 ModuleList 中的參數總數
ml_params = 0
for p in layers_ml.parameters():ml_params += p.numel()# 使用普通 Python 列表存儲模型層
layers_list = [nn.Linear(10, 20),nn.Linear(20, 5)
]# 計算普通列表中的參數總數
list_params = 0
# 先遍歷列表中的每個層
for layer in layers_list:# 再遍歷每個層的參數for p in layer.parameters():list_params += p.numel()print("ModuleList 參數總數:", ml_params)
print("普通列表參數總數:", list_params)
ModuleList 參數總數: 325
普通列表參數總數: 325

說明

  • 第一個 for 循環遍歷 layers_ml.parameters(),直接累加所有參數的元素數。
  • 第二部分中,我們先遍歷普通列表中的每個 layer,再單獨遍歷每個層的參數。這樣做使每一步都清晰易懂。

5. nn.ModuleList 的其他應用

示例:構建動態 MLP 模型

當網絡結構比較復雜或層數不固定時,可以利用列表生成器動態構建 ModuleList

class DynamicMLP(nn.Module):def __init__(self, layer_sizes):super(DynamicMLP, self).__init__()# 使用 for 循環構造每一層,存儲在 ModuleList 中layers = []  # 先用普通列表保存層for i in range(len(layer_sizes) - 1):linear_layer = nn.Linear(layer_sizes[i], layer_sizes[i + 1])layers.append(linear_layer)# 將普通列表轉換為 ModuleListself.layers = nn.ModuleList(layers)def forward(self, x):# 遍歷每一層(沒有嵌套循環,逐個執行)for layer in self.layers:x = torch.relu(layer(x))return x# 創建一個動態 MLP:輸入 10,隱藏層 20, 30,輸出 5
dynamic_model = DynamicMLP([10, 20, 30, 5])
print("動態 MLP 模型:")
print(dynamic_model)# 測試模型
input_tensor = torch.randn(4, 10)  # 4 個樣本,每個樣本 10 個特征
output = dynamic_model(input_tensor)
print("\n動態 MLP 模型輸出:")
print(output)

說明

  • __init__() 方法中,我們使用一個普通列表 layers 存儲每個 nn.Linear 層,然后再將它轉換為 nn.ModuleList
  • forward() 方法中,使用單獨的 for 循環逐個調用每一層,并對輸出應用 ReLU 激活函數。
  • 這種寫法適用于層數動態變化的網絡(例如 MLP、RNN、Transformer 中部分模塊)。

Transformers中的多頭注意力機制

class SingleHeadAttention(nn.Module):def __init__(self, embed_dim, head_dim):super().__init__()self.query = nn.Linear(embed_dim, head_dim)self.key = nn.Linear(embed_dim, head_dim)self.value = nn.Linear(embed_dim, head_dim)def forward(self, x):# 實現注意力計算邏輯...return attended_valuesclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.head_dim = embed_dim // num_heads# 顯式創建每個注意力頭self.head1 = SingleHeadAttention(embed_dim, self.head_dim)self.head2 = SingleHeadAttention(embed_dim, self.head_dim)self.head3 = SingleHeadAttention(embed_dim, self.head_dim)# 使用ModuleList管理多個頭self.heads = nn.ModuleList([self.head1,self.head2,self.head3])self.output_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# 分別處理每個頭head1_out = self.head1(x)head2_out = self.head2(x) head3_out = self.head3(x)# 拼接結果combined = torch.cat([head1_out, head2_out, head3_out], dim=-1)return self.output_proj(combined)

關鍵點解析:

  • 顯式聲明每個注意力頭(避免循環)

  • 使用ModuleList統一管理注意力頭

  • 在forward中分別調用每個頭

  • 保持各頭獨立性,便于后續調試


6. 總結

  • nn.ModuleList 是專門用于存儲多個子模塊的容器,它會自動注冊子模塊,確保所有參數能參與訓練。
  • 與普通 Python 列表相比,ModuleList 可以直接通過 model.parameters() 獲取其中所有參數,從而方便地進行優化。
  • 使用 ModuleList 時,前向傳播需要手動遍歷其中的模塊,這提供了更大的靈活性,但也要求開發者理解循環過程。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/72746.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/72746.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/72746.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Excel導出工具類--復雜的excel功能導出(使用自定義注解導出)

Excel導出工具類 前言: 簡單的excel導出,可以用easy-excel, fast-excel, auto-poi,在導出實體類上加上對應的注解,用封裝好的工具類直接導出,但對于復雜的場景, 封裝的工具類解決不了,要用原生的excel導出(easy-excel, fast-excel, auto-poi都支持原生的) 業務場景: 根據…

批量測試IP和域名聯通性2

在前面批量測試IP和域名聯通性-CSDN博客的基礎上&#xff0c;由于IP和域名多樣性&#xff0c;比如帶端口號的192.168.1.17:17&#xff0c;實際上應該ping 192.168.1.17。如果封禁http://www.abc.com/a.exe&#xff0c;實際可ping www.abc.com。所以又完善了代碼。 echo off se…

國產編輯器EverEdit - 語法著色文件的語法

1 語法著色定義(官方文檔) 1.1 概述 EverEdit有著優異的語法著色引擎&#xff0c;可以高亮現存的絕大多數的編程語言。在EverEdit的語法著色中有Region和Item兩個概念&#xff0c;Region表示著不同的區塊。而Item則代表著這些區塊中不同的部分。一般情況下&#xff0c;Region…

Excel處理控件Aspose.Cells教程:如何自動將 HTML 轉換為 Excel

在處理 HTML 表中呈現的結構化數據時&#xff0c;將 HTML 轉換為 Excel 是一種常見需求。無論您是從網站、報告還是任何其他來源提取數據&#xff0c;將其轉換為 Excel 都可以更好地進行分析、操作和共享。 開發人員通常更喜歡使用編程方法將 HTML 轉換為 Excel&#xff0c;因…

基于springbo校園安全管理系統(源碼+lw+部署文檔+講解),源碼可白嫖!

摘要 隨著信息時代的來臨&#xff0c;過去信息校園安全管理方式的缺點逐漸暴露&#xff0c;本次對過去的校園安全管理方式的缺點進行分析&#xff0c;采取計算機方式構建校園安全管理系統。本文通過閱讀相關文獻&#xff0c;研究國內外相關技術&#xff0c;提出了一種集進出校…

vim在連續多行行首插入相同的字符

工作中經常需要用vim注釋掉一段代碼或者json文件中的一部分&#xff0c;需要在多行前面插入//或者#符號。在 Vim 中&#xff0c;在連續多行行首插入相同字符主要有以下兩種方法&#xff1a; Visual Block 模式插入 將光標移到要插入相同內容的第一行的行首24。按下Ctrl v進入…

Git 實戰指南:本地客戶端連接 Gitee 全流程

本文將以 Gitee(碼云)、系統Windows 11 為例,詳細介紹從本地倉庫初始化到遠程協作的全流程操作 目錄 1. 前期準備1.1 注冊與配置 Gitee1.2 下載、安裝、配置客戶端1.3 配置公鑰到 Gitee2. 本地倉庫操作(PowerShell/Git Bash)2.1 初始化本地倉庫2.2 關聯 Gitee 遠程倉庫3. …

Pytest項目_day01(HTTP接口)

HTTP HTTP是一個協議&#xff08;服務器傳輸超文本到瀏覽器的傳送協議&#xff09;&#xff0c;是基于TCP/IP通信協議來傳輸數據&#xff08;HTML文件&#xff0c;圖片文件&#xff0c;查詢結果等&#xff09;。 訪問域名 例如www.baidu.com就是百度的域名&#xff0c;我們想…

MySQL超詳細介紹(近2萬字)

1. 簡單概述 MySQL安裝后默認有4個庫不可以刪除&#xff0c;存儲的是服務運行時加載的不同功能的程序和數據 information_schema&#xff1a;是MySQL數據庫提供的一個虛擬的數據庫&#xff0c;存儲了MySQL數據庫中的相關信息&#xff0c;比如數據庫、表、列、索引、權限、角色等…

SQLMesh宏操作符深度解析:掌握@star與@GENERATE_SURROGATE_KEY實戰技巧

引言&#xff1a;解鎖SQLMesh的動態查詢能力 在復雜的數據處理場景中&#xff0c;手動編寫重復性SQL代碼不僅效率低下&#xff0c;還難以維護。SQLMesh作為新一代數據庫中間件&#xff0c;通過其強大的宏系統賦予開發者編程式構建查詢的能力。本文將重點解析兩個核心操作符——…

超詳細kubernetes部署k8s----一臺master和兩臺node

一、部署說明 1、主機操作系統說明 2、主機硬件配置說明 二、主機準備&#xff08;沒有特別說明都是三臺都要配置&#xff09; 1、配置主機名和IP 2、配置hosts解析 3、防火墻和SELinux 4、時間同步配置 5、配置內核轉發及網橋過濾 6、關閉swap 7、啟用ipvs 8、句柄…

高光譜相機在水果分類與品質檢測中的應用

一、核心應用領域 ?外部品質檢測? ?表面缺陷識別&#xff1a;通過400-1000nm波段的高光譜成像&#xff0c;可檢測蘋果表皮損傷、碰傷等細微缺陷&#xff0c;結合圖像分割技術實現快速分類?。 ?損傷程度評估&#xff1a;例如青香蕉的碰撞損傷會導致光譜反射率變化&#…

【藍橋杯每日一題】3.17

&#x1f3dd;?專欄&#xff1a; 【藍橋杯備篇】 &#x1f305;主頁&#xff1a; f狐o貍x 他們說內存泄漏是bug&#xff0c;我說這是系統在逼我進化成SSR級程序員 OK來吧&#xff0c;不多廢話&#xff0c;今天來點有難度的&#xff1a;二進制枚舉 二進制枚舉&#xff0c;就是…

Windows11 新機開荒(二)電腦優化設置

目錄 前言&#xff1a; 一、注冊微軟賬號綁定權益 二、此電腦 桌面圖標 三、系統分盤及默認存儲位置更改 3.1 系統分盤 3.2 默認存儲位置更改 四、精簡任務欄 總結&#xff1a; 前言&#xff1a; 本文承接上一篇 新機開荒&#xff08;一&#xff09; 上一篇文章地址&…

aws(學習筆記第三十三課) 深入使用cdk 練習aws athena

文章目錄 aws(學習筆記第三十三課) 深入使用cdk學習內容&#xff1a;1. 使用aws athena1.1 什么是aws athena1.2 什么是aws glue1.2 為什么aws athena和aws glue一起使用 2. 開始練習aws athena2.1 代碼鏈接2.2 整體架構2.3 代碼解析2.3.1 創建測試數據的S3 bucket2.3.2 創建保…

每日學習Java之一萬個為什么(待補充)

Git分支操作 git branch 分支名 git branch -v git checkout -b 分支名 git checkout 分支名 git merge 分支名 git branch -d | -D 分支名Git沖突 git同名文件合并的最基本單位是行。同名文件同一行不同就會發生沖突。 解決辦法&#xff1a;及時溝通&#xff0c;手動更改&…

C++ 多生產者單消費者(MPSC)模式

根據你的需求,多生產者單消費者(MPSC)模式的日志任務隊列需要調整設計。以下是改進后的代碼實現,重點在于多線程安全入隊、單線程消費任務,并確保停止時隊列任務全部處理完畢: 多生產者單消費者(MPSC)任務隊列實現 #include <iostream> #include <queue> …

OpenCV基礎【圖像和視頻的加載與顯示】

目錄 一.創建一個窗口&#xff0c;顯示圖片 二.顯示攝像頭/多媒體文件 三.把攝像頭錄取到的視頻存儲在本地 四.鼠標回調事件 五.TrackBar滑動條 一.創建一個窗口&#xff0c;顯示圖片 import cv2img_path "src/fengjing.jpg" # 自己的圖片路徑 img cv2.imre…

c++--vector

1.定義vector vector的定義分為四種 (1)vector() ——————無參構造 (2)vector(size_t n,const value_type& val value_type()) ——————構造并初始化n個val (3)vector(const vector& v1) ———————拷貝構造 (4)vector(inputiterator first,inpu…

宇樹科技純技能要求總結

一、嵌入式開發與硬件設計 核心技能 嵌入式開發&#xff1a; 精通C/C&#xff0c;熟悉STM32、ARM開發熟悉Linux BSP開發及驅動框架&#xff08;SPI/UART/USB/FLASH/Camera/GPS/LCD&#xff09;掌握主流平臺&#xff08;英偉達、全志、瑞芯微等&#xff09; 硬件設計&#xff1a…