1. 定義
nn.Linear
是 PyTorch 中最基礎的全連接(fully‐connected)線性層,也稱仿射變換層(affine layer)。它對輸入張量做一次線性變換:
output = x W T + b \text{output} = x W^{T} + b output=xWT+b
其中, W W W是形狀為 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)的權重矩陣, b b b是長度為 out_features \text{out\_features} out_features的偏置向量。
2. 輸入與輸出
-
輸入(Input)
- 類型:浮點型張量(如
torch.float32
,也可與權重同 dtype) - 形狀: ( … , i n _ f e a t u r e s ) (\dots, \mathrm{in\_features}) (…,in_features),最后一維必須與
in_features
匹配;前面的…
可以是 batch 大小或任意額外維度。
- 類型:浮點型張量(如
-
輸出(Output)
- 類型:浮點型張量
- 形狀: ( … , o u t _ f e a t u r e s ) (\dots, \mathrm{out\_features}) (…,out_features),即將最后一維從
in_features
變換為out_features
,前面的維度保持不變。
3. 底層原理
- 矩陣乘法
對輸入張量的最后一維做矩陣乘法:
y [ . . . , j ] = ∑ i = 1 i n _ f e a t u r e s x [ . . . , i ] × W j , i ( ? j = 1 … o u t _ f e a t u r e s ) y[..., j] = \sum_{i=1}^{\mathrm{in\_features}} x[..., i] \times W_{j,i} \quad (\forall\,j=1\ldots\mathrm{out\_features}) y[...,j]=i=1∑in_features?x[...,i]×Wj,i?(?j=1…out_features) - 加偏置
若bias=True
,則在乘積結果上逐元素加偏置 b j b_j bj?:
y [ . . . , j ] = ( x W T ) [ . . . , j ] + b j y[..., j] \;=\; (x W^T)[..., j] \;+\; b_j y[...,j]=(xWT)[...,j]+bj? - 梯度更新
- 權重 W W W:梯度 ? L ? W j , i = ∑ x [ . . . , i ] × δ y [ . . . , j ] \frac{\partial \mathcal{L}}{\partial W_{j,i}} = \sum x[..., i]\times \delta y[..., j] ?Wj,i??L?=∑x[...,i]×δy[...,j]
- 偏置 b b b:梯度 ∑ δ y [ . . . , j ] \sum \delta y[..., j] ∑δy[...,j]
優化器(如 SGD、Adam 等)根據梯度更新參數。
4. 構造函數參數詳解
參數 | 類型 & 默認 | 說明 |
---|---|---|
in_features | int | 必填。輸入特征維度,即每個樣本最后一維的大小。 |
out_features | int | 必填。輸出特征維度,即映射后的最后一維大小(神經元個數)。 |
bias | bool ,默認 True | 是否使用偏置向量 b 。若為 False ,則不加偏置。 |
device | torch.device 或 None | 指定權重和偏置所在設備(CPU/GPU);若為 None ,默認繼承父模塊設備。 |
dtype | torch.dtype 或 None | 指定權重和偏置的數據類型;若為 None ,默認繼承父模塊 dtype。 |
權重與偏置初始化
- 默認情況下,
W
按均勻分布初始化:
W j , i ~ U ( ? 1 i n _ f e a t u r e s , 1 i n _ f e a t u r e s ) W_{j,i}\sim \mathcal{U}\Bigl(-\sqrt{\tfrac{1}{\mathrm{in\_features}}},\;\sqrt{\tfrac{1}{\mathrm{in\_features}}}\Bigr) Wj,i?~U(?in_features1??,in_features1??) - 偏置 b b b初始化為全零。
5. 使用示例
import torch
import torch.nn as nn# 1. 定義線性層
in_dim = 64 # 輸入維度
out_dim = 10 # 輸出維度(例如分類 10 類的 logits)
linear = nn.Linear(in_features=in_dim,out_features=out_dim,bias=True
)# 2. 構造輸入
# 假設 batch_size=8
x = torch.randn(8, in_dim) # shape=[8,64]# 3. 前向計算
# 輸出 shape=[8,10]
y = linear(x)
print(y.shape) # torch.Size([8, 10])
如果不希望使用偏置:
linear_no_bias = nn.Linear(in_dim, out_dim, bias=False)
在更高維場景下也可:
# 輸入 shape=[batch, seq_len, in_dim]
x_seq = torch.randn(8, 5, in_dim)
# 輸出 shape=[8, 5, out_dim]
y_seq = linear(x_seq)
6. 注意事項
- 維度匹配
- 確保輸入最后一維等于
in_features
,否則會報維度不匹配錯誤。
- 確保輸入最后一維等于
- 批量處理
- 對多維輸入,
nn.Linear
自動應用到最后一維,無需手動 reshape(除非想將多個維度合并后統一處理)。
- 對多維輸入,
- 初始化
- 默認初始化適用于常見場景;若訓練不穩定,可手動調用
nn.init
系列方法(如kaiming_uniform_
,xavier_normal_
等)重新初始化。
- 默認初始化適用于常見場景;若訓練不穩定,可手動調用
- 無偏置場景
- 對于某些網絡結構(如批歸一化緊跟線性層),可關閉
bias
減少參數不用多余偏置。
- 對于某些網絡結構(如批歸一化緊跟線性層),可關閉
- 設備與 dtype
- 在多 GPU 或混合精度訓練時,可通過
device
與dtype
參數顯式控制,避免后續.to()
調用。
- 在多 GPU 或混合精度訓練時,可通過
- 與 Conv1d 的聯系
- 本質上,
nn.Conv1d(in_dim, out_dim, kernel_size=1)
相當于在時間維度(或序列維度)上對每個位置做一個nn.Linear
;理解這一點有助于模型設計。
- 本質上,