Transformer結構與代碼實現詳解

參考:
Transformer模型詳解(圖解最完整版) - 知乎https://zhuanlan.zhihu.com/p/338817680GitHub - liaoyanqing666/transformer_pytorch: 完整的原版transformer程序,complete origin transformer programhttps://github.com/liaoyanqing666/transformer_pytorcharxiv.org/pdf/1706.03762https://arxiv.org/pdf/1706.03762

一. Transformer的整體架構

Transformer 由 Encoder (編碼器)和 Decoder (解碼器)兩個部分組成,Encoder 和 Decoder 都包含 6 個 block(塊)。Transformer 的工作流程大體如下:

第一步:獲取輸入句子的每一個單詞的表示向量?X,X由單詞本身的 Embedding(Embedding就是從原始數據提取出來的特征(Feature)) 和單詞位置的 Embedding 相加得到。

第二步:將得到的單詞表示向量矩陣 (如上圖所示,每一行是一個單詞的表示?x)傳入 Encoder 中,經過 6 個 Encoder block (編碼器塊)后可以得到句子所有單詞的編碼信息矩陣?C。如下圖,單詞向量矩陣用 X_{n\times d}表示, n 是句子中單詞個數,d 是表示向量的維度(論文中 d=512)。每一個 Encoder block (編碼器塊)輸出的矩陣維度與輸入完全一致。

第三步:將 Encoder (編碼器)輸出的編碼信息矩陣?C傳遞到 Decoder(解碼器)中,Decoder(解碼器) 依次會根據當前翻譯過的單詞 1~ i 翻譯下一個單詞 i+1,如下圖所示。在使用的過程中,翻譯到單詞 i+1 的時候需要通過?Mask (掩蓋)?操作遮蓋住 i+1 之后的單詞。

上圖 Decoder 接收了 Encoder 的編碼矩陣?C,然后首先輸入一個翻譯開始符 "<Begin>",預測第一個單詞 "I";然后輸入翻譯開始符 "<Begin>" 和單詞 "I",預測單詞 "have",以此類推。

二. Transformer 的輸入

Transformer 中單詞的輸入表示?x?單詞本身的 Embedding?和單詞位置 Embedding?(Positional Encoding)相加得到。

2.1 單詞 Embedding(詞嵌入層)

單詞本身的 Embedding 有很多種方式可以獲取,例如可以采用 Word2Vec、Glove 等算法預訓練得到,也可以在 Transformer 中訓練得到。

self.embedding = nn.Embedding(vocabulary, dim)

功能解釋:

  1. 作用:將離散的整數索引(單詞ID)轉換為連續的向量表示

  2. 輸入:形狀為?[sequence_length]?的整數張量

  3. 輸出:形狀為?[sequence_length, dim]?的浮點數張量(X_{n\times d},n是序列長度,d是特征維度)

參數詳解:

參數含義示例值說明
vocabulary詞匯表大小10000表示模型能處理的不同單詞/符號總數
dim嵌入維度512每個單詞被表示成的向量長度

工作原理:

  1. 創建一個可學習的嵌入矩陣[vocabulary, dim],例如當?vocabulary=10000,?dim=512?時,是一個?10000×512?的矩陣;

  2. 每個整數索引對應矩陣中的一行:

# 假設單詞"apple"的ID=42
apple_vector = embedding_matrix[42]  # 形狀 [512]

在Transformer中的具體作用:

# 輸入:src = torch.randint(0, 10000, (2, 10))
# 形狀:[batch_size=2, seq_len=10]src_embedded = self.embedding(src)# 輸出形狀變為:[2, 10, 512]
# 每個整數單詞ID被替換為512維的向量

可視化表現:

原始輸入 (單詞ID):
[ [ 25,  198, 3000, ... ],   # 句子1[ 1,   42,  999,  ... ] ]  # 句子2經過嵌入層后 (向量表示):
[ [ [0.2, -0.5, ..., 1.3],   # ID=25的向量[0.8, 0.1, ..., -0.9],   # ID=198的向量... ],[ [0.9, -0.2, ..., 0.4],   # ID=1的向量[0.3, 0.7, ..., -1.2],   # ID=42的向量... ] ]

為什么需要詞嵌入:

  • 語義表示:相似的單詞會有相似的向量表示

  • 降維:將離散的ID映射到連續空間(one-hot編碼需要10000維 → 嵌入只需512維)

  • 可學習:在訓練過程中,這些向量會不斷調整以更好地表示語義關系

2.2??位置 Embedding(位置編碼)

Transformer 的位置編碼(Positional Encoding,PE)是模型的關鍵創新之一,它解決了傳統序列模型(如 RNN)固有的順序處理問題。Transformer 的自注意力機制本身不具備感知序列位置的能力,位置編碼通過向輸入嵌入添加位置信息,使模型能夠理解序列中元素的順序關系。位置編碼計算之后的輸出維度和詞嵌入層相同,均為(X_{n\times d})。

位置編碼的核心作用:

  1. 注入位置信息:讓模型區分不同位置的相同單詞(如 "bank" 在句首 vs 句尾)

  2. 保持距離關系:編碼相對位置和絕對位置信息

  3. 支持并行計算:避免像 RNN 那樣依賴順序處理

為什么需要位置編碼?

  1. 自注意力的位置不變性
    Attention(Q,K,V)=softmax\left ( \frac{QK^{T}}{\sqrt{d_k}} \right )V,計算過程不包含位置信息

  2. 序列順序的重要性

  • 自然語言:"貓追狗" ≠ "狗追貓"
  • 時序數據:股價序列的順序決定趨勢替代方案對比
方法優點缺點
正弦/余弦泛化性好,理論保證固定模式不靈活
可學習適應任務特定模式長度受限,需訓練
相對位置直接建模相對距離實現復雜

位置編碼的實際效果

  1. 早期層作用:幫助模型建立位置感知

  2. 后期層作用:位置信息被融合到語義表示中

  3. 可視化示例

Input:    [The,   cat,   sat,   on,   mat]
Embed:    [E_The, E_cat, E_sat, E_on, E_mat]
Position: [P0,    P1,    P2,    P3,   P4]Final: [E_The+P0, E_cat+P1, ... E_mat+P4]
(1)正余弦位置編碼(論文采用)

正余弦位置編碼的計算公式:

其中:

  • ?`pos` 是token在序列中的位置(從0開始)
  • ?`d_model` 是模型的嵌入維度(即每個token的向量維度)
  • ?`i` 是維度的索引(從0到d_model/2-1)

特點:

  • 波長幾何級數:覆蓋不同頻率
  • 相對位置可學習:位置偏移的線性變換 PE_{pos+k} 可表示為 PE_{pos} 的線性函數
  • 泛化性強:可處理比訓練時更長的序列
  • 對稱性:sin/cos 組合允許模型學習相對位置

代碼實現:

class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)return x + self.pe[:len, :]

例如,指定emb_dim=512和max_len=100,句子長度為10,則位置embedding的數值計算如下(三角函數取弧度制):

\begin{bmatrix} sin\left ( \frac{0}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{0}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{0}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{0}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{0}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{0}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{1}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{1}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{1}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{1}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{1}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{1}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{2}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{2}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{2}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{2}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{2}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{2}{10000^{\frac{510}{512}}} \right )\\ ... & ... & ... & ... & ... & ... & ...\\ sin\left ( \frac{7}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{7}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{7}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{7}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{7}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{7}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{8}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{8}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{8}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{8}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{8}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{8}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{9}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{9}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{9}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{9}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{9}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{9}{10000^{\frac{510}{512}}} \right )\\ \end{bmatrix}_{10\times 512}=\begin{bmatrix} 0 & 1 & 0 & ... & 1 & 0 & 1\\ 0.8415 & 0.5403 & 0.8219 & ... & 1.0000 & 1.0366\times 10^{-4} & 1.0000\\ 0.9093 & -0.4161 & 0.9364 & ... & 1.0000 & 2.0733\times 10^{-4} & 1.0000\\ ... & ... & ... & ... & ... & ... & ...\\ 0.6570& 0.7539 & 0.4524 & ... & 1.0000 & 7.2564\times 10^{-4} & 1.0000\\ 0.9894 & -0.1455 & 0.9907 & ... & 1.0000 & 8.2931\times 10^{-4} & 1.0000\\ 0.4121 & -0.9111 & 0.6764 & ... & 1.0000 & 9.3297\times 10^{-4} & 1.0000 \end{bmatrix}_{10\times 512}

(2)可學習位置編碼
class LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]

特性

  • 直接學習位置嵌入:作為模型參數訓練
  • 靈活性高:可適應特定任務的位置模式
  • 長度受限:只能處理預定義的最大長度
  • 計算效率高:直接查表無需計算

三. Self-Attention(自注意力機制)和Multi-Head Attention(多頭自注意力)

Transformer 的內部結構圖,左側為 Encoder block(編碼器),右側為 Decoder block(解碼器)。可以看到:

(1)Encoder block 包含一個 Multi-Head Attention;

(2)Decoder block 包含兩個 Multi-Head Attention (其中有一個用到 Masked)。Multi-Head Attention 上方還包括一個 Add & Norm 層,Add 表示殘差連接(Residual Connection),用于防止網絡退化,Norm 表示Layer Normalization,用于對每一層的激活值進行歸一化。

Multi-Head Attention 是 Transformer 的重點,它由?Self-Attention 演變而來,我們先從?Self-Attention 講起。

3.1? Self-Attention(自注意力機制)

Self-Attention(自注意力)是 Transformer 架構的核心創新,它徹底改變了序列建模的方式。與傳統的循環神經網絡(RNN)和卷積神經網絡(CNN)不同,self-attention 能夠直接捕捉序列中任意兩個元素之間的關系,無論它們之間的距離有多遠:

Self-Attention 的輸入用矩陣X_{n\times d}(n是序列長度,d是特征維度)進行表示,計算如下:

(1)通過可學習的權重矩陣生成Q(查詢),K(鍵值),V(值):

\left\{\begin{matrix} Q = XW^Q \\ K = XW^K \\ V = XW^V \end{matrix}\right.

其中W^Q,W^K,W^V是可學習參數。

(2)計算?Self-Attention 的輸出:Attention(Q,K,V)=softmax\left ( \frac{QK^{T}}{\sqrt{d_k}} \right )V

步驟分解:

  1. 相似度計算QK^T計算所有查詢-鍵對之間的點積相似度,QK^T得到的矩陣行列數都為 n,n為句子單詞數,這個矩陣可以表示單詞之間的 attention 強度。

  2. 縮放:除以\sqrt{d_k}防止點積過大導致梯度消失

  3. 歸一化:softmax 將相似度轉換為概率分布

  4. 加權求和:用注意力權重對值向量加權求和,得到最終的輸出

輸入序列: [x1, x2, x3, x4]步驟1: 為每個輸入生成Q,K,V向量
x1 → q1, k1, v1
x2 → q2, k2, v2
x3 → q3, k3, v3
x4 → q4, k4, v4步驟2: 計算注意力權重 (以x1為例)
權重1 = softmax(q1·k1 / √d_k)
權重2 = softmax(q1·k2 / √d_k)
權重3 = softmax(q1·k3 / √d_k)
權重4 = softmax(q1·k4 / √d_k)步驟3: 加權求和
輸出1 = 權重1*v1 + 權重2*v2 + 權重3*v3 + 權重4*v4

3.2? Multi-Head Attention(多頭注意力)

Transformer 使用多頭機制增強模型表達能力:

MultiHead(Q,K,V)=Concat(head_1,head_2...head_h)W^O

其中每個注意力頭:

head_i=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

  • h:注意力頭的數量

  • W_i^Q, W_i^K, W_i^V:每個頭的獨立參數

  • W^O:輸出投影矩陣

代碼實現:

(1)多頭分割處理:使用view將特征維度分割為多個頭,確保每個頭的維度:dim_head = dim_qk // num_heads

q = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)
k = ... # 類似處理
v = ... # 類似處理

(2)高效的矩陣運算:使用矩陣乘法并行計算所有位置的注意力分數

attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)

(3)多頭合并:使用view合并多頭:num_heads * d_v = dim_v

output = output.transpose(1, 2)
output = output.contiguous().view(-1, len_q, self.dim_v)

完整Multi-Head Attention(多頭注意力)的代碼實現,這里已經考慮了掩碼處理的實現,關于掩碼將在后面介紹。

class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return output

六.完整代碼實現

import torch
import torch.nn as nnclass LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)print(self.pe[:len, :])return x + self.pe[:len, :]class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return outputclass Feedforward(nn.Module):def __init__(self, dim, hidden_dim=2048, dropout=0., activate=nn.ReLU()):super(Feedforward, self).__init__()self.dim = dimself.hidden_dim = hidden_dimself.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, dim)self.act = activatedef forward(self, x):x = self.act(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xdef attn_mask(len):""":param len: length of sequence:return: mask tensor, False for not replaced, True for replaced as -infe.g. attn_mask(3) =tensor([[[False,  True,  True],[False, False,  True],[False, False, False]]])"""mask = torch.triu(torch.ones(len, len, dtype=torch.bool), 1)return maskdef padding_mask(pad_q, pad_k):""":param pad_q: pad label of query (0 is padding, 1 is not padding), [B, len_q]:param pad_k: pad label of key (0 is padding, 1 is not padding), [B, len_k]:return: mask tensor, False for not replaced, True for replaced as -infe.g. pad_q = tensor([[1, 1, 0]], [1, 0, 1])padding_mask(pad_q, pad_q) =tensor([[[False, False,  True],[False, False,  True],[ True,  True,  True]],[[False,  True, False],[ True,  True,  True],[False,  True, False]]])"""assert pad_q.ndim == pad_k.ndim == 2, 'pad_q and pad_k must be 2-dimensional'assert pad_q.size(0) == pad_k.size(0), 'batch size mismatch'mask = pad_q.bool().unsqueeze(2) * pad_k.bool().unsqueeze(1)mask = ~mask# mask: [B, len_q, len_k]return maskclass EncoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(EncoderLayer, self).__init__()self.attn = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)def forward(self, x, mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn(res1, res1, res1, mask)res2 = self.norm2(x)x = x + self.ffn(res2)else:x = self.attn(x, x, x, mask) + xx = self.norm1(x)x = self.ffn(x) + xx = self.norm2(x)return xclass Encoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Encoder, self).__init__()self.layers = nn.ModuleList([EncoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, mask=None):for layer in self.layers:x = layer(x, mask)return xclass DecoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(DecoderLayer, self).__init__()self.attn1 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.attn2 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)self.norm3 = nn.LayerNorm(dim)def forward(self, x, enc, self_mask=None, pad_mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn1(res1, res1, res1, self_mask)res2 = self.norm2(x)x = x + self.attn2(res2, enc, enc, pad_mask)res3 = self.norm3(x)x = x + self.ffn(res3)else:x = self.attn1(x, x, x, self_mask) + xx = self.norm1(x)x = self.attn2(x, enc, enc, pad_mask) + xx = self.norm2(x)x = self.ffn(x) + xx = self.norm3(x)return xclass Decoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Decoder, self).__init__()self.layers = nn.ModuleList([DecoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, enc, self_mask=None, pad_mask=None):for layer in self.layers:x = layer(x, enc, self_mask, pad_mask)return xclass Transformer(nn.Module):def __init__(self, dim, vocabulary, num_heads=1, num_layers=1, dropout=0., learnable_pos=False, pre_norm=False):super(Transformer, self).__init__()self.dim = dimself.vocabulary = vocabularyself.num_heads = num_headsself.num_layers = num_layersself.dropout = dropoutself.learnable_pos = learnable_posself.pre_norm = pre_normself.embedding = nn.Embedding(vocabulary, dim)self.pos_enc = LearnablePositionalEncoding(dim, 100) if learnable_pos else PositionalEncoding(dim, 100)self.encoder = Encoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.decoder = Decoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.linear = nn.Linear(dim, vocabulary)def forward(self, src, tgt, src_mask=None, tgt_mask=None, pad_mask=None):# src.shape: torch.Size([2, 10])src = self.embedding(src)# src.shape: torch.Size([2, 10, 512])src = self.pos_enc(src)# src.shape: torch.Size([2, 10, 512])src = self.encoder(src, src_mask)# src.shape: torch.Size([2, 10, 512])# tgt.shape: torch.Size([2, 8])tgt = self.embedding(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.pos_enc(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.decoder(tgt, src, tgt_mask, pad_mask)# tgt.shape: torch.Size([2, 8, 512])output = self.linear(tgt)# output.shape: torch.Size([2, 8, 10000])return outputdef get_mask(self, tgt, src_pad=None):# Under normal circumstances, tgt_pad will perform mask processing when calculating loss, and it isn't necessarily in decoderif src_pad is not None:src_mask = padding_mask(src_pad, src_pad)else:src_mask = Nonetgt_mask = attn_mask(tgt.size(1))if src_pad is not None:pad_mask = padding_mask(torch.zeros_like(tgt), src_pad)else:pad_mask = None# src_mask: [B, len_src, len_src]# tgt_mask: [len_tgt, len_tgt]# pad_mask: [B, len_tgt, len_src]return src_mask, tgt_mask, pad_maskif __name__ == '__main__':model = Transformer(dim=512, vocabulary=10000, num_heads=8, num_layers=6, dropout=0.1, learnable_pos=False, pre_norm=True)src = torch.randint(0, 10000, (2, 10))  # torch.Size([2, 10])tgt = torch.randint(0, 10000, (2, 8))   # torch.Size([2, 8])src_pad = torch.randint(0, 2, (2, 10))  # torch.Size([2, 10])src_mask, tgt_mask, pad_mask = model.get_mask(tgt, src_pad)model(src, tgt, src_mask, tgt_mask, pad_mask)# output.shape: torch.Size([2, 8, 10000])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/912296.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/912296.shtml
英文地址,請注明出處:http://en.pswp.cn/news/912296.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Adobe InDesign 2025

Adobe InDesign 2025(ID2025)桌面出版軟件和在線發布工具,報刊雜志印刷排版設計軟件。Adobe InDesign中文版主要用于傳單設計,海報設計,明信片設計,電子書設計,排版,手冊設計,數字雜志,iPad應用程序和在線交互文檔。它是首款支持Unicode文本處理的主流DTP應用程序,率先使用新型…

Linux下獲取指定時間內某個進程的平均CPU使用率

一、引言 通過pidstat工具可以測量某個進程在兩個時間點之間的平均CPU利用率。 二、pidstat工具的安裝 pidstat屬于sysstat套件的一部分。以Ubuntu系統為例&#xff0c;執行下面命令下載安裝sysstat套件&#xff1a; apt-get install sysstat 執行完后&#xff0c;終端執行p…

1.4 蜂鳥E203處理器NICE接口詳解

一、NICE接口的概念 NICE&#xff08;Nuclei Instruction Co-unit Extension&#xff09;接口是蜂鳥E203處理器中用于擴展自定義指令的協處理器接口&#xff0c;基于RISC-V標準協處理器擴展機制設計。它允許用戶在不修改處理器核流水線的情況下&#xff0c;通過外部硬件加速特…

Oracle 遞歸 + Decode + 分組函數實現復雜樹形統計進階(第二課)

在上篇文章基礎上&#xff0c;我們進一步解決層級數據遞歸匯總問題 —— 讓上級部門的統計結果自動包含所有下級部門數據&#xff08;含多級子部門&#xff09;&#xff0c;并新增請假天數大于 3 天的統計維度。通過遞歸 CTE、DECODE函數與分組函數的深度結合&#xff0c;實現真…

MySQL 數據類型全面指南:詳細說明與關鍵注意事項

MySQL 數據類型全面指南&#xff1a;詳細說明與關鍵注意事項 MySQL 提供了豐富的數據類型&#xff0c;合理選擇對數據庫性能、存儲效率和數據準確性至關重要。以下是所有數據類型的詳細說明及使用注意事項&#xff1a; 一、數值類型 整數類型 類型字節有符號范圍無符號范圍說…

leetcode437-路徑總和III

leetcode 437 思路 利用前綴和hash map解答 前綴和在這里的含義是&#xff1a;從根節點到當前節點的路徑上所有節點值的總和 我們使用一個 Map 數據結構來記錄這些前綴和及其出現的次數 具體思路如下&#xff1a; 初始化&#xff1a;創建一個 Map &#xff0c;并將前綴和 …

UI前端與數字孿生融合探索新領域:智慧家居的可視化設計與實現

hello寶子們...我們是艾斯視覺擅長ui設計、前端開發、數字孿生、大數據、三維建模、三維動畫10年經驗!希望我的分享能幫助到您!如需幫助可以評論關注私信我們一起探討!致敬感謝感恩! 一、引言&#xff1a;智慧家居的數字化轉型浪潮 在物聯網與人工智能技術的推動下&#xff0c…

數據結構知識點總結--緒論

1.1 數據結構的基本概念 1.1.1 基本概念和術語 主要涉及概念有&#xff1a; 數據、數據元素、數據對象、數據類型、數據結構 #mermaid-svg-uyyvX6J6ofC9rFSB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-uyyvX6…

pip install mathutils 安裝 Blender 的 mathutils 模塊時,編譯失敗了

你遇到的問題是因為你試圖通過 pip install mathutils 安裝 Blender 的 mathutils 模塊時&#xff0c;編譯失敗了&#xff0c;主要原因是&#xff1a; 2018年 的老版本也不行 pip install mathutils2.79 ? 報錯核心總結&#xff1a; 缺失頭文件 BLI_path_util.h&#xff1a;…

編譯安裝交叉工具鏈 riscv-gnu-toolchain

參考鏈接&#xff1a; https://zhuanlan.zhihu.com/p/258394849 1&#xff0c;下載源碼 git clone https://gitee.com/mirrors/riscv-gnu-toolchain 2&#xff0c;進入目錄 cd riscv-gnu-toolchain 3&#xff0c;去掉qemu git rm qemu 4&#xff0c;初始化 git submodule…

復制 生成二維碼

一、安裝插件 1、復制 npm install -g copy-to-clipboard import copy from copy-to-clipboard; 2、生成二維碼 & 下載 npm install -g qrcode import QRCode from qrcode.react; 二、功能&#xff1a;生成二維碼 & 下載 效果圖 1、常規使用&#xff08;下載圖片模糊…

自由職業的經營視角

“領導力的核心是幫助他人看到自己看不到的東西。” — 彼得圣吉 最近與一些自由職業者的交流中&#xff0c;發現很多專業人士都會從專業視角來做交流&#xff0c;這也讓我更加理解我們海外戰略顧問莊老師在每月輔導時的提醒——經營者視角和專業人士視角的不同。這不僅讓大家獲…

MR30分布式 IO在物流堆垛機的應用

在現代物流行業蓬勃發展的浪潮中&#xff0c;物流堆垛機作為自動化倉儲系統的核心設備&#xff0c;承擔著貨物的高效存取與搬運任務。它憑借自動化操作、高精度定位等優勢&#xff0c;極大地提升了倉儲空間利用率和貨物周轉效率。然而&#xff0c;隨著物流行業的高速發展&#…

告別固定密鑰!在單一賬戶下用 Cognito 實現 AWS CLI 的 MFA 單點登錄

大家好&#xff0c;很多朋友&#xff0c;特別是通過合作伙伴或服務商使用 AWS 的同學&#xff0c;可能會發現自己的 IAM Identity Center 功能受限&#xff0c;無法像在組織管理賬戶里那樣輕松配置 CLI 的 SSO (aws configure sso)。那么&#xff0c;我們就要放棄治療&#xff…

未來機器視覺軟件將更注重成本控制,邊緣性能,魯棒性、多平臺支持、模塊優化與性能提升,最新版本opencv-4.11.0更新了什么

OpenCV 4.11.0 作為 4.10.0 的后續版本,雖然沒有在提供的搜索結果中直接列出詳細更新內容,但結合 OpenCV 4.10.0 的重大改進方向(發布于 2024 年 6 月),可以合理推斷 4.11.0 版本可能延續了對多平臺支持、模塊優化和性能提升的強化。以下是基于 OpenCV 近期更新模式的推測…

小程序入門:數據請求全解析

在微信小程序開發中&#xff0c;數據請求是實現豐富功能的關鍵環節。本文將帶你深入了解小程序數據請求的相關知識&#xff0c;包括請求限制、配置方法以及不同請求方式的實現&#xff0c;還會介紹如何在頁面加載時自動請求數據&#xff0c;同時附上詳細代碼示例&#xff0c;讓…

開源版gpt4o 多模態MiniGPT-4 實現原理詳解

MiniGPT-4是開源的GPT-4的平民版。本文用帶你快速掌握多模態大模型MiniGPT-4的模型架構、訓練秘訣、實戰亮點與改進方向。 1 模型架構全景&#xff1a;三層協同 &#x1f4ca; 模型底部實際輸入圖像&#xff0c;經 ViT Q-Former 編碼。藍色方塊 (視覺編碼器)&#xff1a;左側…

Flutter基礎(控制器)

第1步&#xff1a;找個遙控器&#xff08;創建控制器&#xff09;? // 就像買新遙控器要裝電池 TextEditingController myController TextEditingController(); ??第2步&#xff1a;連上你的玩具&#xff08;綁定到組件&#xff09;?? TextField(controller: myContro…

Spring Boot使用Redis常用場景

Spring Boot使用Redis常用場景 一、概述&#xff1a;Redis 是什么&#xff1f;為什么要用它&#xff1f; Redis&#xff08;Remote Dictionary Server&#xff09;是一個內存中的數據存儲系統&#xff08;類似一個“超級大字典”&#xff09;&#xff0c;它能存各種類型的數據…

CAD文件處理控件Aspose.CAD教程:在 C# 中將 DXF 文件轉換為 SVG - AutoCAD C# 示例

概述 使用 C# 輕松將DXF文件轉換為SVG。此轉換可更好地兼容 Web 應用程序&#xff0c;并增強 CAD 圖紙的視覺呈現效果。使用Aspose.CAD for .NET &#xff0c;開發人員可以輕松實現此轉換過程。該 SDK 提供強大的功能&#xff0c;使其成為 C# 開發人員的可靠選擇。Aspose.CAD …