1. KV Cache的定義與核心原理
KV Cache(Key-Value Cache)是一種在Transformer架構的大模型推理階段使用的優化技術,通過緩存自注意力機制中的鍵(Key)和值(Value)矩陣,避免重復計算,從而顯著提升推理效率。
原理:
-
自注意力機制:在Transformer中,注意力計算基于公式:
Attention ( Q , K , V ) = softmax ( Q K ? d k ) V = ∑ i = 1 n w i v i (加權求和形式) \begin{split} \text{Attention}(Q, K, V) &= \text{softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right) V \\ &= \sum_{i=1}^n w_i v_i \quad \text{(加權求和形式)} \end{split} Attention(Q,K,V)?=softmax(dk??QK??)V=i=1∑n?wi?vi?(加權求和形式)?
其中,Q(Query)、K(Key)、V(Value)由輸入序列線性變換得到。 -
緩存機制:在生成式任務(如文本生成)中,模型以自回歸方式逐個生成token。首次推理時,計算所有輸入token的K和V并緩存;后續生成時,僅需為新token計算Q,并從緩存中讀取歷史K和V進行注意力計算。
-
復雜度優化:傳統方法的計算復雜度為O(n2),而KV Cache將后續生成的復雜度降為O(n),避免重復計算歷史token的K和V。
2. KV Cache的核心作用
-
加速推理:通過復用緩存的K和V,減少矩陣計算量,提升生成速度。例如,某聊天機器人應用響應時間從0.5秒縮短至0.2秒。
-
降低資源消耗:顯存占用減少約30%-50%(例如移動端模型從1GB降至0.6GB),支持在資源受限設備上部署大模型。
-
支持長文本生成:緩存機制使推理耗時不再隨文本長度線性增長,可穩定處理長序列(如1024 token以上)。
-
保持模型性能:僅優化計算流程,不影響輸出質量。
3. 技術實現與優化策略
實現方式:
-
數據結構
- KV Cache以張量形式存儲,Key Cache和Value Cache的形狀分別為
(batch_size, num_heads, seq_len, k_dim)
和(batch_size, num_heads, seq_len, v_dim)
。
- KV Cache以張量形式存儲,Key Cache和Value Cache的形狀分別為
-
兩階段推理:
- 初始化階段:計算初始輸入的所有K和V,存入緩存。
- 迭代階段:僅計算新token的Q,結合緩存中的K和V生成輸出,并更新緩存。
? 代碼示例(Hugging Face Transformers):設置model.generate(use_cache=True)
即可啟用KV Cache。
優化策略:
-
稀疏化(Sparse):僅緩存部分重要K和V,減少顯存占用。
-
量化(Quantization):將K和V矩陣從FP32轉為INT8/INT4,降低存儲需求。
共享機制(MQA/GQA):
-
Multi-Query Attention (MQA):所有注意力頭共享同一組K和V,顯存占用降低至1/頭數。
-
Grouped-Query Attention (GQA):將頭分組,組內共享K和V,平衡性能和顯存。
4. 挑戰與局限性
-
顯存壓力:隨著序列長度增加,緩存占用顯存線性增長(如1024 token占用約1GB顯存),可能引發OOM(內存溢出)。
-
冷啟動問題:首次推理仍需完整計算K和V,無法完全避免初始延遲。
5、python實現
import torch
import torch.nn as nn# 超參數
d_model = 4
n_heads = 1
seq_len = 3
batch_size = 3# 初始化參數(兼容多頭形式)
Wq = nn.Linear(d_model, d_model, bias=False)
Wk = nn.Linear(d_model, d_model, bias=False)
Wv = nn.Linear(d_model, d_model, bias=False)# 生成模擬輸入(整個序列一次性輸入)
input_sequence = torch.randn(batch_size, seq_len, d_model) # [B, L, D]# 初始化 KV 緩存(兼容多頭格式)
kv_cache = {"keys": torch.empty(batch_size, 0, n_heads, d_model // n_heads), # [B, T, H, D/H]"values": torch.empty(batch_size, 0, n_heads, d_model // n_heads)
}# 因果掩碼預先生成(覆蓋最大序列長度)
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() # [L, L]'''
本循環是將整句話中的token一個一個輸入,并更新KV緩存;
所以無需顯示的因果掩碼,因為因果掩碼只用于計算注意力權重時,而計算注意力權重時,KV緩存中的key和value已經包含了因果掩碼的信息。'''for step in range(seq_len):# 1. 獲取當前時間步的輸入(整個批次)current_token = input_sequence[:, step, :] # [B, 1, D]# 2. 計算當前時間步的 Q/K/V(保持三維結構)q = Wq(current_token) # [B, 1, D]k = Wk(current_token) # [B, 1, D]v = Wv(current_token) # [B, 1, D]# 3. 調整維度以兼容多頭格式(關鍵修改點)def reshape_for_multihead(x):return x.view(batch_size, 1, n_heads, d_model // n_heads).transpose(1, 2) # [B, H, 1, D/H]# 4. 更新 KV 緩存(增加時間步維度)kv_cache["keys"] = torch.cat([kv_cache["keys"], reshape_for_multihead(k).transpose(1, 2) # [B, T+1, H, D/H]], dim=1)kv_cache["values"] = torch.cat([kv_cache["values"],reshape_for_multihead(v).transpose(1, 2) # [B, T+1, H, D/H]], dim=1)# 5. 多頭注意力計算(支持批量處理)q_multi = reshape_for_multihead(q) # [B, H, 1, D/H]k_multi = kv_cache["keys"].transpose(1, 2) # [B, H, T+1, D/H]print("q_multi shape:", q_multi.shape)print("k_multi shape:", k_multi.shape)# 6. 計算注意力分數(帶因果掩碼)attn_scores = torch.matmul(q_multi, k_multi.transpose(-2, -1)) / (d_model ** 0.5)print("attn_scores shape:", attn_scores.shape)# attn_scores = attn_scores.masked_fill(causal_mask[:step+1, :step+1], float('-inf'))# print("attn_scores shape:", attn_scores.shape)# 7. 注意力權重計算attn_weights = torch.softmax(attn_scores, dim=-1) # [B, H, 1, T+1]# 8. 加權求和output = torch.matmul(attn_weights, kv_cache["values"].transpose(1, 2)) # [B, H, 1, D/H]# 9. 合并多頭輸出output = output.contiguous().view(batch_size, 1, d_model) # [B, 1, D]print(f"Step {step} 輸出:", output.shape)