一、我們先來回憶一下在transformer中KV在哪里出現過,都有什么作用?
α的計算過程:
這里引入三個向量:
圖中的q為Query,用來匹配key值
圖中的k為key,用來被Query匹配
圖中的Value,是用來被進行加權平均的
由
這一步我們知道α就是K與Q的匹配程度,匹配程度越高則權重越大。
Wq、Wk、Wv這三個參數矩陣都需要從訓練數據中學習
二、為什么要使用KV緩存
使用KV緩存是為減少生成token時候的矩陣運算。
? ? ? ? 因為在transformer中文本是逐個token生成的,每次新的預測會基于之前生成的所有token的上下文信息,這種對順序數據的依賴會減慢生成過程,因為每次預測下一個token都需要重新處理序列中所有之前的token。
? ? ? ? 比說我們要預測第100個token,那么模型必須使用前面99個token的信息,這就需要對這些token做矩陣運算,而這個矩陣運算是非常耗時的。所以KV緩存就是為了減少這種耗時的矩陣運算,在推理過程中會把鍵和值放在緩存中,這樣模型就可以在后續生成token的時候,直接訪問緩存,而不需要重新計算。
三、KV緩存具體是怎么實現的?
這兩張圖分別是有緩存和沒有緩存的情況
因為是第一個token,所以有沒有緩存計算過程沒有差別
? ? ? 接下來到第二個token時,可以看到紫色標出的就是緩存下來的key和value,在沒有緩存的情況下KV都要重新計算。如果做了緩存就只需要把歷史的KV拿出來,同時只計算最新的那個token的KV再拼接成一個大矩陣就行了。
對比一下,有緩存的計算量明顯減少了一半
那后面的token一樣,每次歷史計算過的鍵和值就不用重新計算了,這樣就極大減少了self attention 的計算量,從序列長度的二次方直接變成了線性