LLM筆記(九)KV緩存(2)

文章目錄

    • 1. 背景與動機
    • 2. 不使用 KV Cache 的情形
      • 2.1 矩陣形式展開
      • 2.2 計算復雜度
    • 3. 使用 KV Cache 的優化
      • 3.1 核心思想
      • 3.2 矩陣形式展開
      • 3.3 計算復雜度對比
    • 4. 總結
    • 5. GPT-2 中 KV 緩存的實現分析
      • 5.1 緩存的數據結構與類型
      • 5.2 在注意力機制 (`GPT2Attention`) 中使用緩存
      • 5.3 緩存的更新機制 (`Cache.update`)
      • 5.4 在模型整體 (`GPT2Model`) 的 `forward` 方法中處理
      • 5.5 因果掩碼 (Causal Mask) 與 KV 緩存的配合
      • 5.6 支持多種高效的注意力實現
      • 5.7 KV 緩存的完整工作流程 (自回歸生成)
        • 5.7.1 初始步驟 (t=0):
        • 5.7.2 后續步驟 (t > 0):
      • KV 緩存的顯著優勢

  1. 看圖學kv 很形象清楚
  2. gpt2源碼
  3. 分析transformer模型的參數量、計算量、中間激活、KV cache量化分析了緩存
  4. kv解讀

1. 背景與動機

在自回歸生成(autoregressive generation)任務中,Transformer 解碼器需要在每一步中根據前面已生成的所有 token 重新計算注意力(Attention),這會產生大量重復計算。引入 KV Cache(Key–Value Cache)后,能夠將已計算的鍵值對緩存下來,僅對新增的 Query 進行點乘與加權,從而大幅降低時間與算力開銷。

2. 不使用 KV Cache 的情形

2.1 矩陣形式展開

  • 第 1 步(生成第一個 token)

    Q 1 , K 1 , V 1 ∈ R 1 × d Q_1, K_1, V_1 \in \mathbb{R}^{1\times d} Q1?,K1?,V1?R1×d

    A t t e n t i o n 1 = s o f t m a x ( Q 1 K 1 ? d ) , V 1 Attention_1 = \mathrm{softmax}\Bigl(\frac{Q_1 K_1^\top}{\sqrt d}\Bigr),V_1 Attention1?=softmax(d ?Q1?K1???),V1?

  • 第 2 步(生成第二個 token)
    構造全序列的矩陣:
    image.png

    需重算完整注意力矩陣:

    A t t e n t i o n 1 : 2 = s o f t m a x ( Q 1 : 2 K 1 : 2 ? d ) , V 1 : 2 Attention_{1:2} = \mathrm{softmax}\Bigl(\frac{Q_{1:2}K_{1:2}^\top}{\sqrt d}\Bigr),V_{1:2} Attention1:2?=softmax(d ?Q1:2?K1:2???),V1:2?

    計算出一個 2 × 2 2\times 2 2×2 矩陣,但我們只取最后一行作為輸出。

  • 第 n 步

    Q 1 : n , K 1 : n , V 1 : n ∈ R n × d , A t t e n t i o n 1 : n = s o f t m a x ( Q 1 : n K 1 : n ? d ) , V 1 : n Q_{1:n},K_{1:n},V_{1:n}\in\mathbb{R}^{n\times d},\quad Attention_{1:n} = \mathrm{softmax}\Bigl(\tfrac{Q_{1:n}K_{1:n}^\top}{\sqrt d}\Bigr),V_{1:n} Q1:n?,K1:n?,V1:n?Rn×d,Attention1:n?=softmax(d ?Q1:n?K1:n???),V1:n?

    每步均重新構建并計算 n × n n\times n n×n 注意力矩陣。

2.2 計算復雜度

  • 注意力矩陣構建 O ( n 2 ? d ) O(n^2\cdot d) O(n2?d)

  • 整體推理階段:若生成總長度為 N N N,則總復雜度近似為

    ∑ n = 1 N O ( n 2 d ) ; = ; O ( N 3 d ) \sum_{n=1}^N O(n^2 d);=;O(N^3 d) n=1N?O(n2d);=;O(N3d),

    由于每步都做重復計算,效率極低。

3. 使用 KV Cache 的優化

3.1 核心思想

  • 緩存已計算的 K, V:對于前序列位置的鍵值對,只需計算一次并存儲。

  • 僅對新增 Query 進行點乘:第 n n n 步僅需計算 Q n Q_n Qn? 與所有緩存 K 的點乘,得到長度為 n n n 的注意力權重,再加權疊加對應的 V。

3.2 矩陣形式展開

  • 第 1 步:如前,無緩存,計算
    A t t e n t i o n 1 = s o f t m a x ( Q 1 K 1 ? / d ) , V 1 Attention_1 = \mathrm{softmax}(Q_1K_1^\top/\sqrt d),V_1 Attention1?=softmax(Q1?K1??/d ?),V1?.

  • 第 2 步

    • 新增 Q 2 ∈ R 1 × d Q_2\in\mathbb{R}^{1\times d} Q2?R1×d

    • 緩存矩陣已擴展為

      image.png

    • 只做一次 1 × 2 1\times 2 1×2 點乘:

      A t t e n t i o n 2 = s o f t m a x ( Q 2 K c a c h e ? d ) , V c a c h e Attention_2 = \mathrm{softmax}\Bigl(\tfrac{Q_2 K_{\mathrm{cache}}^\top}{\sqrt d}\Bigr),V_{\mathrm{cache}} Attention2?=softmax(d ?Q2?Kcache???),Vcache?,

      輸出即為所需的 1 × d 1\times d 1×d 向量。

  • 第 n 步

    K c a c h e ∈ R n × d , V c a c h e ∈ R n × d , A t t e n t i o n n = s o f t m a x ( Q n K c a c h e ? d ) , V c a c h e K_{\mathrm{cache}}\in\mathbb{R}^{n\times d},\quad V_{\mathrm{cache}}\in\mathbb{R}^{n\times d},\quad Attention_n = \mathrm{softmax}\Bigl(\tfrac{Q_n K_{\mathrm{cache}}^\top}{\sqrt d}\Bigr),V_{\mathrm{cache}} Kcache?Rn×d,Vcache?Rn×d,Attentionn?=softmax(d ?Qn?Kcache???),Vcache?.

3.3 計算復雜度對比

模式每步復雜度總體復雜度(生成長度 N N N
無 Cache O ( n 2 d ) O(n^2 d) O(n2d) O ( N 3 d ) O(N^3 d) O(N3d)
有 KV Cache O ( n d ) O(n d) O(nd) ∑ n = 1 N O ( n d ) = O ( N 2 d ) \displaystyle\sum_{n=1}^N O(n d)=O(N^2 d) n=1N?O(nd)=O(N2d)
  • 加速比:從二次方級別 O ( n 2 ) O(n^2) O(n2) 降到線性級別 O ( n ) O(n) O(n),對長序列提升顯著。

4. 總結

  1. 多頭注意力(Multi-Head)
    每個 head 獨立緩存自己的 K, V 矩陣,計算時分別點乘再拼接。總體計算與存儲線性可擴展。

  2. 緩存管理

    • 內存占用:緩存矩陣大小隨生成長度增長,應考慮清理過舊不再需要的序列(如 sliding window)。

    • Batch 推理:對多條序列并行生成時,可為每條序列維護獨立緩存,或統一按最大長度對齊。

  3. 硬件優化

    • 內存帶寬:KV Cache 減少重復內存載入,對帶寬友好;

    • 并行度:線性點乘更易與矩陣乘加(GEMM)指令級并行融合。

  4. 實踐中常見問題

    • Cache 不命中:若使用 prefix-tuning 等技術動態修改 key/value,需謹慎處理緩存一致性。
    • 數值穩定性:長序列高維 softmax 易出現梯度消失/爆炸,可結合溫度系數或分段歸一化。

5. GPT-2 中 KV 緩存的實現分析

GPT-2(以及許多其他基于 Transformer 的自回歸模型)在生成文本時,為了提高效率,會使用一種稱為 KV 緩存 (Key-Value Cache) 的機制。其核心思想是:在生成第 t 個 token 時,計算注意力所需的鍵 (Key) 和值 (Value) 向量可以部分來自于已經生成的 t-1 個 token。通過緩存這些歷史的 K 和 V 向量,可以避免在每一步生成時都對整個已生成序列重新進行昂貴的 K 和 V 計算。

5.1 緩存的數據結構與類型

Hugging Face Transformers 庫為 GPT-2 提供了靈活的緩存管理機制,主要通過 Cache 基類及其子類實現。

  • Cache (基類): 定義了緩存對象的基本接口,例如 update (更新緩存) 和 get_seq_length (獲取當前緩存的序列長度) 等方法。
  • DynamicCache:
    • 這是自回歸生成時最常用的緩存類型。
    • 它允許緩存的序列長度動態增長。當生成新的 token 時,新計算出的 K 和 V 向量會被追加到已有的緩存后面。
    • 不需要預先分配固定大小的內存,更加靈活,但可能在內存管理上有一些開銷。
  • StaticCache:
    • 在創建時就需要預先分配固定大小的內存空間來存儲 K 和 V 向量。
    • 適用于已知最大生成長度或需要更可控內存占用的場景。
    • 如果生成的序列長度超過了預分配的大小,可能會出錯或需要特殊處理。
  • EncoderDecoderCache:
    • 主要用于 Encoder-Decoder 架構的模型 (如 T5, BART)。
    • 它內部會分別管理編碼器-解碼器注意力(交叉注意力)的 KV 緩存和解碼器自注意力的 KV 緩存。
    • GPT-2 是一個僅解碼器 (Decoder-only) 模型,所以主要關注自注意力的緩存。
# 相關類的導入,展示了緩存工具的多樣性
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache

5.2 在注意力機制 (GPT2Attention) 中使用緩存

GPT2Attention 類的 forward 方法是 KV 緩存機制的核心應用點。

class GPT2Attention(nn.Module):  ...  def forward(  self,  hidden_states: Optional[Tuple[torch.FloatTensor]],  layer_past: Optional[Tuple[torch.Tensor]] = None, # 舊版本的緩存參數名  past_key_value: Optional[Cache] = None,           # 新版本的緩存對象  attention_mask: Optional[torch.FloatTensor] = None,  head_mask: Optional[torch.FloatTensor] = None,  use_cache: Optional[bool] = False,  output_attentions: Optional[bool] = False,  cache_position: Optional[torch.LongTensor] = None, # 指示新token在緩存中的位置  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:  # 1. 計算當前輸入 hidden_states 的 Q, K, V        # self.c_attn 是一個線性層,通常一次性計算出 Q, K, V 然后分割  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)  # 2. 將 Q, K, V 重塑為多頭形式 (batch_size, num_heads, seq_len, head_dim)        query = self._split_heads(query, self.num_heads, self.head_dim)  key = self._split_heads(key, self.num_heads, self.head_dim)  value = self._split_heads(value, self.num_heads, self.head_dim)  # 3. KV 緩存處理  if past_key_value is not None:  # 如果是 EncoderDecoderCache,根據是否交叉注意力選擇正確的緩存  if isinstance(past_key_value, EncoderDecoderCache):  # ... (GPT-2 不直接使用此邏輯,但展示了其通用性)  pass  # 使用 cache_position 來更新緩存中的特定位置  cache_kwargs = {"cache_position": cache_position}  # 調用緩存對象的 update 方法  # key 和 value 是當前新計算的 K, V            # self.layer_idx 標識當前是哪一層的緩存  key, value = past_key_value.update(key, value, self.layer_idx, cache_kwargs)  # 此時的 key 和 value 包含了歷史信息和當前新計算的信息  # 4. 計算注意力權重 (Q @ K^T)        # ...        attn_weights = torch.matmul(query, key.transpose(-1, -2))  # ... 應用注意力掩碼 (causal mask, padding mask) ...  # 5. 計算注意力輸出 (attn_weights @ V)        attn_output = torch.matmul(attn_weights, value)  # ... 合并多頭,返回結果 ...  if use_cache:  # 如果使用緩存,則 present_key_value 就是更新后的 past_key_value            present_key_value = past_key_value  else:  present_key_value = None  return attn_output, present_key_value # 返回注意力的輸出和更新后的緩存

關鍵點解釋:

  • past_key_value (或 layer_past): 這是從上一個時間步或上一個調用傳遞過來的緩存對象。它包含了到目前為止所有先前 token 的 K 和 V 向量。
  • cache_position: 這是一個非常重要的參數,尤其是在使用了諸如 Flash Attention 2 等更高級的注意力實現時。它告訴緩存 update 方法以及注意力計算函數,新的 K 和 V 向量應該被放置在緩存張量的哪個位置。這對于正確地處理填充(padding)和動態序列長度至關重要。例如,如果當前輸入的是第 t 個 token(從0開始計數),cache_position 可能就是 t
  • self.layer_idx: Transformer 模型通常由多個相同的注意力層堆疊而成。每一層都有自己獨立的 KV 緩存。layer_idx 用于標識當前正在處理的是哪一層的緩存,確保數據被正確地存取。
  • use_cache: 控制是否使用和返回緩存。在訓練時通常為 False(除非進行特定類型的訓練,如 teacher forcing 的逐token訓練),在推理(生成)時為 True

5.3 緩存的更新機制 (Cache.update)

Cache 對象的 update 方法是實現緩存的核心。雖然具體的實現會因 DynamicCacheStaticCache 而異,但其基本邏輯是:

class DynamicCache(Cache):  def __init__(self):  self.key_cache: List[torch.Tensor] = [] # 每層一個 tensor        self.value_cache: List[torch.Tensor] = [] # 每層一個 tensor        self.seen_tokens = 0 # 已緩存的token數量  def update(  self,  key_states: torch.Tensor,    # 新計算的 key        value_states: torch.Tensor,  # 新計算的 value        layer_idx: int,              # 當前層索引  cache_kwargs: Optional[Dict[str, Any]] = None,  ) -> Tuple[torch.Tensor, torch.Tensor]:  # 獲取 cache_position        cache_position = cache_kwargs.get("cache_position")  # 如果是第一次更新這一層 (或緩存為空)  if layer_idx >= len(self.key_cache):  # 初始化該層的緩存張量  # ... 根據 key_states 和 value_states 的形狀以及預估的最大長度(或動態調整)  self.key_cache.append(torch.zeros_like(key_states_preallocated))  self.value_cache.append(torch.zeros_like(value_states_preallocated))  # 將新的 key_states 和 value_states 寫入到緩存的指定位置  # 對于 DynamicCache,通常是直接拼接或在預分配空間中按位置寫入  if cache_position is not None:  # 使用 cache_position 精確地更新緩存的特定部分  # 例如: self.key_cache[layer_idx][:, :, cache_position, :] = key_states            #       self.value_cache[layer_idx][:, :, cache_position, :] = value_states            # 這里的維度可能需要根據實際實現調整  # 重要的是理解 cache_position 的作用  # 例如,如果 key_states 的形狀是 (batch, num_heads, new_seq_len, head_dim)            # cache_position 的形狀可能是 (batch, new_seq_len) 或廣播的 (new_seq_len)            # 需要將 key_states 放置到 self.key_cache[layer_idx] 的正確"槽位"  # 對于自回歸,通常 new_seq_len = 1            self.key_cache[layer_idx].index_copy_(dim=2, index=cache_position, source=key_states)  self.value_cache[layer_idx].index_copy_(dim=2, index=cache_position, source=value_states)  # 更新已見過的token數量  self.seen_tokens = cache_position[-1] + 1 # 取最后一個新token的位置加1  else: # 舊的、不使用 cache_position 的邏輯(通常是簡單拼接)  self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)  self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)  self.seen_tokens += key_states.shape[2]  # 返回包含所有歷史信息(包括剛更新的)的 K 和 V 狀態  return self.key_cache[layer_idx], self.value_cache[layer_idx]

update 方法的關鍵職責:

  1. 接收當前新計算的 key_statesvalue_states
  2. 根據 layer_idx 找到對應層的緩存。
  3. (可選,但推薦)使用 cache_position 將新的 K, V 向量精確地放置到緩存張量的正確位置。這對于處理批處理中不同樣本有不同歷史長度的情況(例如,在束搜索beam search后或 speculative decoding 后),或者在有填充 token 時非常重要。
  4. 返回完整的、包含所有歷史信息和當前新信息的 K, V 向量,供后續的注意力計算使用。
  5. 更新內部狀態,如已緩存的 token 數量 (seen_tokens)。

5.4 在模型整體 (GPT2Model) 的 forward 方法中處理

GPT2Modelforward 方法負責協調整個模型的流程,包括緩存的初始化、傳遞和 cache_position 的計算。

class GPT2Model(GPT2PreTrainedModel):  def forward(  self,  input_ids: Optional[torch.LongTensor] = None,  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, # 舊版緩存元組  attention_mask: Optional[torch.FloatTensor] = None,  # ...  use_cache: Optional[bool] = None,  output_attentions: Optional[bool] = None,  output_hidden_states: Optional[bool] = None,  return_dict: Optional[bool] = None,  cache_position: Optional[torch.LongTensor] = None,  ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:  # ... (處理輸入ID和嵌入) ...  inputs_embeds = self.wte(input_ids) # 詞嵌入  position_embeds = self.wpe(position_ids) # 位置嵌入  hidden_states = inputs_embeds + position_embeds  # 1. 緩存初始化和類型轉換  if use_cache:  if past_key_values is None: # 如果是第一次調用或沒有提供緩存  # 根據配置決定使用哪種緩存,通常是 DynamicCache                # 例如:self.config.cache_implementation == "dynamic"  past_key_values = DynamicCache()  elif not isinstance(past_key_values, Cache):  # 為了兼容舊的元組形式的緩存,將其轉換為新的 Cache 對象  past_key_values = DynamicCache.from_legacy_cache(past_key_values)  # else: past_key_values 保持為 None  # 2. 計算 cache_position        if cache_position is None: # 如果外部沒有提供 cache_position            # 獲取當前緩存中已有的 token 數量  past_seen_tokens = past_key_values.get_seq_length(self.config.num_hidden_layers) if past_key_values is not None else 0  # 當前輸入序列的長度  current_seq_length = inputs_embeds.shape[1]  # cache_position 從 past_seen_tokens 開始,長度為 current_seq_length            cache_position = torch.arange(  past_seen_tokens, past_seen_tokens + current_seq_length, device=inputs_embeds.device  )  # else: 使用外部傳入的 cache_position  # ... (準備注意力掩碼,考慮因果關系和緩存長度) ...  # 3. 逐層傳遞和更新緩存  all_hidden_states = () if output_hidden_states else None  all_self_attentions = () if output_attentions else None  # next_decoder_cache 用于收集下一輪的緩存 (如果 use_cache 為 True)        # 在新的 Cache 對象設計中,past_key_values 本身會被原地更新或返回更新后的版本  # 因此,這個 next_decoder_cache 可能不再是必需的,或者其角色由 past_key_values 自身承擔  for i, block in enumerate(self.h): # self.h 是 GPT2Block 的列表  # ...  # 將當前層的緩存 (如果存在) 和 cache_position 傳遞給 GPT2Block            # GPT2Block 內部會再將其傳遞給 GPT2Attention            layer_outputs = block(  hidden_states,  layer_past=None, # 舊參數,通常為None  attention_mask=extended_attention_mask,  head_mask=head_mask[i],  encoder_hidden_states=None,  encoder_attention_mask=None,  use_cache=use_cache,  output_attentions=output_attentions,  past_key_value=past_key_values, # 傳遞整個緩存對象  cache_position=cache_position,  )  hidden_states = layer_outputs[0] # 更新 hidden_states            # 如果 use_cache,block 會返回更新后的緩存,這里 past_key_values 已被更新  # (在 Cache 對象實現中,update 方法通常返回更新后的完整緩存狀態,  #  或者直接在對象內部修改,取決于具體實現)  # ... (處理輸出) ...  return BaseModelOutputWithPast(  last_hidden_state=hidden_states,  past_key_values=past_key_values if use_cache else None, # 返回更新后的緩存  hidden_states=all_hidden_states,  attentions=all_self_attentions,  )

5.5 因果掩碼 (Causal Mask) 與 KV 緩存的配合

在自回歸生成中,模型只能注意到當前 token 及其之前的所有 token,不能注意到未來的 token。這是通過因果掩碼實現的。當使用 KV 緩存時,因果掩碼的構建需要考慮到緩存中已有的 token 數量。

class GPT2Attention(_GPT2Attention):  def _update_causal_mask(  self,  attention_mask: torch.Tensor, # 原始的 attention_mask (可能包含 padding)        input_tensor: torch.Tensor,   # 當前輸入的 hidden_states        cache_position: torch.Tensor,  past_key_values: Cache,       # 當前的緩存對象  output_attentions: bool,  ):  # 獲取當前輸入的序列長度 (通常為1,在自回歸生成的每一步)  input_seq_length = input_tensor.shape[1]  # 獲取緩存中已有的序列長度  past_seen_tokens = past_key_values.get_seq_length(self.layer_idx)  # 總的上下文長度 = 緩存長度 + 當前輸入長度  total_context_length = past_seen_tokens + input_seq_length  # _prepare_4d_causal_attention_mask_with_cache_position 會生成一個正確的掩碼  # 這個掩碼會確保:  # 1. 查詢 Q (來自當前輸入) 只能注意到鍵 K (來自緩存+當前輸入) 中對應位置及之前的部分。  # 2. 處理好 padding (如果 attention_mask 中有指示)。  # 形狀通常是 (batch_size, 1, query_length, key_length)        # 其中 query_length 是當前輸入的長度 (如1)  # key_length 是總的上下文長度 (past_seen_tokens + input_seq_length)        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(  attention_mask,  input_shape=(input_tensor.shape[0], input_seq_length), # 當前輸入的形狀  target_length=total_context_length, # K, V 的總長度  dtype=input_tensor.dtype,  cache_position=cache_position, # 關鍵!用于確定當前 Q 在 K,V 序列中的相對位置  )  return causal_mask

_prepare_4d_causal_attention_mask_with_cache_position 這個輔助函數會創建一個上三角矩陣(或類似結構),其中未來的位置會被掩蓋掉(例如,設置為一個非常小的負數,以便 softmax 后變為0)。cache_position 在這里的作用是,確保即使當前查詢 Q 的序列長度很短(例如為1),它在與歷史的 K, V 進行比較時,依然能正確地只關注到歷史和當前 K, V 中該 Q 之前的部分。

5.6 支持多種高效的注意力實現

Hugging Face Transformers 庫允許 GPT-2(以及其他模型)利用更高效的注意力后端實現,例如:

  • eager: PyTorch 的標準、原生注意力實現。
  • sdpa (Scaled Dot Product Attention): PyTorch 2.0 引入的高度優化的注意力函數 torch.nn.functional.scaled_dot_product_attention。它通常比 eager模式更快,內存效率也更高,并且可以自動選擇最優的底層實現(如 FlashAttention 或 memory-efficient attention)。
  • flash_attention_2: 直接集成 FlashAttention v2 庫。這是一種專門為現代 GPU 設計的、IO 感知的精確注意力算法,速度非常快,內存占用小。

KV 緩存機制的設計需要與這些高效實現兼容。例如,torch.nn.functional.scaled_dot_product_attention 和 FlashAttention 都支持直接傳入包含歷史和當前信息的完整 K, V 張量。cache_position 在這里尤為重要,因為它可以幫助這些高效后端理解哪些部分是新的,哪些是舊的,以及如何正確應用因果掩碼。

# 在 GPT2Attention 的 forward 方法中
self.config._attn_implementation 存儲了選擇的注意力實現方式 ("eager", "sdpa", "flash_attention_2")  ... (計算 query, key, value) ...  
... (更新 key, value 使用 past_key_value 和 cache_position) ...  
此時 key 和 value 是拼接/更新后的完整 K, V  if self.config._attn_implementation == "sdpa":  # 使用 PyTorch SDPA    # is_causal=True 會自動應用因果掩碼  # attn_mask 可能需要根據 SDPA 的要求進行調整  attn_output = torch.nn.functional.scaled_dot_product_attention(  query, key, value, attn_mask=adjusted_attn_mask, dropout_p=self.attn_dropout.p, is_causal=True  )  
elif self.config._attn_implementation == "flash_attention_2":  # from flash_attn import flash_attn_func  # 可能需要對 query, key, value 的形狀或數據類型進行調整以適應 flash_attn_func    # causal=True 會應用因果掩碼  attn_output = flash_attn_func(  query.transpose(1, 2), # FlashAttention 可能期望 (batch, seq_len, num_heads, head_dim)        torch.stack((key.transpose(1,2), value.transpose(1,2)), dim=0), # K, V 打包  dropout_p=self.attn_dropout.p,  causal=True,  )  
else: # "eager"  # ... (標準的 PyTorch matmul 實現) ...

5.7 KV 緩存的完整工作流程 (自回歸生成)

5.7.1 初始步驟 (t=0):
  • 用戶提供初始的 input_ids (例如,一個 [BOS] token 或者一段提示文本)。
  • past_key_valuesNone
  • 模型 forward 方法被調用。
  • use_cache 通常為 True
  • 初始化一個空的 DynamicCache 對象作為 past_key_values
  • 計算 cache_position,此時它通常是從 0 開始的序列 (e.g., torch.arange(0, initial_input_len)).
  • 對于每一注意力層:
    • 計算當前 input_ids 對應的 Q, K, V。
    • 由于 past_key_values 剛被初始化(內部緩存為空),update 方法會將這些新計算的 K, V 存入緩存的第一批位置。
    • 使用這些 K, V (此時它們只包含當前輸入的信息) 和 Q 進行注意力計算。
  • 模型輸出 logits (用于預測下一個 token) 和更新后的 past_key_values (現在包含了第一個輸入的 K,V)。
5.7.2 后續步驟 (t > 0):
  • 從上一步的 logits 中采樣得到新的 input_ids (通常是一個新的 token)。
  • 將上一步返回的 past_key_values (包含了 t-1 步及之前所有 token 的 K,V) 作為輸入傳遞給模型。
  • 模型 forward 方法再次被調用。
  • use_cacheTrue
  • 計算 cache_position。此時,past_key_values.get_seq_length() 會返回已緩存的 token 數量 (例如 t)。新的 cache_position 會是 torch.tensor([t]),表示這個新 token 是序列中的第 t+1 個元素 (如果從1開始計數的話,或者第 t 個位置如果從0開始計數)。
  • 對于每一注意力層:
    • 只對新輸入的單個 token 計算其 Q, K, V (這些是"小"張量)。
    • 調用 past_key_values.update(new_key, new_value, layer_idx, cache_kwargs={"cache_position": cache_position})
      • update 方法會將這個新 token 的 K, V 追加到對應層緩存中已有的 K, V 之后,并返回完整的 K (包含所有 t+1 個 token) 和完整的 V。
    • 使用新 token 的 Q 和完整的 (歷史+當前) K, V 計算注意力。因果掩碼會確保 Q 只注意到 K,V 中它自己及之前的部分。
  • 模型輸出 logits 和再次更新后的 past_key_values

這個過程一直重復,直到生成了 [EOS] token 或達到最大長度。

KV 緩存的顯著優勢

  1. 避免冗余計算: 這是最核心的優勢。在生成第 t 個 token 時,前 t-1 個 token 的 K 和 V 向量已經計算并存儲在緩存中,無需重新計算。注意力機制只需要為新的當前 token 計算 K 和 V,然后將它們與緩存中的歷史 K,V 結合起來。
  2. 顯著提高生成速度: 尤其對于長序列生成,每次迭代的計算量從 O(N2)(N為當前總長度)降低到接近 O(N)(主要是新 Q 與歷史 K,V 的交互),因為主要計算瓶頸(K,V的生成)只針對新token進行。
  3. 支持高效的批處理生成: 雖然每個樣本在批次中可能有不同的已生成長度(特別是在使用可變長度輸入或某些采樣策略時),通過 cache_position 和可能的填充/掩碼機制,KV 緩存可以有效地處理這種情況。
  4. 與先進注意力實現的兼容性: 如前所述,KV 緩存的設計與 SDPA、FlashAttention 等高效后端良好集成,使得模型可以同時享受到算法優化和底層硬件加速的好處。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/83818.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/83818.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/83818.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

2025年滲透測試面試題總結-各廠商二面試題02(題目+回答)

網絡安全領域各種資源,學習文檔,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各種好玩的項目及好用的工具,歡迎關注。 目錄 各廠商二面試題02 模塊六:基礎技術擴展 1. HTTP請求方式 2. 域名解析工具與技術 3. Web十…

專業漏洞掃描機構如何助力企業保障安全并提升競爭力?

在這個信息化的當下,專業漏洞掃描機構扮演著至關重要的角色。他們運用專業的技術和手段,對各種軟件和系統進行細致的漏洞檢測,確保其安全可靠,同時幫助企業提高產品質量和市場競爭力。 服務項目 我們專注于向客戶供應周到詳盡的…

卷積神經網絡中的二維卷積與三維卷積詳解

【內容摘要】 本文聚焦卷積神經網絡中的二維卷積與三維卷積,詳細解析兩者的區別、操作原理及應用場景,涵蓋二維/三維卷積操作示意圖、多通道輸入處理方式,以及RGB圖像不采用三維卷積的原因,助力理解不同卷積類型的特性與適用場景。…

Oracle 的 ASSM 表空間

Oracle 的 ASSM(Automatic Segment Space Management)表空間 是一種自動管理段空間的技術,通過位圖(Bitmap)機制跟蹤數據塊的使用情況,替代傳統的手動管理(MSSM,即 Freelist 管理&am…

螞蟻金服大數據面經及參考答案

Java 如何保證跨平臺性?請從 JVM 底層適配機制及向上提供的統一接口角度說明 Java 的跨平臺性是其核心優勢之一,依賴于 JVM(Java Virtual Machine)的底層適配機制和向上層提供的統一接口。從底層來看,JVM 針對不同操作系統和硬件平臺進行了定制化實現,負責解析和執行 Ja…

P1009 [NOIP 1998 普及組] 階乘之和

題目描述 用高精度計算出 S1!2!3!?n!(n≤50)。 其中 ! 表示階乘,定義為 n!n(n?1)(n?2)?1。例如,5!54321120。 輸入格式 一個正整數 n。 輸出格式 一個正整數 S,表示計算結果。 輸入輸出樣例 輸入 3 輸出…

Python 的 os 庫常見使用方法(操作目錄及文件)

前言: os 模塊是 Python 標準庫中用于與操作系統交互的核心模塊,提供了許多操作文件和目錄的功能。以下是常見的使用方法: 1. 目錄操作 方法功能說明示例os.getcwd()獲取當前工作目錄print(os.getcwd())os.chdir(path)切換當前工作目錄os.ch…

vue3 el-table實現字段可編輯

在Vue 3中,如果你想讓el-table(Element Plus的表格組件)的字段可編輯,你可以通過以下方式來實現: 使用cell-mouse-enter和cell-mouse-leave事件動態顯示編輯圖標或控件 你可以在鼠標進入單元格時顯示一個編輯圖標或輸…

基于shardingsphere的分庫分表方案

一、準備docker容器 啟動兩個mysql的docker容器 docker run -v /root/mysql_volume/data:/var/lib/mysql -v /root/mysql_volume/conf:/etc/mysql/conf.d -v /root/mysql_volume/my.cnf:/etc/my.cnf -p 3306:3306 --name mysql --restartalways --privilegedtrue -e MYSQL_RO…

SearxNG本地搜索引擎

SearxNG 是一個強大、開源的 元搜索引擎(meta search engine),它不會存儲用戶信息,注重隱私保護,并支持從多個搜索引擎聚合結果,用戶可以自建部署,打造一個無廣告、可定制的搜索平臺。 ?? 什么是 SearxNG? SearxNG 是 Searx 的一個積極維護的分支(fork),意在改進…

Vue3.5 企業級管理系統實戰(十九):菜單管理

篇幅原因,本節先探討菜單管理頁面增刪改查相關功能,角色菜單,菜單權限,動態菜單等內容放在后面。 1 菜單 api 在 src/api/menu.ts 中添加菜單 api,代碼如下: //src/api/menu.ts import service from &qu…

【android bluetooth 協議分析 01】【HCI 層介紹 8】【ReadLocalVersionInformation命令介紹】

1. HCI_Read_Local_Version_Information 命令介紹 1. 功能(Description) HCI_Read_Local_Version_Information 命令用于讀取本地 Bluetooth Controller 的版本信息,包括 HCI 和 LMP 層的版本,以及廠商 ID 和子版本號。 這類信息用…

React底層架構深度解析:從虛擬DOM到Fiber的演進之路

一、虛擬DOM:性能優化的基石 1.1 核心工作原理 React通過JSX語法將組件轉換為輕量級JavaScript對象(即虛擬DOM),而非直接操作真實DOM。這一過程由React.createElement()實現,其結構包含元素類型、屬性和子節點等信息&a…

從AlphaGo到ChatGPT:AI技術如何一步步改變世界?

從AlphaGo到ChatGPT:AI技術如何一步步改變世界? 這里給大家分享一個人工智能學習網站。點擊跳轉到網站。 https://www.captainbed.cn/ccc 前言 在科技發展的歷史長河中,人工智能(AI)技術無疑是最為璀璨的明珠之一。從…

關于在Unity項目中使用Post Processing插件打包到web端出現的問題

關于在Unity項目中使用Post Processing插件打包到web端出現的問題 解決方法:是不激活攝像機上的Post Processing有關組件,拉低場景中的Directional Light平行光的強度進行web端打包。 (烘焙燈光時是可以激活。) web端支持這個Pos…

MySQL - 如何突破單庫性能瓶頸

數據庫服務器硬件優化 我們來看看對數據庫所在的服務器是如何進行優化的,服務器是數據庫的宿主,其性能直接影響了數據庫的性能,所以服務器的優化也是數據庫優化的第一步。 數據庫服務器通常是從 CPU、內存、磁盤三個角度進行硬件優化的&…

用 CodeBuddy 搭建「MiniGoal 小目標打卡器」:一次流暢的 UniApp 開發體驗

我正在參加CodeBuddy「首席試玩官」內容創作大賽,本文所使用的 CodeBuddy 免費下載鏈接:騰訊云代碼助手 CodeBuddy - AI 時代的智能編程伙伴 在日常生活中,我們總是希望能夠堅持一些小習慣,比如每天鍛煉十分鐘、讀一頁書、早睡十分…

OpenCV 環境搭建與概述

// //OpenCV-4.11.0 C VS2019 // 一、OpenCV學習路線 1、入門: OpenCV圖像讀寫、視頻讀寫、基本像素處理、基本卷積處理、基本C開發知識。 2、初級: OpenCV自定義卷積操作、圖像梯度、邊緣提取、二值分析、視頻分析、形態學處理、幾何變換與透視變換。 3、中級: 角點查找、BL…

如何快速更換電腦瀏覽器ip:教程與注意事項

無論是為了訪問地域限制內容、保護隱私,還是解決網絡問題,快速更換瀏覽器IP地址的需求日益增多。以下是快速更換電腦瀏覽器IP地址的幾種常用方法及注意事項,結合了多種場景下的解決方案: 一、快速更換瀏覽器IP的方法 1. 代理服務…

【kafka】kafka概念,使用技巧go示例

1. Kafka基礎概念 1.1 什么是Kafka? Kafka是一個分布式流處理平臺,用于構建實時數據管道和流式應用。核心特點: 高吞吐量:每秒可處理百萬級消息持久化存儲:消息按Topic分區存儲在磁盤分布式架構:支持水平…