Decoder-Only整體結構
我們以模型Llama-3.1-8B-Instruct
為例,打印其結構如下(后面會慢慢解析每一部分,莫慌):
LlamaForCausalLM((model): LlamaModel((embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(layers): ModuleList((0-31): 32 x LlamaDecoderLayer((self_attn): LlamaAttention((qkv_proj): QKVParallelLinear(in_features=4096, output_features=6144, bias=False, tp_size=1, gather_output=False)(o_proj): RowParallelLinear(input_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True)(rotary_emb): Llama3RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=500000.0, is_neox_style=True)(attn): RadixAttention())(mlp): LlamaMLP((gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=28672, bias=False, tp_size=1, gather_output=False)(down_proj): RowParallelLinear(input_features=14336, output_features=4096, bias=False, tp_size=1, reduce_results=True)(act_fn): SiluAndMul())(input_layernorm): RMSNorm()(post_attention_layernorm): RMSNorm()))(norm): RMSNorm())(lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(logits_processor): LogitsProcessor()(pooler): Pooler()
)
Decoder-Only處理流程
我們以Llama-3.1-8B-Instruct模型為例,結合一個具體的聊天對話場景,詳細說明Decoder-Only模型的處理流程,從用戶輸入到最終輸出回答。整個過程會逐步拆解,并標注每個步驟的輸入輸出形狀(假設batch_size=1,seq_len=10,hidden_dim=4096,詞表大小=128000)。
1. 用戶輸入與聊天模板處理
場景:用戶問:“如何做西紅柿炒雞蛋?”
模型需求:需要根據歷史對話和當前問題生成回答。
聊天模板處理
- 輸入文本text:原始用戶輸入(如“如何做西紅柿炒雞蛋?”)
- 模板化prompt:模型需要將輸入包裝成特定格式的prompt,例如:
[系統指令]:你是一個烹飪助手,請回答以下問題。 [用戶]:如何做西紅柿炒雞蛋? [助手]:
- 作用:模板化prompt讓模型明確任務目標(如回答問題),并模擬對話上下文。
輸入輸出形狀:
- 輸入文本長度:假設為10個字符(實際長度取決于具體輸入)。
- 模板化后的prompt長度:假設為30個字符(包含系統指令、用戶問題和占位符)。
2. Tokenizer處理:從prompt到input_ids
步驟:
- Tokenization:將模板化prompt拆分為模型能理解的Token(如“西紅柿”→“西紅柿”,“炒”→“炒”)。
- 映射到input_ids:每個Token被映射為對應的ID(例如,“西紅柿”→1234,“炒”→5678)。
示例:
假設模板化Prompt被拆分為10個Token,其input_ids為:
[101, 1234, 5678, 8901, 2345, 6789, 102, 3456, 7890, 102]
(其中101和102是特殊標記,如<BOS>
和<EOS>
,表示開始和結束)
輸入輸出形狀:
input_ids
的形狀為(batch_size, seq_len)
→(1, 10)
attention_mask
(可選)的形狀為(1, 10)
,標記哪些位置是有效Token(1)或填充(0)。
3. 嵌入層:input_ids → hidden_states
步驟:
- Token Embedding:將input_ids映射為高維向量(如4096維)。
- Positional Encoding:添加位置信息,讓模型知道每個Token在序列中的位置。
示例:
- input_ids
[101, 1234, 5678, ...]
→ 隱藏狀態hidden_states
的形狀為(1, 10, 4096)
。 - 每個Token對應的向量包含其語義和位置信息(例如,“西紅柿”對應的食物相關特征,以及它在句子中的位置)。
輸入輸出形狀:
hidden_states
的形狀為(batch_size, seq_len, hidden_dim)
→(1, 10, 4096)
4. Decoder Block處理:逐層計算
核心流程:
-
Masked Self-Attention(帶掩碼的自注意力):
- 每個Token只能看到自己及之前的Token(防止“偷看”未來內容)。
- 例如,在生成“西紅柿炒雞蛋”時,模型會先處理“西紅柿”,再處理“炒”,確保生成邏輯連貫。
-
前饋網絡(FFN):
- 對每個Token的隱藏狀態進行非線性變換,增強表達能力。
示例:
- 假設模型有32層Decoder Block,每層都會更新
hidden_states
。 - 最終的
hidden_states
保留了完整的上下文信息(如“西紅柿炒雞蛋”的步驟描述)。
輸入輸出形狀:
- 每層Decoder Block的輸入輸出形狀不變,仍為
(1, 10, 4096)
5. LM Head:從hidden_states到下一個詞
步驟:
- 線性層:將最后一個Token的隱藏狀態(形狀為
(1, 10, 4096)
)映射到詞表維度(128000)。- 例如,對最后一個位置(
seq_len=9
)的隱藏狀態取值:hidden_states[:, 9, :]
→ 形狀(1, 4096)
。
- 例如,對最后一個位置(
- Softmax:將輸出轉換為概率分布(每個詞的概率)。
示例:
- 假設模型預測下一個詞是“步驟一”,其ID為9876,則概率分布中9876的值最高。
輸入輸出形狀:
- 線性層輸出形狀:
(1, 128000)
- 概率分布形狀:
(1, 128000)
6. 采樣策略:從概率分布到下一個詞
方法:
- Top-k采樣:從概率最高的前k個詞(如k=50)中隨機選一個。
- Greedy Search:直接選概率最高的詞(如“步驟一”)。
示例:
- 模型選擇“步驟一”作為下一個詞,并將其ID(9876)添加到
input_ids
中。 - 新的
input_ids
變為:[101, 1234, 5678, ..., 9876]
(長度+1)。
輸入輸出形狀:
- 新的
input_ids
形狀為(1, 11)
7. 迭代生成:重復步驟3-6直到完成
流程:
- 將新的
input_ids
和hidden_states
送回Decoder Block。 - 重復計算,逐步生成完整回答(如“步驟一:熱鍋涼油…”)。
- 直到生成終止標記(如
<EOS>
)或達到最大長度(如2048 Token)。
示例:
- 生成完整回答后,
input_ids
的長度可能變為200(假設生成190個新Token)。 - 最終的
input_ids
包含原始Prompt和生成的回答。
8. Tokenizer反向處理:從input_ids到用戶文本
步驟:
- 將生成的
input_ids
(含prompt和回答)截取回答部分(去掉prompt)。 - 使用Tokenizer將
input_ids
轉換回自然語言文本(如“步驟一:熱鍋涼油…”)。
輸入輸出形狀:
- 截取后的
input_ids
形狀為(1, 190)
- 最終輸出文本長度取決于生成內容(如“步驟一:熱鍋涼油…”)
總結流程圖
用戶輸入 → 模板化Prompt → Tokenizer → input_ids (1,10) → 嵌入層 → hidden_states (1,10,4096) → Decoder Block ×32 → hidden_states (1,10,4096) → LM Head → 概率分布 (1,128000) → 采樣 → 新input_ids (1,11) → 重復生成 → input_ids (1,200) → Tokenizer反向 → 用戶文本
LlamaForCausalLM結構分析
以模型Llama-3.1-8B-Instruct
為例,將一部分子結構信息折疊起來,將顯示如下:
LlamaForCausalLM((model): LlamaModel((embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(layers): ModuleList((0-31): 32 x LlamaDecoderLayer(...))(norm): RMSNorm())(lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(logits_processor): LogitsProcessor()(pooler): Pooler()
)
可以看到LlamaForCausalLM
主要由幾個關鍵部分組成:model, lm_head, logits_processor和pooler。這幾個組件作用各不相同,我們現在來介紹一下他們。
1. model
:核心解碼器結構
(1) embed_tokens
:詞嵌入層
- 作用:將輸入的Token ID(如“西紅柿”→ID=1234)映射為4096維的向量,表示Token的語義和位置信息。
- 技術細節:
- 使用VocabParallelEmbedding(并行詞嵌入,僅需了解,無需深入),支持分布式訓練。
- 詞表大小為128256,覆蓋多語言和特殊符號(如
<BOS>
、<EOS>
)。
- 輸入輸出形狀:
- 輸入:
(batch_size, seq_len)
→(1, 10)
(假設輸入10個Token) - 輸出:
(batch_size, seq_len, hidden_dim)
→(1, 10, 4096)
- 輸入:
(2) layers
:32層Decoder Block
- 核心結構:
- 多頭注意力(MHA):通過Grouped-Query Attention (GQA) 提高推理效率(Llama 3.1新增)。
- 查詢(Q)、鍵(K)、值(V)的維度:
d_model=4096
,num_heads=32
,head_dim=128
。 - GQA機制:將K/V頭數減少為
num_key_value_heads=8
,降低計算開銷。
- 查詢(Q)、鍵(K)、值(V)的維度:
- 前饋網絡(MLP):使用SwiGLU激活函數(Sigmoid + Gated Linear Unit),替代傳統ReLU。
- 輸入:
4096
維 → 中間層:11008
維 → 輸出:4096
維。
- 輸入:
- 歸一化:每層使用RMSNorm(均方根歸一化),穩定訓練并加速收斂。
- 多頭注意力(MHA):通過Grouped-Query Attention (GQA) 提高推理效率(Llama 3.1新增)。
- 輸入輸出形狀:
- 每層輸入/輸出:
(1, 10, 4096)
(與輸入形狀一致)
- 每層輸入/輸出:
(3) norm
:最終歸一化層
- 作用:對32層Decoder Block的輸出進行最后一次歸一化,確保數值穩定性。
- 技術細節:
- 使用RMSNorm,無需計算均值,直接對向量的模長標準化。
- 公式:
hidden_states = hidden_states / sqrt(variance + ε)
,其中ε=1e-6
。
2. lm_head
:語言模型頭部
- 作用:將最終的隱藏狀態(
hidden_dim=4096
)映射為詞表大小(vocab_size=128256
)的概率分布,預測下一個詞。 - 技術細節:
- 使用ParallelLMHead(并行線性層),加速大規模詞表的計算。
- 參數量:
4096 × 128256 ≈ 5.16B
(占模型總參數量的約76%)。
- 輸入輸出形狀:
- 輸入:
(1, 4096)
(取最后一個位置的隱藏狀態) - 輸出:
(1, 128256)
(每個詞的概率值)
- 輸入:
3. logits_processor
:概率分布處理器
- 作用:對
lm_head
輸出的概率分布進行后處理,控制生成策略。 - 常用功能:
- 溫度調節(Temperature):降低溫度(
<1
)使輸出更確定,升高溫度(>1
)增加多樣性。 - Top-k/Top-p采樣:從概率最高的
k
個詞或累積概率達p
的詞中隨機選擇,平衡質量和多樣性。 - 重復懲罰(Repetition Penalty):抑制重復生成相同詞(如避免“西紅柿西紅柿”)。
- 溫度調節(Temperature):降低溫度(
- 輸入輸出形狀:
- 輸入:
(1, 128256)
(原始概率分布) - 輸出:
(1, 128256)
(處理后的概率分布)
- 輸入:
4. pooler
:池化層
- 作用:將整個序列的隱藏狀態壓縮為固定長度的向量,用于下游任務(如分類、相似度計算)。
- 技術細節:
- 默認取第一個Token(如
<BOS>
)的隱藏狀態作為全局表示。 - 或使用平均池化/最大池化,但Llama 3.1通常直接取
<BOS>
。
- 默認取第一個Token(如
- 輸入輸出形狀:
- 輸入:
(1, 10, 4096)
(全序列隱藏狀態) - 輸出:
(1, 4096)
(固定長度的全局向量)
- 輸入:
總結:組件協同工作流程
- 輸入處理:用戶輸入文本 → 模板化Prompt →
embed_tokens
→(1, 10, 4096)
- 特征提取:32層Decoder Block →
hidden_states
→(1, 10, 4096)
- 歸一化:
norm
→ 穩定輸出 - 生成預測:
lm_head
→(1, 128256)
概率分布logits_processor
→ 調整概率分布- 采樣生成下一個詞 → 更新
input_ids
- 迭代生成:重復步驟1-4,直到生成終止標記(
<EOS>
)或達到最大長度。 - 任務適配:
pooler
提取全局向量 → 用于分類、相似度等任務。
model
:像一個廚師,逐步處理食材(Token)并調整火候(注意力機制)。lm_head
:廚師的“味覺”,決定下一步該加什么調料(預測下一個詞)。logits_processor
:廚房的“規則制定者”,確保菜譜不重復且口味可控。pooler
:食客的“總結筆記”,用一句話概括整道菜的風味(全局語義)。