一、LLM推理的核心過程:自回歸生成
LLM(如DeepSeek、ChatGPT、LLaMA系列等)的推理本質是自回歸生成:從初始輸入(如[CLS]
或用戶prompt)開始,逐token預測下一個詞,直到生成結束符(如[EOS]
)。其核心分為兩個階段:
1. Initialization階段(初始化)
- 目標:準備第一個token的生成條件。
- 關鍵步驟:
- 輸入編碼:將初始prompt轉換為token序列(如
[CLS]你好
),嵌入為向量x_0
。 - 初始隱藏狀態:通過Transformer的編碼器(或直接使用預訓練參數)生成第一層的隱藏狀態
h_0
。 - KV Cache初始化:為每一層的每個注意力頭創建空的Key/Value緩存(形狀:
[batch, heads, seq_len, head_dim]
)。此時seq_len=0
,因為尚無歷史token。
- 輸入編碼:將初始prompt轉換為token序列(如
示例:生成首詞“今天”時,輸入為[CLS]
,初始化后僅計算第一層的h_0
,KV Cache為空。
在LLM推理中,Initialization階段(初始化階段)又稱“預填充階段”(Prefill Stage)。這一命名源于其核心功能:為后續的逐token生成預填充(Prefill)KV Cache和初始隱藏狀態。
工程實現
Hugging Face的transformers
庫、NVIDIA的FasterTransformer均采用prefill
和generation
區分這兩個階段。例如:
# 偽代碼:Hugging Face生成邏輯
outputs = model.prefill(prompt) # 預填充KV Cache(Initialization)
for _ in range(max_new_tokens):outputs = model.generate_step(outputs) # 解碼階段,逐token生成
術語對比:Initialization vs Prefill
場景 | 常用術語 | 含義側重 |
---|---|---|
學術描述 | Initialization | 強調“初始化隱藏狀態和緩存” |
工程實踐 | Prefill | 強調“預填充固定長度的輸入” |
用戶視角 | 輸入處理階段 | 對應“用戶輸入的prompt處理” |
本質是同一階段,但“Prefill”更直觀反映了其“為生成提前準備歷史KV”的工程目標。
2. Decoding階段(解碼)
- 目標:逐token生成,每步復用歷史計算結果。
- 核心邏輯(以生成第
t
個token為例):- 當前token處理:將第
t-1
步生成的token嵌入x_t
,與前一步隱藏狀態拼接,輸入Transformer層。 - 注意力計算優化:
- 查詢(Query):僅計算當前token的Query向量
Q_t
(因為只關注當前位置)。 - 鍵值(Key/Value):復用KV Cache中的歷史Key/Value,并追加當前token的Key_t、Value_t。
- 注意力得分:計算
Q_t
與所有歷史Key的相似度(僅需一次矩陣乘法,而非重復全量計算)。
- 查詢(Query):僅計算當前token的Query向量
- 更新KV Cache:將當前層的Key_t、Value_t追加到緩存中(
seq_len += 1
)。 - 生成概率:通過LM頭輸出第
t
個token的概率分布,選擇下一詞(貪心/采樣)。
- 當前token處理:將第
3. 舉個栗子🌰
- 輸入:用戶prompt“請寫一首詩:”(4個token)。
- Prefill階段:
- 計算這4個token的所有層Key/Value,填充到KV Cache(此時緩存長度=4)。
- 生成第一個待擴展的隱藏狀態(對應第4個token的輸出)。
- Decoding階段:
逐句生成詩句,每步:- 計算當前token的Q(僅1個token)。
- 復用Prefill的4個KV + 之前生成的KV,計算注意力。
- 追加當前token的KV到緩存(緩存長度逐步增加到4+N)。
通過“預填充”,避免了每次生成新token時重復計算prompt的KV,這正是LLM實現高效推理的關鍵優化之一。
二、原始Transformer的效率瓶頸:O(n2)的重復計算
- 時間復雜度:訓練時并行計算所有token的注意力(O(n2)),但推理時需自回歸生成,每步需重新計算所有歷史token的Key/Value,導致總復雜度為O(n3)(n為序列長度)。
- 空間復雜度:每次推理需保存所有中間層的Key/Value,內存占用隨n線性增長,長文本(如n=4k)時顯存爆炸。
- 現實痛點:生成1000字的文章需重復計算百萬次注意力,傳統Transformer無法支持實時交互。
三、KV Cache:用空間換時間的核心優化
1. 方法本質
緩存歷史層的Key/Value,避免重復計算。每個Transformer層維護獨立的KV Cache,存儲該層所有已生成token的Key/Value向量。
2. 具體實現步驟(以單batch為例)
-
初始化緩存(t=0):
- 每層創建空緩存:
K_cache = []
,V_cache = []
(形狀:[num_layers, heads, 0, head_dim]
)。
- 每層創建空緩存:
-
第t步生成(t≥1):
- 前向傳播:輸入當前token嵌入,通過Transformer層計算當前層的
Q_t, K_t, V_t
。 - 拼接緩存:
K_cache[t_layer] = torch.cat([K_cache[t_layer], K_t], dim=2) # 在seq_len維度追加 V_cache[t_layer] = torch.cat([V_cache[t_layer], V_t], dim=2)
- 注意力計算:
attn_scores = Q_t @ K_cache[t_layer].transpose(-2, -1) # Q_t: [1, heads, 1, d], K_cache: [1, heads, t, d] attn_probs = softmax(attn_scores / sqrt(d)) @ V_cache[t_layer] # 僅需O(t)計算
- 更新隱藏狀態:將注意力輸出傳入下一層,直到LM頭生成token。
- 前向傳播:輸入當前token嵌入,通過Transformer層計算當前層的
-
循環:重復步驟2,直到生成
[EOS]
或達到最大長度。
3. 優化效果
- 時間:每步注意力從O(n2)→O(n),總復雜度O(n2)(接近線性)。
- 空間:緩存占用O(n)(每層存儲歷史K/V),但避免了重復計算的中間變量,實際顯存節省50%+。
- 典型案例:LLaMA-2 70B在4k序列長度下,KV Cache使推理速度提升4倍(NVIDIA官方數據)。
四、延伸:KV Cache的局限性與改進
- 顯存瓶頸:長上下文(如100k token)的KV Cache占用巨大(每層約4k token×4byte×2(KV)≈32KB,64層×100k≈2GB)。
- 優化方向:
- 分頁緩存(Paged Attention):NVIDIA提出,用非連續內存存儲KV,減少碎片化(2023年突破)。
- 動態緩存:僅保留最近相關token的KV(如檢索增強LLM)。
KV Cache是LLM落地的基石,其設計思想(復用歷史計算)貫穿現代推理優化(如FlashAttention、QLoRA),最終實現了從“實驗室模型”到“實時對話”的跨越。