文章出自:基于 2個Expert 的 MoE 架構分步指南
本篇適合 MoE 架構初學者。文章亮點在于詳細拆解 Qwen 3 MoE 架構,并用簡單代碼從零實現 MoE 路由器、RMSNorm 等核心組件,便于理解內部原理。
該方法適用于需部署高性能、高效率大模型,同時優化計算成本的商業場景。
例如,在智能客服中,不同專家處理特定問題,提升響應速度;或在個性化推薦中,快速生成用戶內容。
代碼都可以在: GitHub 倉庫找到
文章目錄
- 1. 前言
- 2. 了解 Qwen 3 MoE 架構
- 2.1. 使用 RMSNorm 進行預歸一化
- 2.2. SwiGLU 激活函數
- 2.3. 旋轉位置嵌入 (RoPE)
- 2.4. 字節對編碼 (BPE)
- 3. 初始化安裝
- 4. 為什么我們需要模型權重?
- 5. Tokenized文本
- 6. 創建令牌嵌入層
- 7. 使用 RMSNorm 進行規范化
- 8. 分組查詢注意力 (GQA)
- 9. 使用 RoPE
- 10. 計算注意力分數
- 11. 實現多頭注意力
- 12. 專家混合 (MoE) 塊
- 13. 合并層
- 14. 生成輸出
1. 前言
阿里巴巴的 Qwen 3 是目前僅次于 DeepSeek 的最佳開源 MoE AI 模型,擅長推理、編碼、數學和語言。其頂級版本在 MMLU-Pro、LiveCodeBench 和 AIME 等關鍵測試中表現出色。
在這篇博客中,我們將使用 2 位專家構建一個微型 Qwen-3 MoE,而不使用面向對象編程(OOP)原則……
因此,我們可以一次查看并理解一個矩陣乘法。
Qwen 3 采用混合專家(MoE)架構構建,每次查詢僅激活其 2350 億參數中的一個子集,從而在不犧牲質量的情況下實現高效率。它還支持高達 128K 標記上下文,處理 119 種語言,并引入了雙重“思考”與“非思考”模式,以平衡深度推理和更快的推理。
我們的 Qwen 模型擁有 8 億參數。
所有代碼(理論 + 筆記本)都可以在我的 GitHub 倉庫中找到。
正如我所說,我們不會使用面向對象編程(OOP)編碼,而只使用簡單的 Python 編程。但是,您應該對神經網絡和 Transformer 架構有基本的了解。
這是遵循本博客所需的僅有的兩個先決條件。
2. 了解 Qwen 3 MoE 架構
我們首先以中級技術人員的身份了解 Qwen MoE 架構,然后使用一個例子“貓坐”來了解它如何通過架構,從而獲得清晰的理解。
Qwen 3 MoE 架構(來自 Sebastian Raschka)
想象一下你有一項非常艱巨的工作。你不是雇傭一個對所有事情都“略知一二”的人,而是雇傭一個專家團隊,每個人都擅長某一項特定技能(比如電工、水管工、油漆工)。你還會雇傭一個經理,他會查看當前任務并將其發送給合適的專家。
AI 模型中的 MoE 有點像這樣。MoE 層不是一個試圖學習所有內容的龐大神經網絡,而是包含:
- “專家”團隊:這些是更小、更專業的神經網絡(通常是簡單的前饋網絡或 MLP)。每個專家可能擅長處理某些類型的信息或模式。
- “路由器”(經理):這是另一個小型網絡。它的工作是查看輸入數據(如一個詞或詞的一部分),并決定哪些專家最適合立即處理它。
想象一下我們的模型正在處理句子:“The cat sat.”
- 標記:首先,我們將其分解成小塊(標記):“The”、“cat”、“sat”。
- 路由器獲取標記:MoE 層接收標記
cat
(表示為一串數字,一個嵌入向量)。路由器
查看這個cat
向量。 - 路由器選擇:假設我們有 4 位專家(
E1
、E2
、E3
、E4
)。路由器
決定哪些最適合cat
。 - 也許它認為
E2
(可能擅長名詞?)和E4
(可能擅長動物概念?)是最佳選擇。它為這些選擇賦予分數或“權重”(例如,E2
為 70%,E4
為 30%)。
路由器如何決定(由 Fareed Khan 創建)
cat
向量僅發送到 專家 2
和 專家 4
。專家 1
和 專家 3
不對此標記執行任何工作,從而節省了計算!E2
處理 cat
并生成其結果(Output_E2
)。E4
處理 cat
并生成其結果(Output_E4
)。
貓詞精選專家(由 Fareed Khan 創建)
我們現在使用 路由器
權重組合所選專家的結果:Final_Output = (0.7 * Output_E2) + (0.3 * Output_E4).
這個 Final_Output
是 MoE
層為標記 cat
傳遞的內容。序列中的每個標記都會發生這種情況!不同的標記可能會被路由到不同的專家。
因此,當我們的模型處理像“The cat sat.”
這樣的文本時,整個過程如下所示:
輸入文本
進入 分詞器
。分詞器
創建數字標記 ID
。嵌入層
將 ID 轉換為有意義的數字向量(嵌入
)并添加位置信息
(稍后在注意力中使用 RoPE
)。
這些向量通過多個 Transformer 塊
。每個塊都有:
自注意力
(其中標記相互關注,由RoPE
增強)。MoE 層
(其中路由器
將標記發送到特定的專家
)。歸一化
(RMSNorm
)和殘差連接
有助于學習。
最后一個塊的輸出進入 最終層
。這一層為我們詞匯表中的每個可能的下一個標記生成 Logits
(分數)。
我們將 logits
轉換為 概率
并 預測下一個標記
。
現在我們已經了解了 MoE
如何融入整體,接下來讓我們深入了解每個 AI 模型中的較小組件。
2.1. 使用 RMSNorm 進行預歸一化
RMSNorm(均方根歸一化)應用于每個 Transformer 子層(注意力或前饋)之前。
它根據輸入的均方根縮放輸入,而不減去均值(與 LayerNorm 不同)。這有助于穩定訓練并在早期保持重要信號的強度,就像在深入研究教科書之前復習關鍵章節一樣。
均方根層歸一化論文 (https://arxiv.org/abs/1910.07467)
感興趣的讀者可以在此處探索 RMSNorm 的詳細實現。
2.2. SwiGLU 激活函數
SwiGLU(Swish + 門控線性單元)增強了模型強調重要特征的能力。
它使用帶有 Swish 激活的門控機制,這有助于控制哪些信息通過。
SwiGLU:GLU 變體改進 Transformer (https://kikaben.com/swiglu-2020/)
將其視為一個智能熒光筆,它使關鍵部分在處理過程中更加突出。
它在 PaLM 中引入,現在用于 LLaMA 3/Qwen 3 以獲得更好的性能。有關 SwiGLU 的更多詳細信息可以在相關論文中找到。
2.3. 旋轉位置嵌入 (RoPE)
RoPE 使用正弦函數和旋轉扭曲來編碼標記位置,使嵌入能夠“旋轉”以反映相對位置。
RoPE 公式(由 Fareed Khan 創建)
與固定位置嵌入不同,RoPE 支持更長的上下文和對未見位置的更好泛化。
想象一下學生在一個圓圈中移動,他們的位置會發生變化,但他們的相對距離保持不變。這有助于模型更靈活地跟蹤詞序。
2.4. 字節對編碼 (BPE)
BPE 通過合并頻繁的字符對(如“th”、“ing”)來構建標記,使模型能夠更有效地處理不常見或新詞。
BPE(來自 langformer blog)
Qwen 3 使用 BPE,它傾向于完整的已知詞(例如,“hugging”如果在詞匯表中,則保持完整)。
而 LLaMA 3 使用 SentencePiece BPE,它可能會將同一個詞拆分成多個部分(“hug”+“ging”)。這種差異會影響分詞速度以及模型理解文本的方式。
3. 初始化安裝
我們將使用少量 Python 庫,但最好安裝它們以避免遇到**“未找到模塊”**錯誤。
pip install sentencepiece tiktoken torch matplotlib huggingface_hub tokenizers safetensors
安裝完所需的庫后,我們需要下載 Qwen 3 架構權重和配置文件,這些文件將在本指南中用到。
我們正在針對一個較小的 Qwen 3 MoE 版本,其中包含兩個專家,每個專家有 0.8B 參數。必要的文件是 Qwen 3 架構的骨干。有兩種方法可以實現這一點。
(選項 1:手動) 轉到 Qwen-0.8B-2E HF 目錄并手動下載這四個文件中的每一個。
(選項 2:編碼) 我們可以使用 huggingface_hub
的 snapshot_download
模塊下載 Qwen 3 MoE 模型的整個 Hugging Face 倉庫。我們采用這種方法。
from tqdm import tqdm
from huggingface_hub import snapshot_downloadrepo_id = "huihui-ai/Huihui-MoE-0.8B-2E"
local_dir = "Huihui-MoE-0.8B-2E"snapshot_download(repo_id=repo_id,local_dir=local_dir,ignore_patterns=["*.bin"],tqdm_class=tqdm
)
下載所有文件后,我們需要導入將在本博客中使用的庫。
import torch
import torch.nn as nnfrom huggingface_hub import snapshot_download
from tokenizers import Tokenizer
from safetensors.torch import load_fileimport json
from pathlib import Path
from tqdm import tqdmimport matplotlib.pyplot as plt
接下來,我們需要了解每個文件的用途。
4. 為什么我們需要模型權重?
由于我們旨在精確復制 Qwen 3 MoE,這意味著我們的輸入文本必須產生有意義的輸出。
例如,如果我們的輸入是**“太陽的顏色是?”** ,輸出必須是**“白色”**。
實現這一點需要在大規模數據集上訓練我們的 LLM,這需要高計算能力,對我們來說是不可行的。
然而,阿里巴巴已經公開了他們的 Qwen 3 架構文件,或者更復雜地說,他們預訓練的權重供使用。我們剛剛下載了這些文件,這使我們能夠復制他們的架構,而無需訓練或大量數據集。一切都已準備就緒,我們只需在正確的位置使用正確的組件。
tokenizer.json — Qwen 3 使用字節對編碼(BPE),Andrej Karpathy 有一個非常簡潔的 BPE 實現。
tokenizer_path = Path("Huihui-MoE-0.8B-2E/tokenizer.json")tokenizer = Tokenizer.from_file(str(tokenizer_path))with open("Huihui-MoE-0.8B-2E/special_tokens_map.json", "r") as f:special_tokens_map = json.load(f)print(f"Special tokens from file: {special_tokens_map}")
Special tokens from file: {
'additional_special_tokens': ['<|im_start|>',
'<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>'
...
}
這些特殊標記將用于包裝我們的提示,以指導我們的 Qwen 3 架構如何響應我們的查詢。
# We'll follow the encode -> decode pattern to ensure it works correctly.
prompt = "The only thing I know is that I know"# .encode() returns an Encoding object, we access the token IDs via .ids
encoded = tokenizer.encode(prompt)
print(f"\nOriginal prompt: '{prompt}'")
print(f"Encoded token IDs: {encoded.ids}")# .decode() converts the token IDs back to a string.
decoded = tokenizer.decode(encoded.ids)
print(f"Decoded back to text: '{decoded}'")# Verify the vocabulary size
vocab_size = tokenizer.get_vocab_size()
print(f"\nTokenizer vocabulary size: {vocab_size}")#### OUTPUT ####
Original prompt: 'The only thing I know is that I know'
Encoded token IDs: [785, 1172, 3166, 358, 1414, 374, 429, 358, 1414]
Decoded back to text: 'The only thing I know is that I know'
Tokenizer vocabulary size: 151669
詞匯量大小表示訓練數據中唯一字符的數量。tokenizer
的類型是一個字典。
# Get the vocabulary as a dictionary: {token_string: token_id}
vocab = tokenizer.get_vocab()# Display a slice of the vocabulary for inspection (tokens 5600 to 5609)
sample_vocab_slice = list(vocab.items())[5600:5610]
sample_vocab_slice#### OUTPUT ####
[('í??', 129382),('?Brands', 54232),('?incorporates', 51824),('à??à?£à?°à?£à?2à??', 132851),('?Resource', 79487),('??????', 80840),('hover', 17583),('Movement', 38050),('è§£??3?o?', 105826),('?onBackPressed', 70609)]
當我們從中打印 10 個隨機項時,您會看到使用 BPE 算法形成的字符串。鍵表示來自 BPE 訓練的字節序列,而值表示基于頻率的合并排名。
config.json — 包含各種參數值,例如:
# Define the path to the configuration file.
config_path = Path("Huihui-MoE-0.8B-2E/config.json")# Open and load the JSON file into a Python dictionary.
with open(config_path, "r") as f:config = json.load(f)# Print the configuration to see all the parameters.
# This gives us a complete overview of the model we're about to build.
print(json.dumps(config, indent=4))#### OUTPUT ####
{"architectures": ["Qwen3MoeForCausalLM"],"attention_bias": false,"attention_dropout": 0.0,"bos_token_id": 151643,"decoder_sparse_step": 1,"eos_token_id": 151645,"head_dim": 128,"hidden_act": "silu",..."transformers_version": "4.52.4","use_cache": true,"use_sliding_window": false,"vocab_size": 151936
}
這些值將通過指定注意力頭數、嵌入向量維度、專家數量等細節來幫助我們復制 Qwen-3 架構。
讓我們存儲這些值,以便以后使用。
# --- Main Architecture Parameters ---
# Extract model hyperparameters from the config dictionary.# Embedding dimension (hidden size of the model)
dim = config["hidden_size"]
# Number of transformer layers
n_layers = config["num_hidden_layers"]
# Number of attention heads
n_heads = config["num_attention_heads"]
# Number of key/value heads (for grouped-query attention)
n_kv_heads = config["num_key_value_heads"]
# Vocabulary size
vocab_size = config["vocab_size"]
# RMSNorm epsilon value for numerical stability
norm_eps = config["rms_norm_eps"]
# Rotary positional embedding theta parameter
rope_theta = torch.tensor(config["rope_theta"])
# Dimension of each attention head
head_dim = config["head_dim"] # For attention calculations# --- Mixture-of-Experts (MoE) Specific Parameters ---
# Number of experts in the MoE layer
num_experts = config["num_experts"]
# Number of experts selected per token by the router
num_experts_per_tok = config["num_experts_per_tok"]
# Intermediate size of the MoE feed-forward network
moe_intermediate_size = config["moe_intermediate_size"]
model.safetensors — 包含 Qwen 0.8B 2 專家模型的學習參數(權重)。這些參數包含模型如何理解和處理語言的信息,例如它如何表示標記、計算注意力、執行專家選擇以及歸一化其輸出。
model_weights_path = Path("Huihui-MoE-0.8B-2E/model.safetensors")model_weights = load_file(model_weights_path)print("First 20 keys in model_weights:")
print(json.dumps(list(model_weights.keys())[:20], indent=4))
OUTPUT:
["model.embed_tokens.weight","model.layers.0.input_layernorm.weight","model.layers.0.mlp.experts.0.down_proj.weight","model.layers.0.mlp.experts.0.gate_proj.weight","model.layers.0.mlp.experts.0.up_proj.weight","model.layers.0.mlp.experts.1.down_proj.weight",..."model.layers.1.mlp.experts.0.gate_proj.weight","model.layers.1.mlp.experts.0.up_proj.weight"...
]
如果您熟悉 Transformer 架構,您就會知道查詢、鍵矩陣等等。稍后,我們將使用這些層/權重來創建這些矩陣以及 Qwen 3 MoE 架構中的 MoE 組件。
現在我們有了分詞器模型、包含權重的架構模型和配置參數,讓我們開始從頭開始編碼我們自己的 Qwen 3 MoE。
5. Tokenized文本
標記化輸入文本(由 Fareed Khan 創建)
第一步是將我們的輸入文本轉換為標記。Qwen 3 使用帶有特殊標記(如 <|im_start|>
和 <|im_end|>
)的特定聊天模板來構建對話。這有助于模型區分用戶查詢和它自己的響應。
prompt = "The only thing I know is that I know"im_start_id = tokenizer.token_to_id("<|im_start|>")
im_end_id = tokenizer.token_to_id("<|im_end|>")
newline_id = tokenizer.encode("\n").ids[0]
user_ids = tokenizer.encode
````python
assistant_ids = tokenizer.encode("assistant").ids
prompt_ids = tokenizer.encode(prompt).idsprefix_ids = [im_start_id] + user_ids + [newline_id]
suffix_ids = [im_end_id, newline_id, im_start_id] + assistant_ids + [newline_id]
tokens_list = prefix_ids + prompt_ids + suffix_idstokens = torch.tensor(tokens_list)print(f"Final combined token IDs: {tokens}")prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
print(f"\nPrompt split into tokens: {prompt_split_as_tokens}")
OUTPUT:
Final combined token IDs: tensor([151644, 872, ... , 8])
Prompt split into tokens: ['', 'user', '\n', 'The', ..., '\n']
我們現在已經將提示轉換為一個包含 17 個標記的結構化列表,準備好供模型使用。
6. 創建令牌嵌入層
生成標記化文本的嵌入(由 Fareed Khan 創建)
嵌入是一個密集向量,用于在高維空間中表示標記的含義。我們的 17 個標記的輸入向量需要轉換為 [17, 1024]
的張量,其中 1024 (dim
) 是嵌入維度。
embedding_layer = nn.Embedding(vocab_size, dim)embedding_layer.weight.data.copy_(model_weights["model.embed_tokens.weight"])token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)print("Shape of the token embeddings:", token_embeddings_unnormalized.shape)
OUTPUT
Shape of the token embeddings: torch.Size([17, 1024])
這些嵌入未歸一化,如果我們不進行歸一化,將產生嚴重影響。在下一節中,我們將對輸入向量執行歸一化。
7. 使用 RMSNorm 進行規范化
我們將定義 rms_norm
函數,它根據輸入的均方根值縮放輸入。這是我們 Transformer 層中的第一個預歸一化步驟。
均方根層歸一化論文 (https://arxiv.org/abs/1910.07467)
def rms_norm(tensor, norm_weights):input_dtype = tensor.dtypetensor_float = tensor.to(torch.float32)variance = tensor_float.pow(2).mean(-1, keepdim=True)normalized_tensor = tensor_float * torch.rsqrt(variance + norm_eps)return (normalized_tensor * norm_weights).to(input_dtype)
我們將使用 layers_0
的注意力權重來歸一化我們未歸一化的嵌入。使用 layer_0
的原因是,我們現在正在創建 Qwen 3 架構的第一層。
token_embeddings_normalized = rms_norm(token_embeddings_unnormalized,model_weights["model.layers.0.input_layernorm.weight"]
)
print("Shape of the normalized token embeddings:", token_embeddings_normalized.shape)
Shape of the normalized token embeddings: torch.Size([17, 1024])
形狀保持不變,但值現在已歸一化,并準備好用于注意力機制。
8. 分組查詢注意力 (GQA)
接下來,我們生成查詢 (Q)、鍵 (K) 和值 (V) 向量。預訓練權重存儲在大的組合矩陣中。我們需要重塑它們以分離出我們 16 個注意力頭的每個頭的權重。
分組查詢注意力 (GQA)(由 Fareed Khan 創建)
該模型使用一種稱為分組查詢注意力 (GQA) 的優化,其中多個查詢頭 (16) 共享少量鍵和值頭 (8)。這在不顯著降低性能的情況下減少了計算負載。
q_layer0 = model_weights["model.layers.0.self_attn.q_proj.weight"]
q_layer0 = q_layer0.view(n_heads, head_dim, dim)k_layer0 = model_weights["model.layers.0.self_attn.k_proj.weight"]
k_layer0 = k_layer0.view(n_kv_heads, head_dim, dim)v_layer0 = model_weights["model.layers.0.self_attn.v_proj.weight"]
v_layer0 = v_layer0.view(n_kv_heads, head_dim, dim)
現在,讓我們通過將歸一化嵌入乘以頭的權重來計算第一個頭的 Q、K 和 V 向量。
q_layer0_head0 = q_layer0[0]
k_layer0_head0 = k_layer0[0]
v_layer0_head0 = v_layer0[0]q_per_token = torch.matmul(token_embeddings_normalized, q_layer0_head0.T)
k_per_token = torch.matmul(token_embeddings_normalized, k_layer0_head0.T)
v_per_token = torch.matmul(token_embeddings_normalized, v_layer0_head0.T)print("Shape of Query vectors per token:", q_per_token.shape)
Shape of Query vectors per token: torch.Size([17, 128])
我們 17 個標記中的每個標記現在都有一個 128 維的 Q、K 和 V 向量,用于第一個頭。
9. 使用 RoPE
這些向量尚未知道它們的位置。我們將使用 RoPE 通過“旋轉”它們來注入這些信息。為了提高效率,我們可以預先計算所有可能位置(直到最大序列長度)的旋轉角度。
RoPE 實現(由 Fareed Khan 創建)
這將創建一個旋轉矩陣的查找表,表示為復數。
max_seq_len = config["max_position_embeddings"]
freqs = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2) / head_dim))
t = torch.arange(max_seq_len)
freqs_for_each_token = torch.outer(t, freqs)freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
這個 freqs_cis
張量現在包含將執行旋轉的復數。我們可以可視化單個標記的旋轉,以查看每個 2D 維度對如何以不同的角度旋轉。
單個標記位置上每個 2D 對的 RoPE 旋轉(由 Fareed Khan 創建)
現在,我們將這些旋轉應用于我們的 Q 和 K 向量。通過將向量視為復數并執行逐元素乘法來執行旋轉。
freqs_cis_for_tokens = freqs_cis[:len(tokens)]q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token.float().view(q_per_token.shape[0], -1, 2))
q_per_token_rotated_complex = q_per_token_as_complex_numbers * freqs_cis_for_tokens
q_per_token_rotated = torch.view_as_real(q_per_token_rotated_complex).view(q_per_token.shape)k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token.float().view(k_per_token.shape[0], -1, 2))
k_per_token_rotated_complex = k_per_token_as_complex_numbers * freqs_cis_for_tokens
k_per_token_rotated = torch.view_as_real(k_per_token_rotated_complex).view(k_per_token.shape)print("Shape of rotated Query vectors:", q_per_token_rotated.shape)
Shape of rotated Query vectors: torch.Size([17, 128])
10. 計算注意力分數
現在我們通過計算查詢和鍵矩陣的點積來計算注意力分數。這將創建一個 [17, 17]
矩陣,顯示每個標記應該“關注”其他每個標記的程度。
我們通過頭維度的平方根來縮放分數,以穩定訓練。
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)qk_per_token_scaled = qk_per_token / (head_dim**0.5)
我們可以將這些原始分數可視化為熱圖。
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)qk_per_token_scaled = qk_per_token / (head_dim**0.5)def display_qk_heatmap(qk_matrix, title="Attention Heatmap"):_, ax = plt.subplots()im = ax.imshow(qk_matrix.to(torch.float32).detach(), cmap='viridis')ax.set_xticks(range(len(prompt_split_as_tokens)))ax.set_yticks(range(len(prompt_split_as_tokens)))ax.set_xticklabels(prompt_split_as_tokens, rotation=90)ax.set_yticklabels(prompt_split_as_tokens)ax.figure.colorbar(im, ax=ax)plt.title(title)plt.show()display_qk_heatmap(qk_per_token_scaled, title="Raw Attention Scores (Before Masking)")
原始注意力分數(掩碼前)
為了防止標記在這種自回歸模型中“看到”未來,我們應用因果掩碼。這將矩陣上三角形中的所有分數設置為負無窮大,因此它們在 softmax 函數后變為零。
mask = torch.full((len(tokens), len(tokens)), float("-inf"))
mask = torch.triu(mask, diagonal=1)qk_per_token_masked = qk_per_token_scaled + mask
如果我們看看掩碼矩陣的樣子。
print(mask)
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
掩碼后的注意力分數
最后,我們應用 softmax 函數將這些分數轉換為概率(注意力權重),并將它們乘以值矩陣。這將產生值的加權和,為我們提供此注意力頭的最終輸出。
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_masked.float(), dim=1).to(torch.bfloat16)qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)print("Shape of the final attention output for Head 0:", qkv_attention.shape)
Shape of the final attention output for Head 0: torch.Size([17, 128])
最終注意力權重(由 Fareed Khan 創建)
輸出是一個新的 [17, 128]
張量,其中每個標記的向量現在包含來自所有先前標記的上下文信息。
11. 實現多頭注意力
我們現在在一個循環中對所有 16 個頭重復自注意力過程。每個頭的輸出([17, 128]
張量)被收集到一個列表中。
多頭注意力(由 Fareed Khan 創建)
qkv_attention_store = []for head in range(n_heads):q_layer0_head = q_layer0[head]k_layer0_head = k_layer0[head // (n_heads // n_kv_heads)]v_layer0_head = v_layer0[head // (n_heads // n_kv_heads)]q_per_token = torch.matmul(token_embeddings_normalized, q_layer0_head.T)k_per_token = torch.matmul(token_embeddings_normalized, k_layer0_head.T)v_per_token = torch.matmul(token_embeddings_normalized, v_layer0_head.T)q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis_for_tokensq_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)k_per_token_as_complex_numbers_rotated = k_per_token_as_complex_numbers * freqs_cis_for_tokensk_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers_rotated)k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (head_dim**0.5)qk_per_token_masked = qk_per_token + maskqk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_masked.float(), dim=1).to(torch.bfloat16)qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)qkv_attention_store.append(qkv_attention)
循環結束后,我們將 16 個頭的輸出連接成一個大小為 [17, 2048]
的大張量。然后使用輸出權重矩陣 o_proj
將其投影回模型的維度 (1024
)。
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)w_layer0 = model_weights["model.layers.0.self_attn.o_proj.weight"]embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T)
結果 embedding_delta
被加回到層的原始輸入中。這是第一個殘差連接,這是一項關鍵技術,通過允許梯度更輕松地流動,有助于訓練非常深的神經網絡。
embedding_after_attention = token_embeddings_unnormalized + embedding_delta
12. 專家混合 (MoE) 塊
這是 Transformer 塊的第二個子層。首先,我們對其輸入應用預歸一化。
Qwen 3 MoE 層(由 Fareed Khan 創建)
embedding_after_attention_normalized = rms_norm(embedding_after_attention,model_weights["model.layers.0.post_attention_layernorm.weight"]
)
接下來,路由器(一個簡單的線性層)計算分數以確定每個標記應該發送到兩個專家中的哪一個。
gate = model_weights["model.layers.0.mlp.gate.weight"]
router_logits = torch.matmul(embedding_after_attention_normalized, gate.T)routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=1).to(torch.bfloat16)
routing_expert_indices = torch.argmax(routing_weights, dim=1)print("Router logits shape:", router_logits.shape)
print("Expert chosen for each of the 17 tokens:", routing_expert_indices)
Router logits shape: torch.Size([17, 2])
Expert chosen for each of the 17 tokens: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
在這種情況下,路由器決定將所有 17 個標記發送給專家 1。我們現在通過每個標記選擇的專家的前饋網絡 (FFN) 處理每個標記的嵌入,并根據路由器的概率加權組合結果。
expert0_w1 = model_weights["model.layers.0.mlp.experts.0.gate_proj.weight"]
expert0_w2 = model_weights["model.layers.0.mlp.experts.0.down_proj.weight"]
expert0_w3 = model_weights["model.layers.0.mlp.experts.0.up_proj.weight"]expert1_w1 = model_weights["model.layers.0.mlp.experts.1.gate_proj.weight"]
expert1_w2 = model_weights["model.layers.0.mlp.experts.1.down_proj.weight"]
expert1_w3 = model_weights["model.layers.0.mlp.experts.1.up_proj.weight"]final_expert_output = torch.zeros_like(embedding_after_attention_normalized)for i, token_embedding in enumerate(embedding_after_attention_normalized):chosen_expert_index = routing_expert_indices[i]if chosen_expert_index == 0:w1, w2, w3 = expert0_w1, expert0_w2, expert0_w3else:w1, w2, w3 = expert1_w1, expert1_w2, expert1_w3silu_output = torch.nn.functional.silu(torch.matmul(token_embedding, w1.T))gated_output = silu_output * torch.matmul(token_embedding, w3.T)expert_output = torch.matmul(gated_output, w2.T)final_expert_output[i] = expert_output * routing_weights[i, chosen_expert_index]
最后,我們將 MoE 塊的輸出添加回注意力塊的輸出。這是第二個殘差連接,完成了 Transformer 層。
layer_0_embedding = embedding_after_attention + final_expert_output
13. 合并層
現在我們有了所有組件,我們可以通過遍歷所有 28 層來構建完整的模型。
一層的輸出成為下一層的輸入。
合并一切(來自 Sebastian Raschka)
final_embedding = token_embeddings_unnormalizedfor layer in range(n_layers):attention_input = rms_norm(final_embedding, model_weights[f"model.layers.{layer}.input_layernorm.weight"])q_layer = model_weights[f"model.layers.{layer}.self_attn.q_proj.weight"].view(n_heads, head_dim, dim)k_layer = model_weights[f"model.layers.{layer}.self_attn.k_proj.weight"].view(n_kv_heads, head_dim, dim)v_layer = model_weights[f"model.layers.{layer}.self_attn.v_proj.weight"].view(n_kv_heads, head_dim, dim)w_layer = model_weights[f"model.layers.{layer}.self_attn.o_proj.weight"]qkv_attention_store = []for head in range(n_heads):q_layer_head = q_layer[head]k_layer_head = k_layer[head // (n_heads // n_kv_heads)]v_layer_head = v_layer[head // (n_heads // n_kv_heads)]q_per_token = torch.matmul(attention_input, q_layer_head.T)k_per_token = torch.matmul(attention_input, k_layer_head.T)v_per_token = torch.matmul(attention_input, v_layer_head.T)q_per_token_rotated = torch.view_as_real(torch.view_as_complex(q_per_token.float().view(q_per_token.shape[0], -1, 2)) * freqs_cis_for_tokens).view(q_per_token.shape)k_per_token_rotated = torch.view_as_real(torch.view_as_complex(k_per_token.float().view(k_per_token.shape[0], -1, 2)) * freqs_cis_for_tokens).view(k_per_token.shape)qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (head_dim**0.5)qk_per_token_masked = qk_per_token + maskqk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_masked.float(), dim=1).to(torch.bfloat16)qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)qkv_attention_store.append(qkv_attention)stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)embedding_after_attention = final_embedding + embedding_deltamoe_input = rms_norm(embedding_after_attention, model_weights[f"model.layers.{layer}.post_attention_layernorm.weight"])gate = model_weights[f"model.layers.{layer}.mlp.gate.weight"]router_logits = torch.matmul(moe_input, gate.T)routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=1).to(torch.bfloat16)routing_expert_indices = torch.argmax(routing_weights, dim=1)final_expert_output = torch.zeros_like(moe_input)expert0_w1 = model_weights[f"model.layers.{layer}.mlp.experts.0.gate_proj.weight"]expert0_w2 = model_weights[f"model.layers.{layer}.mlp.experts.0.down_proj.weight"]expert0_w3 = model_weights[f"model.layers.{layer}.mlp.experts.0.up_proj.weight"]expert1_w1 = model_weights[f"model.layers.{layer}.mlp.experts.1.gate_proj.weight"]expert1_w2 = model_weights[f"model.layers.{layer}.mlp.experts.1.down_proj.weight"]expert1_w3 = model_weights[f"model.layers.{layer}.mlp.experts.1.up_proj.weight"]for i, token_embedding in enumerate(moe_input):chosen_expert_index = routing_expert_indices[i]if chosen_expert_index == 0:w1, w2, w3 = expert0_w1, expert0_w2, expert0_w3else:w1, w2, w3 = expert1_w1, expert1_w2, expert1_w3silu_output = torch.nn.functional.silu(torch.matmul(token_embedding, w1.T))gated_output = silu_output * torch.matmul(token_embedding, w3.T)expert_output = torch.matmul(gated_output, w2.T)final_expert_output[i] = expert_output * routing_weights[i, chosen_expert_index]final_embedding = embedding_after_attention + final_expert_outputprint("Shape of the final embeddings after all layers:", final_embedding.shape)
Shape of the final embeddings after all layers: torch.Size([17, 1024])
14. 生成輸出
我們現在有了最終嵌入,它代表了模型對下一個標記的預測。其形狀為 [17, 1024]
。首先,我們應用最后一次 RMSNorm。
final_embedding_normalized = rms_norm(final_embedding, model_weights["model.norm.weight"])
為了獲得最終預測,我們只需要序列中最后一個標記的嵌入。我們將這個 [1024]
向量乘以語言模型頭權重(與標記嵌入權重綁定),以獲得詞匯表中每個單詞的分數,即 logits。
lm_head_weights = model_weights["model.embed_tokens.weight"]last_token_embedding = final_embedding_normalized[-1]logits = torch.matmul(last_token_embedding, lm_head_weights.T)print("Shape of the final logits :", logits.shape)
Shape of the final logits: torch.Size([151936])
具有最高 logit 的標記是模型的預測。我們使用 argmax
來找到其索引。
next_token_id = torch.argmax(logits, dim=-1)
print(f"Predicted Token ID: {next_token_id.item()}")predicted_word = tokenizer.decode([next_token_id.item()])
print(f"\nPredicted Word: '{predicted_word}'")
Predicted Token ID: 12454
Predicted Word: 'nothing'
因此,在提示 ...assistant\n
之后,模型對下一個詞的最佳猜測是“nothing”。這只是一個單標記生成,但它表明我們從頭開始實現的 Qwen 3 MoE 架構正在正確運行。
您可以通過簡單地更改開頭的 prompt
變量并調整標記張量構造來嘗試不同的輸入文本。