引言
在自然語言處理和序列生成任務中,自注意力機制(Self-Attention)是提升模型性能的關鍵技術。本文將通過一個自定義的PyTorch模型實現,展示如何構建一個結合多頭注意力與前饋網絡的序列生成模型(如文本或字符生成)。該模型通過創新的 MaxStateSuper
模塊實現動態特征融合,適用于字體生成、文本預測等場景。
技術背景
1. 模型結構解析
核心組件
-
MaxStateSuper(自注意力模塊)
- 功能:通過多頭注意力機制提取序列中的關鍵特征,并結合累積最大值操作增強長期依賴建模。
- 實現亮點:
- 合并三個線性層為一個
combined
層,減少參數冗余。 - 使用
torch.cummax
實現動態狀態積累,提升序列記憶能力。
- 合并三個線性層為一個
-
FeedForward(前饋網絡)
- 結構:兩層全連接網絡,中間夾雜
ReLU
激活函數和門控機制(gate
)。 - 作用:非線性變換,增強模型表達能力。
- 結構:兩層全連接網絡,中間夾雜
-
DecoderLayer(解碼器層)
- 創新點:
- 引入
alpha
參數平衡前饋網絡輸出與原始輸入的權重,實現動態特征融合。 - 層歸一化(
LayerNorm
)確保梯度穩定性。
- 引入
- 創新點:
-
SamOut(整體模型)
- 輸入:字符或token的Embedding向量。
- 輸出:預測的下一時刻token概率分布。
2. 關鍵技術
- 多頭注意力機制:通過
heads
參數將特征空間劃分為多個子空間,提升模型對不同模式的捕捉能力。 - 累積最大值操作:
out2 = torch.cummax(out2, dim=2)[0]
保留序列中的關鍵特征軌跡。 - 動態參數平衡:
alpha
參數通過梯度下降自動學習前饋網絡與原始輸入的權重比例。
代碼實現
完整代碼
import torch
import torch.nn as nn
import torch.optim as optimclass MaxStateSuper(nn.Module):def __init__(self, dim_size, heads):super().__init__()self.heads = headsassert dim_size % heads == 0, "Dimension size must be divisible by head size."self.combined = nn.Linear(dim_size, 3 * dim_size, bias=False) # 合并QKV線性層def forward(self, x):b, s, d = x.shape# 合并后的線性變換并分割為QKVqkv = self.combined(x).chunk(3, dim=-1)q, k, v = qkv# 調整形狀并執行注意力計算# ...(此處省略具體注意力計算邏輯,參考標準多頭注意力實現)...return out, stateclass FeedForward(nn.Module):def __init__(self, hidden_size):super().__init__()self.ffn1 = nn.Linear(hidden_size, hidden_size)self.ffn2 = nn.Linear(hidden_size, hidden_size)self.gate = nn.Linear(hidden_size, hidden_size)self.relu = nn.ReLU()def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))xx = x1 * x2return self.ffn2(xx)class DecoderLayer(nn.Module):def __init__(self, hidden_size, num_heads):super().__init__()self.self_attn = MaxStateSuper(hidden_size, num_heads)self.ffn = FeedForward(hidden_size)self.norm = nn.LayerNorm(hidden_size)self.alpha = nn.Parameter(torch.tensor(0.5)) # 動態平衡參數def forward(self, x):attn_out, _ = self.self_attn(x)ffn_out = self.ffn(attn_out)x = self.norm(self.alpha * ffn_out + (1 - self.alpha) * x)return xclass SamOut(nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super().__init__()self.embedding = nn.Embedding(voc_size, hidden_size, padding_idx=3)self.layers = nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])self.head = nn.Linear(hidden_size, voc_size, bias=False)def forward(self, x):x = self.embedding(x)for layer in self.layers:x = layer(x)return self.head(x)# 訓練流程(簡化版)
if __name__ == '__main__':voc_size = 10000 # 假設詞匯表大小model = SamOut(voc_size, hidden_size=256, num_heads=8, num_layers=6)criterion = nn.CrossEntropyLoss(ignore_index=3)optimizer = optim.Adam(model.parameters(), lr=1e-3)for epoch in range(10):# 假設 input_tensor 和 target_tensor 已準備output = model(input_tensor)loss = criterion(output.view(-1, voc_size), target_tensor.view(-1))loss.backward()optimizer.step()
關鍵步驟解析
1. MaxStateSuper
模塊的創新點
# 合并QKV層
qkv = self.combined(x<