1. 定義
nn.Embedding
是 PyTorch 中的 查表式嵌入層(lookup‐table),用于將離散的整數索引(如詞 ID、實體 ID、離散特征類別等)映射到一個連續的、可訓練的低維向量空間。它通過維護一個形狀為 (num_embeddings, embedding_dim)
的權重矩陣,實現高效的“索引 → 向量”轉換。
2. 輸入與輸出
-
輸入
- 類型:整型張量(
torch.long
或torch.int64
),必須是 LongTensor,其他類型會報錯。 - 形狀:任意形狀
(*, L)
,其中最內層長度L
常為序列長度,前面的*
可以是 batch 及其他維度。 - 取值范圍:
0 ≤ index < num_embeddings
;超出范圍會拋出IndexError
。
- 類型:整型張量(
-
輸出
- 類型:浮點型張量(與權重相同的
dtype
,默認為torch.float32
)。 - 形狀:
(*, L, embedding_dim)
;就是在輸入張量后追加一個維度embedding_dim
。 - 語義:若輸入某位置的值為
j
,則該位置對應輸出就是權重矩陣的第j
行。
- 類型:浮點型張量(與權重相同的
3. 底層原理
-
查表操作 vs. One-hot 乘法
- 直觀上,Embedding 相當于:
output = one_hot ( i n p u t ) × W \text{output} = \text{one\_hot}(input) \;\times\; W output=one_hot(input)×W
其中W
是(num_embeddings×embedding_dim)
的權重矩陣。 - 為避免顯式構造稀疏的 one-hot 張量,PyTorch 直接根據索引做“取行”操作,效率更高、內存更省。
- 直觀上,Embedding 相當于:
-
梯度更新
- 稠密模式(默認):整個
W
都有梯度緩沖,優化器根據梯度更新所有行。 - 稀疏模式(
sparse=True
):僅對被索引過的行計算和存儲梯度,可配合optim.SparseAdam
高效更新,適合超大字典(百萬級以上)但每次只訪問少量索引的場景。
- 稠密模式(默認):整個
-
范數裁剪
- 若指定
max_norm
,每次前向都會對輸出向量(即對應的行)做范數裁剪,保證其 L-norm_type
范數不超過max_norm
,有助于防止某些頻繁訪問的詞向量過大。
- 若指定
-
權重初始化
- 默認初始化使用均勻分布:
W i , j ~ U ( ? 1 num_embeddings , 1 num_embeddings ) W_{i,j} \sim \mathcal{U}\Bigl(-\sqrt{\tfrac{1}{\text{num\_embeddings}}},\;\sqrt{\tfrac{1}{\text{num\_embeddings}}}\Bigr) Wi,j?~U(?num_embeddings1??,num_embeddings1??) - 可以通過
_weight
參數傳入外部預訓練權重(如 Word2Vec、GloVe 等)。
- 默認初始化使用均勻分布:
4. 構造函數參數詳解
參數 | 類型及默認 | 說明 |
---|---|---|
num_embeddings | int | 必填。嵌入表行數,等于類別總數(最大索引 + 1)。 |
embedding_dim | int | 必填。每個向量的維度。 |
padding_idx | int 或 None | 默認 None 。指定該索引對應行始終輸出全零,并且該行的梯度永遠為 0,適合做序列填充。 |
max_norm | float 或 None | 默認 None 。若設為數值,每次前向時對取出的向量做范數裁剪(L-norm_type ≤ max_norm )。 |
norm_type | float ,默認 2 | 與 max_norm 配合使用時定義范數類型,如 1-范數、2-范數等。 |
scale_grad_by_freq | bool ,默認 False | 若為 True ,在反向傳播階段按照索引在 batch 中出現的頻次對梯度做縮放(出現越多,梯度越小),有助于高頻詞的梯度平滑。 |
sparse | bool ,默認 False | 若為 True ,開啟稀疏更新,僅對被訪問行生成梯度;必須配合 optim.SparseAdam 使用,不支持常規稠密優化器。 |
_weight | Tensor 或 None | 若提供,則用此張量(形狀應為 (num_embeddings, embedding_dim) )作為權重初始化,否則隨機初始化。 |
5. 使用示例
import torch
import torch.nn as nn# 1. 參數設定
vocab_size = 10000 # 詞表大小
embed_dim = 300 # 嵌入維度# 2. 創建 Embedding 層
embedding = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim,padding_idx=0, # 將 0 作為填充索引,輸出全 0max_norm=5.0, # 向量范數不超過 5norm_type=2.0,scale_grad_by_freq=True,sparse=False
)# 3. 構造輸入
# batch_size=2, seq_len=6
input_ids = torch.tensor([[ 1, 234, 56, 789, 0, 23],[123, 4, 567, 8, 9, 0],
], dtype=torch.long)# 4. 前向計算
# 輸出 shape = [2, 6, 300]
output = embedding(input_ids)
print(output.shape) # -> torch.Size([2, 6, 300])
加載并凍結預訓練權重
import numpy as np# 假設有預訓練權重 pre_trained.npy,shape=(10000,300)
weights = torch.from_numpy(np.load("pre_trained.npy"))
embed_pre = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim,_weight=weights
)
# 凍結所有權重
embed_pre.weight.requires_grad = False
6. 注意事項
- 類型與范圍
- 輸入必須為 LongTensor,且所有索引滿足
0 ≤ index < num_embeddings
。
- 輸入必須為 LongTensor,且所有索引滿足
- Padding 與 Mask
- 僅指定
padding_idx
會返回零向量,但上游網絡(如 RNN、Transformer)還需顯式 mask,避免無效位置影響注意力或累積狀態。
- 僅指定
- 性能考量
max_norm
每次前向都做范數計算和裁剪,若不需要可關閉以提升速度。
- 稀疏更新限制
sparse=True
可節省內存,但只支持SparseAdam
,且在 GPU 上效率有時不如稠密模式。
- EmbeddingBag
- 對于可變長度序列的 sum/mean/power-mean 匯聚,可使用
nn.EmbeddingBag
,避免中間張量開銷。
- 對于可變長度序列的 sum/mean/power-mean 匯聚,可使用
- 分布式與大詞表
- 在分布式訓練時,可將嵌入表切分到多個進程上(
torch.nn.parallel.DistributedDataParallel
+torch.nn.Embedding
支持參數分布式)。 - 超大詞表(千萬級)時,可考慮動態加載、分布式哈希表或專用庫(如 DeepSpeed 的嵌入稀疏優化)。
- 在分布式訓練時,可將嵌入表切分到多個進程上(