torch.norm 是什么
torch.norm(dot_product, p=2, dim=-1)
是 PyTorch 中用于計算張量 L2 范數的函數,
1. 各參數解析
dot_product
:輸入張量,在代碼中形狀為[batch_size, seq_len]
(每個元素是 token 隱藏狀態與關注向量的點積)。p=2
:指定計算L2 范數(歐幾里得范數),公式為:對于向量[x?, x?, ..., x?]
,L2 范數 =√(x?2 + x?2 + ... + x?2)
。dim=-1
:指定計算范數的維度。-1
表示“最后一個維度”,在[batch_size, seq_len]
中即seq_len
維度(序列長度維度)。
2. 計算邏輯(結合代碼上下文)
假設 dot_product
的形狀為 [2, 3]
(batch_size=2
,seq_len=3