歡迎來到啾啾的博客🐱。
記錄學習點滴。分享工作思考和實用技巧,偶爾也分享一些雜談💬。
有很多很多不足的地方,歡迎評論交流,感謝您的閱讀和評論😄。
目錄
- 引言
- 1 一個完整的Transformer模型
- 2 需要準備的“工具包”
- 3 Demo
引言
AI使用聲明:在內容整理、結構優化和語言表達的過程中,我使用了人工智能(AI)工具作為輔助。
如果以LLM應用工程師為目標,其實我們并不需要熟練掌握PyTorch,熟練掌握Transformer,但是我們必須對這兩者與其背后的信息有基本的了解誒,進而更好的團隊協作,以及微調模型。
本篇是一個完整的從0開始構建Transformer的Demo。
代碼由QWen3-Coder生成,可以運行調試。
1 一個完整的Transformer模型
![[從零構建TransformerP2-新聞分類Demo.png]]
2 需要準備的“工具包”
工具 | 作用 |
---|---|
nn.Embedding | 詞嵌入 |
nn.Linear | 投影層 |
F.softmax , F.relu | 激活函數 |
torch.matmul | 矩陣乘法(注意力核心) |
mask (triu, masked_fill) | 實現因果注意力 |
LayerNorm , Dropout | 穩定訓練 |
nn.ModuleList | 堆疊多層 |
DataLoader | 批量加載數據 |
3 Demo
"""
基于Transformer的新聞分類模型
嚴格按照設計流程實現,每個組件都有明確設計依據
""" import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Optional, Tuple # ==============================================
# 第一部分:基礎組件設計(根據設計決策選擇)
# ============================================== class TokenEmbedding(nn.Module): """ 詞嵌入層:將輸入的詞ID映射為密集向量表示 設計依據: - 文本任務需要詞嵌入表示語義 - 乘以sqrt(d_model)穩定初始化方差(原論文做法) """ def __init__(self, vocab_size: int, d_model: int): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.d_model = d_model def forward(self, x: torch.Tensor) -> torch.Tensor: """ 前向傳播 參數: x: 輸入詞ID張量,形狀為(batch_size, seq_len) 返回: 嵌入后的張量,形狀為(batch_size, seq_len, d_model) """ # 原論文建議乘以sqrt(d_model)來穩定方差 return self.embedding(x) * math.sqrt(self.d_model) class PositionalEncoding(nn.Module): """ 位置編碼:為輸入序列添加位置信息 設計依據: - Transformer沒有順序感知能力,必須添加位置信息 - 選擇可學習位置編碼(更靈活,適合變長序列) """ def __init__(self, d_model: int, max_len: int = 512): super().__init__() self.pos_embedding = nn.Embedding(max_len, d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: """ 前向傳播 參數: x: 輸入張量,形狀為(batch_size, seq_len, d_model) 返回: 添加位置編碼后的張量 """ batch_size, seq_len = x.size(0), x.size(1) # 生成位置ID: [0, 1, 2, ..., seq_len-1] positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1) return x + self.pos_embedding(positions) class MultiHeadAttention(nn.Module): """ 多頭注意力機制 設計依據: - 需要建模詞與詞之間的關系(自注意力) - 多頭機制允許模型在不同子空間關注不同關系 """ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1): super().__init__() assert d_model % num_heads == 0, "d_model必須能被num_heads整除" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # 線性變換層 self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ 縮放點積注意力 參數: q: 查詢張量,形狀為(batch_size, num_heads, seq_len, d_k) k: 鍵張量,形狀為(batch_size, num_heads, seq_len, d_k) v: 值張量,形狀為(batch_size, num_heads, seq_len, d_k) mask: 注意力掩碼,用于屏蔽padding或未來位置 返回: attention_output: 注意力輸出 attention_weights: 注意力權重(可用于可視化) """ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: # 將mask為0的位置設為極小值,使softmax后為0 attn_scores = attn_scores.masked_fill(mask == 0, -1e9) attn_probs = F.softmax(attn_scores, dim=-1) attn_probs = self.dropout(attn_probs) output = torch.matmul(attn_probs, v) return output, attn_probs def split_heads(self, x: torch.Tensor) -> torch.Tensor: """將輸入拆分為多個頭""" batch_size = x.size(0) x = x.view(batch_size, -1, self.num_heads, self.d_k) return x.transpose(1, 2) # (batch_size, num_heads, seq_len, d_k) def combine_heads(self, x: torch.Tensor) -> torch.Tensor: """將多個頭合并回原始形狀""" batch_size = x.size(0) x = x.transpose(1, 2).contiguous() return x.view(batch_size, -1, self.d_model) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ 前向傳播 參數: q, k, v: 查詢、鍵、值張量,形狀為(batch_size, seq_len, d_model) mask: 注意力掩碼 返回: 多頭注意力輸出,形狀為(batch_size, seq_len, d_model) """ q = self.split_heads(self.W_q(q)) k = self.split_heads(self.W_k(k)) v = self.split_heads(self.W_v(v)) attn_output, _ = self.scaled_dot_product_attention(q, k, v, mask) output = self.W_o(self.combine_heads(attn_output)) return output class FeedForward(nn.Module): """ 前饋神經網絡 設計依據: - 每個位置獨立處理,增強模型表示能力 - 通常d_ff = 4 * d_model(原論文比例) """ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.gelu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return x class EncoderLayer(nn.Module): """ 編碼器層 設計依據: - 新聞分類需要雙向上下文理解 - 殘差連接和層歸一化提升訓練穩定性 """ def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads, dropout) self.ffn = FeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: # 自注意力 + 殘差連接 + 層歸一化 attn_output = self.self_attn(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) # 前饋網絡 + 殘差連接 + 層歸一化 ffn_output = self.ffn(x) x = self.norm2(x + self.dropout(ffn_output)) return x # ==============================================
# 第二部分:完整模型組裝(根據設計決策)
# ============================================== class NewsClassifier(nn.Module): """ 新聞分類Transformer模型 設計決策回顧: - 任務類型:文本分類(Encoder-only) - 輸入:新聞文本序列 - 輸出:新聞類別(體育、科技、娛樂等) - 架構選擇:Encoder-only(無需生成能力) - 輸入表示:Token Embedding + 可學習位置編碼 - 輸出頭:[CLS] token + 分類層 """ def __init__( self, vocab_size: int, d_model: int = 768, num_heads: int = 12, num_layers: int = 6, d_ff: int = 3072, num_classes: int = 10, max_len: int = 512, dropout: float = 0.1 ): """ 參數: vocab_size: 詞匯表大小 d_model: 模型維度(默認768,與BERT-base一致) num_heads: 注意力頭數(默認12,與BERT-base一致) num_layers: 編碼器層數(默認6,平衡性能與計算成本) d_ff: FFN隱藏層維度(默認3072=4*d_model) num_classes: 分類類別數 max_len: 最大序列長度 dropout: dropout概率 """ super().__init__() self.d_model = d_model # 1. 特殊token(設計依據:BERT-style分類需要[CLS]) self.cls_token = nn.Parameter(torch.randn(1, 1, d_model)) # 2. 詞嵌入層 self.token_embedding = TokenEmbedding(vocab_size, d_model) # 3. 位置編碼(設計依據:選擇可學習位置編碼) self.pos_encoding = PositionalEncoding(d_model, max_len) # 4. 編碼器層堆疊 self.encoder_layers = nn.ModuleList([ EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) # 5. 分類頭(設計依據:使用[CLS] token進行分類) self.classifier = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, num_classes) ) self.dropout = nn.Dropout(dropout) # 權重初始化(設計依據:穩定訓練) self._init_weights() def _init_weights(self): """初始化模型權重""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def add_cls_token(self, x: torch.Tensor) -> torch.Tensor: """ 在序列開頭添加[CLS] token 設計依據:BERT-style分類使用[CLS]聚合全局信息 參數: x: 輸入張量,形狀為(batch_size, seq_len, d_model) 返回: 添加[CLS]后的張量,形狀為(batch_size, seq_len+1, d_model) """ batch_size = x.size(0) cls_tokens = self.cls_token.expand(batch_size, -1, -1) return torch.cat((cls_tokens, x), dim=1) def create_padding_mask(self, input_ids: torch.Tensor, pad_idx: int = 0) -> torch.Tensor: """ 創建padding掩碼 設計依據:處理變長序列,忽略padding位置 參數: input_ids: 輸入ID張量,形狀為(batch_size, seq_len) pad_idx: padding token的ID 返回: 掩碼張量,形狀為(batch_size, 1, 1, seq_len) True表示有效位置,False表示padding位置 (BoolTensor) """ # 創建布爾掩碼,非pad為True mask = (input_ids != pad_idx).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_len) return mask.bool() # 確保返回的是布爾類型 def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ 前向傳播 參數: input_ids: 輸入詞ID,形狀為(batch_size, original_seq_len) attention_mask: 可選的注意力掩碼,形狀為(batch_size, original_seq_len)。 1.0 表示有效位置,0.0 表示padding位置。 如果提供,應為浮點類型 (如 torch.float) 或布爾類型。 如果為 None,則根據 input_ids 自動創建。 返回: 分類logits,形狀為(batch_size, num_classes) """ batch_size, original_seq_len = input_ids.size() # 1. 詞嵌入 x = self.token_embedding(input_ids) # (batch_size, original_seq_len, d_model) # 2. 添加[CLS] token x = self.add_cls_token(x) # (batch_size, original_seq_len + 1, d_model) new_seq_len = x.size(1) # 獲取添加[CLS]后的序列長度 # 3. 位置編碼 x = self.pos_encoding(x) x = self.dropout(x) # 4. 準備注意力掩碼 (用于屏蔽padding) if attention_mask is not None: # 如果提供了 attention_mask,確保其為四維且為布爾類型 # 預期輸入形狀: (batch_size, original_seq_len) # 目標形狀: (batch_size, 1, 1, original_seq_len) if attention_mask.dim() == 2: # 假設非零值為有效位置 attention_mask_for_padding = (attention_mask != 0).unsqueeze(1).unsqueeze(2) elif attention_mask.dim() == 4: attention_mask_for_padding = (attention_mask.squeeze(1).squeeze(1) != 0).unsqueeze(1).unsqueeze(2) else: raise ValueError(f"attention_mask must be 2D or 4D, but got {attention_mask.dim()}D") else: # 如果沒有提供,根據 input_ids 自動創建 # 形狀: (batch_size, 1, 1, original_seq_len) attention_mask_for_padding = self.create_padding_mask(input_ids) # --- 關鍵修復:正確擴展 mask 以適應添加了 [CLS] token 后的新序列長度 --- # 創建一個針對新序列長度 (new_seq_len = original_seq_len + 1) 的掩碼 # [CLS] token (索引 0) 應該總是被 attend 到,所以我們需要擴展 mask # 1. 初始化一個全為 True 的新掩碼,形狀 (batch_size, 1, 1, new_seq_len) expanded_mask = torch.ones((batch_size, 1, 1, new_seq_len), dtype=torch.bool, device=x.device) # 2. 將原始 padding mask 復制到新 mask 的 [1:] 位置 (跳過 [CLS]) # 原始 mask 形狀: (batch_size, 1, 1, original_seq_len) # 新 mask 的 [1:] 部分形狀: (batch_size, 1, 1, original_seq_len) expanded_mask[:, :, :, 1:] = attention_mask_for_padding # 最終用于注意力的掩碼,形狀 (batch_size, 1, 1, new_seq_len) # 在 MultiHeadAttention 中,這個掩碼會被廣播用于屏蔽 key (src_seq) 的 padding 位置 final_attention_mask = expanded_mask # 5. 通過編碼器層 # 將擴展后的 mask 傳遞給每一層,以屏蔽 padding for layer in self.encoder_layers: x = layer(x, final_attention_mask) # 傳遞匹配新序列長度的 mask # 6. 取[CLS] token作為句子表示 cls_output = x[:, 0, :] # (batch_size, d_model) # 7. 分類 logits = self.classifier(cls_output) return logits # ==============================================
# 第三部分:訓練流程(根據設計決策)
# ============================================== def train_news_classifier(): """新聞分類模型訓練流程""" # 1. 超參數設置(根據設計決策) config = { "vocab_size": 30000, # 詞匯表大小(設計依據:新聞領域常用詞) "d_model": 768, # 模型維度(設計依據:平衡性能與計算成本) "num_heads": 12, # 注意力頭數(設計依據:與d_model匹配) "num_layers": 6, # 編碼器層數(設計依據:足夠捕捉復雜關系) "d_ff": 3072, # FFN維度(設計依據:4*d_model) "num_classes": 10, # 分類類別數(設計依據:新聞類別數量) "max_len": 512, # 最大序列長度(設計依據:覆蓋大多數新聞) "dropout": 0.1, # dropout概率(設計依據:防止過擬合) "batch_size": 32, # 批量大小(設計依據:GPU內存限制) "learning_rate": 2e-5, # 學習率(設計依據:微調預訓練模型常用值) "epochs": 3, # 訓練輪數(設計依據:避免過擬合) "warmup_steps": 500, # warmup步數(設計依據:穩定訓練初期) "weight_decay": 0.01 # 權重衰減(設計依據:正則化) } # 2. 創建模型 print("? 創建新聞分類模型...") model = NewsClassifier( vocab_size=config["vocab_size"], d_model=config["d_model"], num_heads=config["num_heads"], num_layers=config["num_layers"], d_ff=config["d_ff"], num_classes=config["num_classes"], max_len=config["max_len"], dropout=config["dropout"] ) # 3. 設備選擇 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f" 模型將運行在: {device}") # 4. 偽造數據集(實際應用中替換為真實數據) class NewsDataset(Dataset): def __init__(self, num_samples: int = 1000, max_len: int = 512): self.num_samples = num_samples self.max_len = max_len def __len__(self): return self.num_samples def __getitem__(self, idx): # 偽造新聞文本(詞ID) seq_len = min(500, 100 + idx % 400) # 變長序列 input_ids = torch.randint(1, 30000, (seq_len,)) # 偽造類別標簽(0-9) label = torch.tensor(idx % 10, dtype=torch.long) return input_ids, label # 5. 數據加載器(處理變長序列的關鍵) def collate_fn(batch): """處理變長序列的collate函數""" input_ids, labels = zip(*batch) # 找出最大長度 max_len = max(len(ids) for ids in input_ids) # padding padded_ids = [] for ids in input_ids: padding = torch.zeros(max_len - len(ids), dtype=torch.long) padded_ids.append(torch.cat([ids, padding])) input_ids = torch.stack(padded_ids) labels = torch.stack(labels) return input_ids, labels print("? 創建數據集和數據加載器...") train_dataset = NewsDataset(num_samples=1000) train_loader = DataLoader( train_dataset, batch_size=config["batch_size"], shuffle=True, collate_fn=collate_fn ) # 6. 損失函數和優化器 print("? 配置訓練組件...") loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW( model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"] ) # 7. 學習率調度器(設計依據:warmup + linear decay) total_steps = len(train_loader) * config["epochs"] warmup_steps = config["warmup_steps"] def lr_lambda(current_step: int): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return max( 0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)) ) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # 8. 訓練循環 print("🚀 開始訓練...") for epoch in range(config["epochs"]): model.train() total_loss = 0 for batch_idx, (input_ids, labels) in enumerate(train_loader): input_ids = input_ids.to(device) labels = labels.to(device) # 前向傳播 optimizer.zero_grad() logits = model(input_ids) loss = loss_fn(logits, labels) # 反向傳播 loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # 梯度裁剪 optimizer.step() scheduler.step() total_loss += loss.item() # 打印進度 if batch_idx % 50 == 0: avg_loss = total_loss / (batch_idx + 1) current_lr = optimizer.param_groups[0]['lr'] print(f"Epoch [{epoch + 1}/{config['epochs']}] | " f"Batch [{batch_idx}/{len(train_loader)}] | " f"Loss: {avg_loss:.4f} | " f"LR: {current_lr:.2e}") print(f"? Epoch {epoch + 1} 完成 | Average Loss: {total_loss / len(train_loader):.4f}") # 9. 保存模型 torch.save(model.state_dict(), "news_classifier.pth") print("💾 模型已保存至 news_classifier.pth") # ==============================================
# 第四部分:推理示例
# ============================================== def predict_news_category(text: str, model: NewsClassifier, tokenizer, device: torch.device): """ 新聞分類推理 設計依據: - 使用與訓練相同的預處理流程 - 取[CLS] token進行分類 參數: text: 新聞文本 model: 訓練好的模型 tokenizer: 文本分詞器 device: 設備 返回: 預測類別和概率 """ model.eval() # 1. 文本預處理 input_ids = tokenizer.encode(text, max_length=512, truncation=True, padding="max_length") input_ids = torch.tensor(input_ids).unsqueeze(0).to(device) # 2. 前向傳播 with torch.no_grad(): logits = model(input_ids) probs = F.softmax(logits, dim=-1) # 3. 獲取結果 predicted_class = torch.argmax(probs, dim=-1).item() confidence = probs[0, predicted_class].item() return predicted_class, confidence if __name__ == "__main__": # 這里只是演示結構,實際運行需要完整實現 print("=" * 50) print("Transformer新聞分類模型設計與實現") print("=" * 50) print("\n本示例演示了如何根據任務需求設計并實現一個Transformer模型") print("設計流程嚴格遵循:問題分析 → 架構選擇 → 組件設計 → 訓練實現") print("\n關鍵設計決策:") print("- 選擇Encoder-only架構(分類任務無需生成能力)") print("- 使用[CLS] token進行分類(BERT-style)") print("- 可學習位置編碼(更適合變長新聞文本)") print("- 6層編碼器(平衡性能與計算成本)") print("\n要運行完整訓練,請取消注釋train_news_classifier()調用") train_news_classifier()