原理
KV Cache的本質就是避免重復計算,把需要重復計算的結果進行緩存,生成式模型的新的token的產生需要用到之前的所有token的 K , V K,V K,V,在計算注意力的時候是當前的 Q Q Q和所有的 K , V K,V K,V來進行計算,所以是緩存 K , V K,V K,V。
由于Causal Mask的存在,前面已經生成的token不需要與后面的token產生attention,也就是用不到前面token的 Q Q Q,用的上前面token的 K , V K,V K,V,具體的公式如下:
a t t 1 ( Q , K , V ) = s o f t m a x ( Q 1 K 1 T D ) V 1 att_1(Q,K,V)=softmax(\frac{Q_1K_1^T}{\sqrt{D}})V_1 att1?(Q,K,V)=softmax(D?Q1?K1T??)V1?
a t t 2 ( Q , K , V ) = s o f t m a x ( Q 2 K 1 T D ) V 1 + s o f t m a x ( Q 2 K 2 T D ) V 2 att_2(Q,K,V)=softmax(\frac{Q_2K_1^T}{\sqrt{D}})V_1+softmax(\frac{Q_2K_2^T}{\sqrt{D}})V_2 att2?(Q,K,V)=softmax(D?Q2?K1T??)V1?+softmax(D?Q2?K2T??)V2?
a t t 3 ( Q , K , V ) = s o f t m a x ( Q 3 K 1 T D ) V 1 + s o f t m a x ( Q 3 K 2 T D ) V 2 + s o f t m a x ( Q 3 K 3 T D ) V 3 att_3(Q,K,V)=softmax(\frac{Q_3K_1^T}{\sqrt{D}})V_1+softmax(\frac{Q_3K_2^T}{\sqrt{D}})V_2+softmax(\frac{Q_3K_3^T}{\sqrt{D}})V_3 att3?(Q,K,V)=softmax(D?Q3?K1T??)V1?+softmax(D?Q3?K2T??)V2?+softmax(D?Q3?K3T??)V3?
可以看出, K , V K,V K,V存在重復計算的情況,因此可以進行Cache。
KV Cache只適用于Decoder架構,因為有Causal Mask的存在,如果是Encoder,處理的是輸入序列,是一次性完成整個序列attention的計算,并不像Decoder一樣有自左向右的重復性的計算,Encoder由于其一次性和并行性,用不上KV-Cache,而解碼器由于其自回歸性,KV Cache是很有用的。