大語言模型加速技術之KV Cache

大語言模型加速技術之KV Cache

  • Why we need KV Cache ?
  • Self-Attention Without Cache
  • Self-Attention With Cache
  • Huggingface 官方代碼實現

Why we need KV Cache ?

生成式generative模型的推理過程很有特點,我們給一個輸入文本,模型會輸出一個回答(長度為N),其實該過程中執行了N次推理過程。即GPT類模型一次推理只輸出一個token,輸出token會與輸入tokens 拼接在一起,然后作為下一次推理的輸入,這樣不斷反復直到遇到終止符。

如上描述是我們通常認知的GPT推理過程。代碼描述如下:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizerdef main():# 加載模型和 tokenizermodel = GPT2LMHeadModel.from_pretrained("gpt2").eval()tokenizer = GPT2Tokenizer.from_pretrained("gpt2")# 初始輸入in_text = "Open AI is a"in_tokens = torch.tensor(tokenizer.encode(in_text)).unsqueeze(0)  # [1, seq_len]token_eos = torch.tensor([198])  # line break symbolout_token = Nonei = 0with torch.no_grad():while out_token != token_eos:outputs = model(in_tokens)logits = outputs.logitsout_token = torch.argmax(logits[0, -1, :], dim=-1, keepdim=True).unsqueeze(0)  # [1, 1]in_tokens = torch.cat((in_tokens, out_token), dim=1)text = tokenizer.decode(in_tokens[0])print(f'step {i} input: {text}', flush=True)i += 1out_text = tokenizer.decode(in_tokens[0])print(f'\nInput: {in_text}')print(f'Output: {out_text}')if __name__ == "__main__":main()

輸出:

step 0 input: Open AI is a new
step 1 input: Open AI is a new way
step 2 input: Open AI is a new way to
step 3 input: Open AI is a new way to build
step 4 input: Open AI is a new way to build AI
step 5 input: Open AI is a new way to build AI that
step 6 input: Open AI is a new way to build AI that is
step 7 input: Open AI is a new way to build AI that is more
step 8 input: Open AI is a new way to build AI that is more efficient
step 9 input: Open AI is a new way to build AI that is more efficient and
step 10 input: Open AI is a new way to build AI that is more efficient and more
step 11 input: Open AI is a new way to build AI that is more efficient and more efficient
step 12 input: Open AI is a new way to build AI that is more efficient and more efficient than
step 13 input: Open AI is a new way to build AI that is more efficient and more efficient than traditional
step 14 input: Open AI is a new way to build AI that is more efficient and more efficient than traditional AI
step 15 input: Open AI is a new way to build AI that is more efficient and more efficient than traditional AI.
step 16 input: Open AI is a new way to build AI that is more efficient and more efficient than traditional AI.Input: Open AI is a
Output: Open AI is a new way to build AI that is more efficient and more efficient than traditional AI.

在上面的推理過程中,每 step 內,輸入一個 token序列,經過Embedding層將輸入token序列變為一個三維張量 [b, s, h],經過一通計算,最后經 logits 層將計算結果映射至詞表空間,輸出張量維度為 [b, s, vocab_size]。

當前輪輸出token與輸入tokens拼接,并作為下一輪的輸入tokens,反復多次。可以看出第 i+1 輪輸入數據只比第 i 輪輸入數據新增了一個 token,其他全部相同!

因此第 i+1 輪推理時必然包含了第 i 輪的部分計算。KV Cache 的出發點就在這里,緩存當前輪可重復利用的計算結果,下一輪計算時直接讀取緩存結果。

上面所舉例子并沒有使用KV Cache進行推理,請注意。

Self-Attention Without Cache

下圖給出了無 Cache 情況下,類GPT式生成式模型進行推理的過程:

在這里插入圖片描述

這種方式的問題是: 每生成一個 token,就要重新計算所有之前 token 的 Q/K/V + Attention + FFN

Self-Attention With Cache

下圖給出了有 Cache 情況下,類GPT式生成式模型進行推理的過程:

在這里插入圖片描述

Huggingface 官方代碼實現

本節將根據 Huggingface 官方代碼實現進行 KV Cache 實現講解 (只展示核心代碼,移除了大量與本文無關的邏輯)。

官方代碼鏈接: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py

下面將給出使用了 KV Cache 進行推理的代碼:

import torch
from transformers import GPT2Tokenizer, GPT2Config
from modeling_gpt2 import GPT2LMHeadModel  # copy from huggingface , 刪除了大量無關代碼def generate_text(model, tokenizer, prompt, max_new_tokens=50, eos_token_id=198):model.eval()input_ids = tokenizer.encode(prompt, return_tensors="pt")past_key_values = Noneoutput_ids = input_ids.clone()with torch.no_grad():for step in range(max_new_tokens):outputs = model(input_ids=input_ids,past_key_values=past_key_values,use_cache=True)logits = outputs.logitspast_key_values = outputs.past_key_valuesnext_token_logits = logits[:, -1, :]next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)output_ids = torch.cat([output_ids, next_token], dim=-1)if next_token.item() == eos_token_id:breakinput_ids = next_token  # 采用KV Cache后,推理過程修改的關鍵: 下一步只送入新 tokenprint(f"step {step}: {tokenizer.decode(output_ids[0])}", flush=True)return tokenizer.decode(output_ids[0])def main():config = GPT2Config()tokenizer = GPT2Tokenizer.from_pretrained("gpt2")model = GPT2LMHeadModel(config)prompt = "Once upon a time"output = generate_text(model, tokenizer, prompt)print("\nFinal output:")print(output)if __name__ == "__main__":main()

KV Cache 的引入是為了加速自回歸模型的推理速度,具體體現在:

  1. 每輪推理時,只需要計算當前輪新增 token 的 Q/K/V,而不需要重新計算所有之前 token 的 Q/K/V。

  2. 緩存當前輪計算結果,下一輪推理時直接讀取緩存結果。

在首輪推理的過程中,我們傳入的是 promt 提示詞列表,并且 KV Cache 此時為空,還未進行初始化。因此首輪推理過程需要完成 promt 提示詞列表的 keys 和 values 的緩存;由于 GPT2 由多層 GPT2Block 堆疊而成,而每一層 GPT2Block 都有一個 GPT2Attention 模塊, 因此 KV Cache 需要準備好每一層 GPT2Attention 模塊的 keys 和 values 緩存 (分層Cache - legacy_cache)。

class GPT2Model(GPT2PreTrainedModel):def forward(self,input_ids=None,past_key_values=None, cache_position=None,attention_mask=None,position_ids=None,head_mask=None,use_cache=None,):          return_legacy_cache = Falseif use_cache:# 1. 首輪推理,先進行 Legacy Cache 初始化if past_key_values is None:return_legacy_cache = Truepast_key_values = DynamicCache()# 2. 后續推理,將模型以元組形式返回的緩存重新封裝為Legacy Cache形式elif not isinstance(past_key_values, Cache):return_legacy_cache = Truepast_key_values = DynamicCache.from_legacy_cache(past_key_values)# 3. 詞嵌入 inputs_embeds = self.wte(input_ids)# 4. 位置編碼計算if cache_position is None:# 4.1 已經緩存的詞序列長度past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0# 4.2 只為當前傳入的詞生成位置序列cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)    if position_ids is None:position_ids = cache_position.unsqueeze(0) # 添加batch維度# 4.3 生成位置編碼position_embeds = self.wpe(position_ids)# 5. 詞嵌入 + 位置編碼hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)# 6. 進入堆疊GPT2Block模塊前向傳播流程for i, block in enumerate(self.h):hidden_states = block(hidden_states,past_key_values if not (self.gradient_checkpointing and self.training) else None, # 訓練時,不啟用KV Cachecache_position,causal_mask,use_cache=use_cache,)hidden_states = self.ln_f(hidden_states)hidden_states = hidden_states.view(output_shape)# 7. 將KV Cache用元組的形式進行返回 past_key_values = past_key_values if use_cache else Noneif return_legacy_cache:past_key_values = past_key_values.to_legacy_cache()return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states,past_key_values=past_key_values,hidden_states=all_hidden_states,attentions=all_self_attentions,cross_attentions=all_cross_attentions,)

下圖展示的是步驟7中以元組形式返回的KV Cache結構:

在這里插入圖片描述

下面將展示GPT2Block模塊的實現邏輯,由于不涉及KV Cache的實現細節,所以不過多展開:

class GPT2Block(GradientCheckpointingLayer):def forward(self,hidden_states: Optional[tuple[torch.FloatTensor]],past_key_value: Optional[Cache] = None,cache_position: Optional[torch.LongTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,use_cache: Optional[bool] = False,) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:# 1. 歸一化residual = hidden_stateshidden_states = self.ln_1(hidden_states)# 2. 自注意力運算attn_output, self_attn_weights = self.attn(hidden_states,past_key_value=past_key_value,cache_position=cache_position,attention_mask=attention_mask,use_cache=use_cache,)# 3. residual connectionhidden_states = attn_output + residual# 4. 歸一化 + MLP +  residual connectionresidual = hidden_stateshidden_states = self.ln_2(hidden_states)feed_forward_hidden_states = self.mlp(hidden_states)hidden_states = residual + feed_forward_hidden_statesreturn hidden_states

推理時的常規流程(無 KV Cache), 每生成一個新 token,都要:

  • 重新輸入全部歷史 token

  • 對所有歷史 token 重新計算 key 和 value

  • 這意味著重復計算,效率低,計算開銷線性增長


有了 KV Cache 后的改進:

  1. 第一次輸入完整句子,計算并緩存其 key/value;

  2. 后續每次生成新 token 時:

    • 只計算新 token 的 query、key、value;

    • 把新 token 的 key/value 插入緩存中(代碼中用 past_key_value.update(...) 完成);

    • attention 直接使用「歷史緩存 key/value + 當前新 token 的 key/value」來完成;

  3. 整個注意力的 query 只有一個(當前 token),key/value 是歷史緩存 + 當前 token

class GPT2Attention(nn.Module):def __init__(self, config, is_cross_attention=False, layer_idx=None):self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) # 輸入維度: (batch,seq_len,embed_dim) , 變換后的輸出維度: (batch,seq_len,3*embed_dim)self.c_proj = Conv1D(self.embed_dim, self.embed_dim)def forward(self,hidden_states: Optional[tuple[torch.FloatTensor]],past_key_value: Optional[Cache] = None,cache_position: Optional[torch.LongTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:# 1. 一維卷積進行線性變換和升維,然后切分成query,key,valuequery_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)# 2. (batch,seq_len,-1,head_dim) , head_dim 是多頭自注意力中每個頭切分到的維度 shape_q = (*query_states.shape[:-1], -1, self.head_dim)shape_kv = (*key_states.shape[:-1], -1, self.head_dim)# 3. 維度統一: (batch,heads,seq_len,head_dim)query_states = query_states.view(shape_q).transpose(1, 2)key_states = key_states.view(shape_kv).transpose(1, 2)value_states = value_states.view(shape_kv).transpose(1, 2)# 4. KV Cache 不為空 if past_key_value is not None:# 4.1 cache_position 記錄當前詞對應輸入詞序列中的索引cache_kwargs = {"cache_position": cache_position}# 4.2 將當前詞的key和val進行緩存,根據所在GPTBlock層級(layer_idx說明),和位于詞序列的索引(cache_kwargs說明),插入對應層的list緩存中去,同時返回對應的key和val listkey_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs)# 5. 進行經典的多頭自注意力運算(不展開細聊) attn_output, attn_weights = attention_interface(self,query_states, # 當前輸入詞的querykey_states,   # cache key list + 輸入詞的keyvalue_states,  # cache val list + 輸入詞的valattention_mask, # padding maskdropout=self.attn_dropout.p if self.training else 0.0,)attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()attn_output = self.c_proj(attn_output)attn_output = self.resid_dropout(attn_output)return attn_output, attn_weights

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/90813.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/90813.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/90813.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

代碼隨想錄算法訓練營第五十三天|圖論part4

110.字符串接龍 題目鏈接&#xff1a;110. 字符串接龍文章講解&#xff1a;代碼隨想錄思路&#xff1a; 把每個字符串看成圖的一個節點。 轉換為求無權圖兩節點的的最短路徑。求最短路徑用bfs #include <string> #include <vector> #include <iostream> #i…

Java進階4:泛型、序列化和反序列化

Java泛型 Java泛型是JDK5引入的一個新的特性&#xff0c;泛型提供了編譯時的類型安全檢測機制&#xff0c;這個機制運行程序員在編譯的時候檢測到非法的類型。泛型的本質是參數化類型&#xff0c;也就是所操作的數據類型被指定為一個參數。 泛型方法 可以寫一個泛型方法&#x…

RAG實戰指南 Day 24:上下文構建與提示工程

【RAG實戰指南 Day 24】上下文構建與提示工程 文章內容 開篇 歡迎來到"RAG實戰指南"系列的第24天&#xff01;今天我們將深入探討RAG系統中至關重要的上下文構建與提示工程技術。在檢索增強生成系統中&#xff0c;如何有效地組織檢索到的文檔片段&#xff0c;并將…

AWD的攻擊和防御手段

一、AWD相關介紹 AWD&#xff08;Attack With Defence&#xff09;是 CTF 線下賽中最接近真實攻防場景、觀賞性和對抗性最強的賽制之一。 賽制本質 人人對抗&#xff1a;所有戰隊互為攻擊者與防守者。 零和記分&#xff1a;你拿到的每一分都是別人的失分&#xff0c;總積分恒…

泛微OA8前臺SQL注入

漏洞URL&#xff1a; http://106.15.190.147/js/hrm/getdata.jsp?cmdgetSelectAllId&sql***注入點 在getdata.jsp中&#xff0c;直接將request對象交給 weaver.hrm.common.AjaxManager.getData(HttpServletRequest, ServletContext) : 方法處理 在getData方法中&#xff0…

Android 藍牙學習

在Android中&#xff0c;進行藍牙設備掃描startDiscovery需要申請位置權限&#xff0c;但有的時候并不需要申請位置權限&#xff0c;就有了android:usesPermissionFlags"neverForLocation"&#xff0c;設置這個就不用申請位置權限來進行藍牙掃描。 android:usesPerm…

Earth靶機攻略

一.環境準備 1.1Earth靶機環境準備 首先將我們解壓好的的Earth.ova放入虛擬機里&#xff0c;并配置環境 將網絡連接換成NET連接 二.信息搜集 2.1ip搜集 2.1.1使用netdiscover命令掃描靶機的ip地址,發現地址為192.168.182.137 2.2端口掃描 2.2.1使用nmap工具對目標機進行端…

java8 List常用基本操作(去重,排序,轉換等)

參考簡述網址: java8 List 根據對象某個字段或多個字段去重、篩選、List轉Map、排序、分組、統計計數等等 list簡單方法示例如下&#xff1a; 一、先定義一個訂單對象&#xff08;Order&#xff09; public class Order {private Long id;private Long userId;private Strin…

【linux】HAProxy七層代理

HAPorxy簡介HAProxy&#xff08;High Availability Proxy&#xff09;是一款高性能、開源的負載均衡器與反向代理服務器&#xff0c;主要用于 TCP 和 HTTP 協議的流量分發&#xff0c;廣泛應用于高并發、高可用的網絡架構中核心特性超高性能與低資源消耗基于單進程、多線程模型…

Android 的16 KB內存頁設備需要硬件支持嗎,還是只需要手機升級到Android15系統就可以

Android 的16 KB內存頁設備需要硬件支持嗎&#xff0c;還是只需要手機升級到Android15系統就可以&#xff1f; deepseek回答&#xff1a; DeepSeek R1 思考 3 秒 思考過程 好的&#xff0c;用戶現在問的是Android的16 KB內存頁設備是否需要硬件支持&#xff0c;還是只需要升級到…

相機內外參矩陣:從3D世界坐標到2D像素坐標變換

相機內外參矩陣&#xff1a;從3D世界坐標到2D像素坐標變換介紹**1. 內參矩陣&#xff08;Intrinsic Matrix, K&#xff09;****2. 外參矩陣&#xff08;Extrinsic Matrix, [R|t]&#xff09;****3. 完整投影過程&#xff08;世界坐標 → 像素坐標&#xff09;****步驟1&#xf…

哈希指針與數據結構:構建可信數字世界的基石

一、哈希指針的核心原理哈希指針是一種創新型數據結構&#xff0c;融合了傳統指針的定位功能與密碼學哈希的驗證能力&#xff1a;雙重功能&#xff1a;既存儲數據地址&#xff0c;又包含該數據的哈希值&#xff0c;實現數據定位與完整性驗證的統一。抗篡改機制&#xff1a;數據…

java實現一個方法,isTure則程序繼續往下,為false則return的鏈式寫法

以下是實現鏈式條件檢查的Java方法&#xff0c;采用函數式風格設計。代碼包含一個Chainable類&#xff0c;支持連續的check方法和多個終止操作&#xff08;如then, orElse等&#xff09;&#xff0c;滿足在條件為false時中斷鏈式調用并返回默認值的需求&#xff1a;import java…

數據結構學習之堆

本篇我們將學習新的數據結構——二叉樹。 作者的個人gitee&#xff1a;樓田莉子 (riko-lou-tian) - Gitee.com 目錄 樹的概念 樹形結構 非樹形結構 樹的相關術語 樹的表示 樹在實際生活上的應用 二叉樹 慢二叉樹 完全二叉樹 二叉樹的儲存結構 二叉樹的存儲結構 順序結構…

【csdn問答社區分析】前端開發熱點問題全解析

前端時間我在csdn問答社區的前端部分"視察”了一圈發現了大家的問題主要集中在以下方面一、框架與組件庫使用問題 Vue相關問題 組件化開發&#xff1a;如avue-crud組件自定義樣式不生效、el-select大數據分頁懶加載、element-plus表格動態列校驗等。功能實現&#xff1a;包…

Pycharm2025 安裝教程 免費分享 沒任何套路

Pycharm 安裝也是很簡單的&#xff0c;簡單過一下流程&#xff0c;如果需要的可以轉存下載到自己電腦上。我用夸克網盤分享了「pycharm2025」&#xff0c;復制鏈接瀏覽器打開轉存后即可下載。鏈接&#xff1a;https://pan.quark.cn/s/4bb74a939332備注&#xff1a;附帶2023-202…

Javaweb————什么是超文本傳輸協議?

&#x1f3cd;?&#x1f3cd;?&#x1f3cd;?引言&#xff1a;什么是協議&#xff1f; 協議是一種約定&#xff0c;規定好一種信息的格式&#xff0c;如果發送方按照這種請求格式發送信息,那么接 收端就要按照這樣的格式解析數據,否則就會出錯&#xff0c;這就是協議 常用協…

UniappDay03

1.熱門推薦-準備工作// 用defineProps獲取頁面參數,query const query defineProps<{type: string }>() const currHot hotMap.find((v) > v.type query.type) // 動態設置標題 uni.setNavigationBarTitle({ title: currHot!.title }) </script>2.獲取熱門推…

基于動態增強的 LLM 置信度方法研究

基于動態增強的 LLM 置信度方法研究 一、引言(Introduction) 大型語言模型(LLM)的性能提升高度依賴于對模型內部表征的精準調控 —— 表征工程通過優化模型中間層隱藏狀態的傳遞規律,能夠在不改變模型參數的前提下顯著提升任務適應性(Wei et al., 2022)。當前主流方法中…

ComfyUI中運行Wan 2.1工作流,電影級視頻,兼容Mac Windows

魔當(LM Downloader)是一個大模型應用下載工具 &#xff0c;目前 魔當 已經支持ComfyUI下載Wan 2.1視頻模型。 魔當下載地址 https://seemts.com/ 先看生成效果 原始圖片&#xff0c;你可以保存到自己電腦上測試 生成視頻&#xff1a; 推薦提示詞&#xff1a; A futurist…