參考:
Transformer模型詳解(圖解最完整版) - 知乎https://zhuanlan.zhihu.com/p/338817680GitHub - liaoyanqing666/transformer_pytorch: 完整的原版transformer程序,complete origin transformer program
https://github.com/liaoyanqing666/transformer_pytorcharxiv.org/pdf/1706.03762
https://arxiv.org/pdf/1706.03762
一. Transformer的整體架構
Transformer 由 Encoder (編碼器)和 Decoder (解碼器)兩個部分組成,Encoder 和 Decoder 都包含 6 個 block(塊)。Transformer 的工作流程大體如下:
第一步:獲取輸入句子的每一個單詞的表示向量?X,X由單詞本身的 Embedding(Embedding就是從原始數據提取出來的特征(Feature)) 和單詞位置的 Embedding 相加得到。
第二步:將得到的單詞表示向量矩陣 (如上圖所示,每一行是一個單詞的表示?x)傳入 Encoder 中,經過 6 個 Encoder block (編碼器塊)后可以得到句子所有單詞的編碼信息矩陣?C。如下圖,單詞向量矩陣用 表示, n 是句子中單詞個數,d 是表示向量的維度(論文中 d=512)。每一個 Encoder block (編碼器塊)輸出的矩陣維度與輸入完全一致。
第三步:將 Encoder (編碼器)輸出的編碼信息矩陣?C傳遞到 Decoder(解碼器)中,Decoder(解碼器) 依次會根據當前翻譯過的單詞 1~ i 翻譯下一個單詞 i+1,如下圖所示。在使用的過程中,翻譯到單詞 i+1 的時候需要通過?Mask (掩蓋)?操作遮蓋住 i+1 之后的單詞。
上圖 Decoder 接收了 Encoder 的編碼矩陣?C,然后首先輸入一個翻譯開始符 "<Begin>",預測第一個單詞 "I";然后輸入翻譯開始符 "<Begin>" 和單詞 "I",預測單詞 "have",以此類推。
二. Transformer 的輸入
Transformer 中單詞的輸入表示?x?由單詞本身的 Embedding?和單詞位置 Embedding?(Positional Encoding)相加得到。
2.1 單詞 Embedding(詞嵌入層)
單詞本身的 Embedding 有很多種方式可以獲取,例如可以采用 Word2Vec、Glove 等算法預訓練得到,也可以在 Transformer 中訓練得到。
self.embedding = nn.Embedding(vocabulary, dim)
功能解釋:
-
作用:將離散的整數索引(單詞ID)轉換為連續的向量表示
-
輸入:形狀為?
[sequence_length]
?的整數張量 -
輸出:形狀為?
[sequence_length, dim]
?的浮點數張量(,n是序列長度,d是特征維度)
參數詳解:
參數 | 含義 | 示例值 | 說明 |
---|---|---|---|
vocabulary | 詞匯表大小 | 10000 | 表示模型能處理的不同單詞/符號總數 |
dim | 嵌入維度 | 512 | 每個單詞被表示成的向量長度 |
工作原理:
-
創建一個可學習的嵌入矩陣[vocabulary, dim],例如當?
vocabulary=10000
,?dim=512
?時,是一個?10000×512
?的矩陣; -
每個整數索引對應矩陣中的一行:
# 假設單詞"apple"的ID=42
apple_vector = embedding_matrix[42] # 形狀 [512]
在Transformer中的具體作用:
# 輸入:src = torch.randint(0, 10000, (2, 10))
# 形狀:[batch_size=2, seq_len=10]src_embedded = self.embedding(src)# 輸出形狀變為:[2, 10, 512]
# 每個整數單詞ID被替換為512維的向量
可視化表現:
原始輸入 (單詞ID):
[ [ 25, 198, 3000, ... ], # 句子1[ 1, 42, 999, ... ] ] # 句子2經過嵌入層后 (向量表示):
[ [ [0.2, -0.5, ..., 1.3], # ID=25的向量[0.8, 0.1, ..., -0.9], # ID=198的向量... ],[ [0.9, -0.2, ..., 0.4], # ID=1的向量[0.3, 0.7, ..., -1.2], # ID=42的向量... ] ]
為什么需要詞嵌入:
-
語義表示:相似的單詞會有相似的向量表示
-
降維:將離散的ID映射到連續空間(one-hot編碼需要10000維 → 嵌入只需512維)
-
可學習:在訓練過程中,這些向量會不斷調整以更好地表示語義關系
2.2??位置 Embedding(位置編碼)
Transformer 的位置編碼(Positional Encoding,PE)是模型的關鍵創新之一,它解決了傳統序列模型(如 RNN)固有的順序處理問題。Transformer 的自注意力機制本身不具備感知序列位置的能力,位置編碼通過向輸入嵌入添加位置信息,使模型能夠理解序列中元素的順序關系。位置編碼計算之后的輸出維度和詞嵌入層相同,均為()。
位置編碼的核心作用:
-
注入位置信息:讓模型區分不同位置的相同單詞(如 "bank" 在句首 vs 句尾)
-
保持距離關系:編碼相對位置和絕對位置信息
-
支持并行計算:避免像 RNN 那樣依賴順序處理
為什么需要位置編碼?
-
自注意力的位置不變性:
,計算過程不包含位置信息
-
序列順序的重要性:
- 自然語言:"貓追狗" ≠ "狗追貓"
- 時序數據:股價序列的順序決定趨勢替代方案對比
方法 | 優點 | 缺點 |
---|---|---|
正弦/余弦 | 泛化性好,理論保證 | 固定模式不靈活 |
可學習 | 適應任務特定模式 | 長度受限,需訓練 |
相對位置 | 直接建模相對距離 | 實現復雜 |
位置編碼的實際效果
-
早期層作用:幫助模型建立位置感知
-
后期層作用:位置信息被融合到語義表示中
-
可視化示例:
Input: [The, cat, sat, on, mat]
Embed: [E_The, E_cat, E_sat, E_on, E_mat]
Position: [P0, P1, P2, P3, P4]Final: [E_The+P0, E_cat+P1, ... E_mat+P4]
(1)正余弦位置編碼(論文采用)
正余弦位置編碼的計算公式:
其中:
- ?`pos` 是token在序列中的位置(從0開始)
- ?`d_model` 是模型的嵌入維度(即每個token的向量維度)
- ?`i` 是維度的索引(從0到d_model/2-1)
特點:
- 波長幾何級數:覆蓋不同頻率
- 相對位置可學習:位置偏移的線性變換 PE_{pos+k} 可表示為 PE_{pos} 的線性函數
- 泛化性強:可處理比訓練時更長的序列
- 對稱性:sin/cos 組合允許模型學習相對位置
代碼實現:
class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)return x + self.pe[:len, :]
例如,指定emb_dim=512和max_len=100,句子長度為10,則位置embedding的數值計算如下(三角函數取弧度制):
(2)可學習位置編碼
class LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]
特性
- 直接學習位置嵌入:作為模型參數訓練
- 靈活性高:可適應特定任務的位置模式
- 長度受限:只能處理預定義的最大長度
- 計算效率高:直接查表無需計算
三. Self-Attention(自注意力機制)和Multi-Head Attention(多頭自注意力)
Transformer 的內部結構圖,左側為 Encoder block(編碼器),右側為 Decoder block(解碼器)。可以看到:
(1)Encoder block 包含一個 Multi-Head Attention;
(2)Decoder block 包含兩個 Multi-Head Attention (其中有一個用到 Masked)。Multi-Head Attention 上方還包括一個 Add & Norm 層,Add 表示殘差連接(Residual Connection),用于防止網絡退化,Norm 表示Layer Normalization,用于對每一層的激活值進行歸一化。
Multi-Head Attention 是 Transformer 的重點,它由?Self-Attention 演變而來,我們先從?Self-Attention 講起。
3.1? Self-Attention(自注意力機制)
Self-Attention(自注意力)是 Transformer 架構的核心創新,它徹底改變了序列建模的方式。與傳統的循環神經網絡(RNN)和卷積神經網絡(CNN)不同,self-attention 能夠直接捕捉序列中任意兩個元素之間的關系,無論它們之間的距離有多遠:
Self-Attention 的輸入用矩陣(n是序列長度,d是特征維度)進行表示,計算如下:
(1)通過可學習的權重矩陣生成Q(查詢),K(鍵值),V(值):
其中是可學習參數。
(2)計算?Self-Attention 的輸出:
步驟分解:
-
相似度計算:
計算所有查詢-鍵對之間的點積相似度,
得到的矩陣行列數都為 n,n為句子單詞數,這個矩陣可以表示單詞之間的 attention 強度。
-
縮放:除以
防止點積過大導致梯度消失
-
歸一化:softmax 將相似度轉換為概率分布
-
加權求和:用注意力權重對值向量加權求和,得到最終的輸出
輸入序列: [x1, x2, x3, x4]步驟1: 為每個輸入生成Q,K,V向量
x1 → q1, k1, v1
x2 → q2, k2, v2
x3 → q3, k3, v3
x4 → q4, k4, v4步驟2: 計算注意力權重 (以x1為例)
權重1 = softmax(q1·k1 / √d_k)
權重2 = softmax(q1·k2 / √d_k)
權重3 = softmax(q1·k3 / √d_k)
權重4 = softmax(q1·k4 / √d_k)步驟3: 加權求和
輸出1 = 權重1*v1 + 權重2*v2 + 權重3*v3 + 權重4*v4
3.2? Multi-Head Attention(多頭注意力)
Transformer 使用多頭機制增強模型表達能力:
其中每個注意力頭:
-
h:注意力頭的數量
-
:每個頭的獨立參數
-
:輸出投影矩陣
代碼實現:
(1)多頭分割處理:使用view
將特征維度分割為多個頭,確保每個頭的維度:dim_head = dim_qk // num_heads
q = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)
k = ... # 類似處理
v = ... # 類似處理
(2)高效的矩陣運算:使用矩陣乘法并行計算所有位置的注意力分數
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)
(3)多頭合并:使用view
合并多頭:num_heads * d_v = dim_v
output = output.transpose(1, 2)
output = output.contiguous().view(-1, len_q, self.dim_v)
完整Multi-Head Attention(多頭注意力)的代碼實現,這里已經考慮了掩碼處理的實現,關于掩碼將在后面介紹。
class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return output
六.完整代碼實現
import torch
import torch.nn as nnclass LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)print(self.pe[:len, :])return x + self.pe[:len, :]class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return outputclass Feedforward(nn.Module):def __init__(self, dim, hidden_dim=2048, dropout=0., activate=nn.ReLU()):super(Feedforward, self).__init__()self.dim = dimself.hidden_dim = hidden_dimself.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, dim)self.act = activatedef forward(self, x):x = self.act(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xdef attn_mask(len):""":param len: length of sequence:return: mask tensor, False for not replaced, True for replaced as -infe.g. attn_mask(3) =tensor([[[False, True, True],[False, False, True],[False, False, False]]])"""mask = torch.triu(torch.ones(len, len, dtype=torch.bool), 1)return maskdef padding_mask(pad_q, pad_k):""":param pad_q: pad label of query (0 is padding, 1 is not padding), [B, len_q]:param pad_k: pad label of key (0 is padding, 1 is not padding), [B, len_k]:return: mask tensor, False for not replaced, True for replaced as -infe.g. pad_q = tensor([[1, 1, 0]], [1, 0, 1])padding_mask(pad_q, pad_q) =tensor([[[False, False, True],[False, False, True],[ True, True, True]],[[False, True, False],[ True, True, True],[False, True, False]]])"""assert pad_q.ndim == pad_k.ndim == 2, 'pad_q and pad_k must be 2-dimensional'assert pad_q.size(0) == pad_k.size(0), 'batch size mismatch'mask = pad_q.bool().unsqueeze(2) * pad_k.bool().unsqueeze(1)mask = ~mask# mask: [B, len_q, len_k]return maskclass EncoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(EncoderLayer, self).__init__()self.attn = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)def forward(self, x, mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn(res1, res1, res1, mask)res2 = self.norm2(x)x = x + self.ffn(res2)else:x = self.attn(x, x, x, mask) + xx = self.norm1(x)x = self.ffn(x) + xx = self.norm2(x)return xclass Encoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Encoder, self).__init__()self.layers = nn.ModuleList([EncoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, mask=None):for layer in self.layers:x = layer(x, mask)return xclass DecoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(DecoderLayer, self).__init__()self.attn1 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.attn2 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)self.norm3 = nn.LayerNorm(dim)def forward(self, x, enc, self_mask=None, pad_mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn1(res1, res1, res1, self_mask)res2 = self.norm2(x)x = x + self.attn2(res2, enc, enc, pad_mask)res3 = self.norm3(x)x = x + self.ffn(res3)else:x = self.attn1(x, x, x, self_mask) + xx = self.norm1(x)x = self.attn2(x, enc, enc, pad_mask) + xx = self.norm2(x)x = self.ffn(x) + xx = self.norm3(x)return xclass Decoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Decoder, self).__init__()self.layers = nn.ModuleList([DecoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, enc, self_mask=None, pad_mask=None):for layer in self.layers:x = layer(x, enc, self_mask, pad_mask)return xclass Transformer(nn.Module):def __init__(self, dim, vocabulary, num_heads=1, num_layers=1, dropout=0., learnable_pos=False, pre_norm=False):super(Transformer, self).__init__()self.dim = dimself.vocabulary = vocabularyself.num_heads = num_headsself.num_layers = num_layersself.dropout = dropoutself.learnable_pos = learnable_posself.pre_norm = pre_normself.embedding = nn.Embedding(vocabulary, dim)self.pos_enc = LearnablePositionalEncoding(dim, 100) if learnable_pos else PositionalEncoding(dim, 100)self.encoder = Encoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.decoder = Decoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.linear = nn.Linear(dim, vocabulary)def forward(self, src, tgt, src_mask=None, tgt_mask=None, pad_mask=None):# src.shape: torch.Size([2, 10])src = self.embedding(src)# src.shape: torch.Size([2, 10, 512])src = self.pos_enc(src)# src.shape: torch.Size([2, 10, 512])src = self.encoder(src, src_mask)# src.shape: torch.Size([2, 10, 512])# tgt.shape: torch.Size([2, 8])tgt = self.embedding(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.pos_enc(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.decoder(tgt, src, tgt_mask, pad_mask)# tgt.shape: torch.Size([2, 8, 512])output = self.linear(tgt)# output.shape: torch.Size([2, 8, 10000])return outputdef get_mask(self, tgt, src_pad=None):# Under normal circumstances, tgt_pad will perform mask processing when calculating loss, and it isn't necessarily in decoderif src_pad is not None:src_mask = padding_mask(src_pad, src_pad)else:src_mask = Nonetgt_mask = attn_mask(tgt.size(1))if src_pad is not None:pad_mask = padding_mask(torch.zeros_like(tgt), src_pad)else:pad_mask = None# src_mask: [B, len_src, len_src]# tgt_mask: [len_tgt, len_tgt]# pad_mask: [B, len_tgt, len_src]return src_mask, tgt_mask, pad_maskif __name__ == '__main__':model = Transformer(dim=512, vocabulary=10000, num_heads=8, num_layers=6, dropout=0.1, learnable_pos=False, pre_norm=True)src = torch.randint(0, 10000, (2, 10)) # torch.Size([2, 10])tgt = torch.randint(0, 10000, (2, 8)) # torch.Size([2, 8])src_pad = torch.randint(0, 2, (2, 10)) # torch.Size([2, 10])src_mask, tgt_mask, pad_mask = model.get_mask(tgt, src_pad)model(src, tgt, src_mask, tgt_mask, pad_mask)# output.shape: torch.Size([2, 8, 10000])