pytorch小記(十三):pytorch中`nn.ModuleList` 詳解
- PyTorch 中的 `nn.ModuleList` 詳解
- 1. 什么是 `nn.ModuleList`?
- 2. 為什么不直接使用普通的 Python 列表?
- 3. `nn.ModuleList` 的基本用法
- 示例:構建一個包含兩層全連接網絡的模型
- 4. 使用 `nn.ModuleList` 計算參數總數(與普通列表對比)
- 示例代碼
- 5. `nn.ModuleList` 的其他應用
- 示例:構建動態 MLP 模型
- Transformers中的多頭注意力機制
- 6. 總結
PyTorch 中的 nn.ModuleList
詳解
在構建深度學習模型時,經常需要管理多個網絡層(例如多個 nn.Linear
、nn.Conv2d
等)。在 PyTorch 中,nn.ModuleList
是一個非常有用的容器,可以幫助我們存儲多個子模塊,并自動注冊它們的參數。這對于確保所有參數能夠參與訓練非常重要。本文將詳細介紹 nn.ModuleList
的作用、使用方法及與普通 Python 列表的區別,并給出清晰的代碼示例。
1. 什么是 nn.ModuleList
?
nn.ModuleList
是一個類似于 Python 列表的容器,但專門用來存儲 PyTorch 的子模塊(也就是繼承自 nn.Module
的對象)。其主要特點是:
-
自動注冊子模塊:將
nn.Module
存儲在ModuleList
中后,這些模塊的參數會自動被添加到父模塊的參數列表中。這意味著當你調用model.parameters()
時,這些子模塊的參數也會被包含進去,從而參與梯度計算和優化。 -
靈活管理:它可以像普通列表一樣進行索引、迭代和切片操作,方便構建動態網絡結構。
注意:
nn.ModuleList
不會像nn.Sequential
那樣自動定義前向傳播(forward)流程。你需要在模型的forward()
方法中手動遍歷ModuleList
并調用各個子模塊。
2. 為什么不直接使用普通的 Python 列表?
雖然可以將 nn.Module
對象存儲在普通列表中,但這樣做有一個主要問題:
普通列表中的模塊不會自動注冊為父模塊的子模塊。
這會導致:
- 調用
model.parameters()
時無法獲取這些模塊的參數; - 優化器無法更新這些參數,從而影響模型訓練。
而使用 nn.ModuleList
可以避免這個問題,因為它會自動將內部所有的模塊注冊到父模塊中。
3. nn.ModuleList
的基本用法
下面通過一個簡單的示例來說明如何使用 nn.ModuleList
構建一個簡單的神經網絡模型。
示例:構建一個包含兩層全連接網絡的模型
import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 創建一個 ModuleList 來存儲各層self.layers = nn.ModuleList([nn.Linear(10, 20), # 第 1 層:輸入 10 個特征,輸出 20 個特征nn.ReLU(), # 激活層nn.Linear(20, 5) # 第 2 層:輸入 20 個特征,輸出 5 個特征])def forward(self, x):# 手動遍歷 ModuleList 中的每個模塊,并依次調用 forwardfor layer in self.layers:x = layer(x)return x# 創建模型實例
model = MyModel()# 打印模型結構
print("模型結構:")
print(model)# 生成一組示例輸入
input_tensor = torch.randn(3, 10) # 3 個樣本,每個樣本 10 個特征# 得到模型輸出
output = model(input_tensor)
print("\n模型輸出:")
print(output)
模型結構:
MyModel((layers): ModuleList((0): Linear(in_features=10, out_features=20, bias=True)(1): ReLU()(2): Linear(in_features=20, out_features=5, bias=True))
)模型輸出:
tensor([[ 0.3741, 0.0883, 0.3550, -0.3930, 0.5173],[ 0.2171, -0.0978, -0.0585, -0.4568, 0.3331],[ 0.1232, -0.1491, 0.2026, -0.0978, 0.5478]],grad_fn=<AddmmBackward0>)
說明:
- 在
__init__()
方法中,我們將各個層放在了nn.ModuleList
中。 - 在
forward()
方法中,我們使用了一個簡單的 for 循環,依次調用self.layers
中的每個子模塊。
4. 使用 nn.ModuleList
計算參數總數(與普通列表對比)
為了進一步說明 nn.ModuleList
與普通列表的區別,我們分別計算一下兩種方式下模型的參數總數。
示例代碼
import torch.nn as nn# 使用 ModuleList 存儲模型層
layers_ml = nn.ModuleList([nn.Linear(10, 20),nn.Linear(20, 5)
])# 計算 ModuleList 中的參數總數
ml_params = 0
for p in layers_ml.parameters():ml_params += p.numel()# 使用普通 Python 列表存儲模型層
layers_list = [nn.Linear(10, 20),nn.Linear(20, 5)
]# 計算普通列表中的參數總數
list_params = 0
# 先遍歷列表中的每個層
for layer in layers_list:# 再遍歷每個層的參數for p in layer.parameters():list_params += p.numel()print("ModuleList 參數總數:", ml_params)
print("普通列表參數總數:", list_params)
ModuleList 參數總數: 325
普通列表參數總數: 325
說明:
- 第一個 for 循環遍歷
layers_ml.parameters()
,直接累加所有參數的元素數。 - 第二部分中,我們先遍歷普通列表中的每個
layer
,再單獨遍歷每個層的參數。這樣做使每一步都清晰易懂。
5. nn.ModuleList
的其他應用
示例:構建動態 MLP 模型
當網絡結構比較復雜或層數不固定時,可以利用列表生成器動態構建 ModuleList
。
class DynamicMLP(nn.Module):def __init__(self, layer_sizes):super(DynamicMLP, self).__init__()# 使用 for 循環構造每一層,存儲在 ModuleList 中layers = [] # 先用普通列表保存層for i in range(len(layer_sizes) - 1):linear_layer = nn.Linear(layer_sizes[i], layer_sizes[i + 1])layers.append(linear_layer)# 將普通列表轉換為 ModuleListself.layers = nn.ModuleList(layers)def forward(self, x):# 遍歷每一層(沒有嵌套循環,逐個執行)for layer in self.layers:x = torch.relu(layer(x))return x# 創建一個動態 MLP:輸入 10,隱藏層 20, 30,輸出 5
dynamic_model = DynamicMLP([10, 20, 30, 5])
print("動態 MLP 模型:")
print(dynamic_model)# 測試模型
input_tensor = torch.randn(4, 10) # 4 個樣本,每個樣本 10 個特征
output = dynamic_model(input_tensor)
print("\n動態 MLP 模型輸出:")
print(output)
說明:
- 在
__init__()
方法中,我們使用一個普通列表layers
存儲每個nn.Linear
層,然后再將它轉換為nn.ModuleList
。 - 在
forward()
方法中,使用單獨的 for 循環逐個調用每一層,并對輸出應用 ReLU 激活函數。 - 這種寫法適用于層數動態變化的網絡(例如 MLP、RNN、Transformer 中部分模塊)。
Transformers中的多頭注意力機制
class SingleHeadAttention(nn.Module):def __init__(self, embed_dim, head_dim):super().__init__()self.query = nn.Linear(embed_dim, head_dim)self.key = nn.Linear(embed_dim, head_dim)self.value = nn.Linear(embed_dim, head_dim)def forward(self, x):# 實現注意力計算邏輯...return attended_valuesclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.head_dim = embed_dim // num_heads# 顯式創建每個注意力頭self.head1 = SingleHeadAttention(embed_dim, self.head_dim)self.head2 = SingleHeadAttention(embed_dim, self.head_dim)self.head3 = SingleHeadAttention(embed_dim, self.head_dim)# 使用ModuleList管理多個頭self.heads = nn.ModuleList([self.head1,self.head2,self.head3])self.output_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# 分別處理每個頭head1_out = self.head1(x)head2_out = self.head2(x) head3_out = self.head3(x)# 拼接結果combined = torch.cat([head1_out, head2_out, head3_out], dim=-1)return self.output_proj(combined)
關鍵點解析:
-
顯式聲明每個注意力頭(避免循環)
-
使用ModuleList統一管理注意力頭
-
在forward中分別調用每個頭
-
保持各頭獨立性,便于后續調試
6. 總結
nn.ModuleList
是專門用于存儲多個子模塊的容器,它會自動注冊子模塊,確保所有參數能參與訓練。- 與普通 Python 列表相比,
ModuleList
可以直接通過model.parameters()
獲取其中所有參數,從而方便地進行優化。 - 使用
ModuleList
時,前向傳播需要手動遍歷其中的模塊,這提供了更大的靈活性,但也要求開發者理解循環過程。