Fast Inference from Transformers via Speculative Decoding
論文地址:https://arxiv.org/pdf/2211.17192
speculative sampling
為了從分布 p ( x ) p(x) p(x) 中采樣,我們實際上是從分布 q ( x ) q(x) q(x) 中采樣 x x x,如果 q ( x ) ≤ p ( x ) q(x) \leq p(x) q(x)≤p(x),則保留該樣本;如果 q ( x ) > p ( x ) q(x) > p(x) q(x)>p(x),則以概率 1 ? p ( x ) q ( x ) 1 - \frac{p(x)}{q(x)} 1?q(x)p(x)? 拒絕該樣本,并重新從調整后的分布 p ′ ( x ) = norm ( max ? ( 0 , p ( x ) ? q ( x ) ) ) p'(x) = \text{norm}(\max(0, p(x)-q(x))) p′(x)=norm(max(0,p(x)?q(x))) 中采樣。對于任何分布 p ( x ) p(x) p(x) 和 q ( x ) q(x) q(x),以及以此方式采樣的 x x x,確實有 x ~ p ( x ) x \sim p(x) x~p(x)。
給定通過在條件前綴上運行 M q M_q Mq? 獲得的分布 q ( x ) q(x) q(x),我們可以采樣一個標記 x 1 ~ q ( x ) x_1 \sim q(x) x1?~q(x)。然后,我們通過在前綴上運行 M p M_p Mp? 來計算分布 p ( x ) p(x) p(x),同時并行地推測性地計算下一個標記 x 2 x_2 x2? 的分布,即在前綴上追加 x 1 x_1 x1? 后運行 M p M_p Mp?。一旦兩項計算都完成,我們就按上述方式處理:如果 x 1 x_1 x1? 被拒絕,我們丟棄 x 2 x_2 x2? 的計算,并從調整后的分布中重新采樣 x 1 x_1 x1?;如果 x 1 x_1 x1? 被接受,我們就保留兩個標記。算法 1 將這一想法推廣為一次采樣 1 到 γ + 1 \gamma + 1 γ+1 個標記。
分析
有幾個證明需要注意一下:
單次算法期望能生成的token
-
單次算法期望能生成的token數量服從幾何分布,但是求和項是有限制的,這里推導下?
-
??接受率β的定義??
設目標模型分布為p(x)
,草稿模型分布為q(x)
。草稿模型生成的單個token被目標模型接受的概率為:
β = ∑ x min ? ( q ( x ) , p ( x ) ) \beta = \sum_x \min\left(q(x), p(x)\right) β=x∑?min(q(x),p(x))
- ??拒絕率α的定義??
α = 1 ? β = 1 ? ∑ x min ? ( p ( x ) , q ( x ) ) x \alpha = 1 - \beta = 1 - \sum_x \min(p(x), q(x)) x α=1?β=1?x∑?min(p(x),q(x))x
-
假設每個token的接受事件獨立且同分布(i.i.d.),草稿模型一次生成
K
個token: -
??首次拒絕發生在位置
r
?? 的概率為:P ( r ) = ( 1 ? β ) β r ? 1 ( 1 ≤ r ≤ K ) P(r) = (1-\beta) \beta^{r-1} \quad (1 \leq r \leq K) P(r)=(1?β)βr?1(1≤r≤K)
所有token均被接受?? 的概率為: β K \beta^K βK
-
綜上期望能生成的token數量為:
γ = ∑ r = 1 K r ? P ( r ) ? 拒絕前生成的token + K ? β K ? 全接受時生成K個token \gamma = \underbrace{\sum_{r=1}^K r \cdot P(r)}_{\text{拒絕前生成的token}} + \underbrace{K \cdot \beta^K}_{\text{全接受時生成K個token}} γ=拒絕前生成的token r=1∑K?r?P(r)??+全接受時生成K個token K?βK??
代入 P ( r ) P(r) P(r) 后展開:
γ = ∑ r = 1 K r ? ( 1 ? β ) β r ? 1 + K β K \gamma = \sum_{r=1}^K r \cdot (1-\beta) \beta^{r-1} + K \beta^K γ=r=1∑K?r?(1?β)βr?1+KβK
- 幾何級數求和?
幾何級數求和公式為:
對 ∑ r = 1 K r β r ? 1 \sum_{r=1}^K r \beta^{r-1} ∑r=1K?rβr?1 求和處理:
- ?令 S = ∑ r = 1 K β r ? 1 S = \sum_{r=1}^K \beta^{r-1} S=∑r=1K?βr?1?:
S = 1 + β + β 2 + ? + β K ? 1 = 1 ? β K 1 ? β S = 1 + \beta + \beta^2 + \cdots + \beta^{K-1} = \frac{1-\beta^K}{1-\beta} S=1+β+β2+?+βK?1=1?β1?βK?
- ??對 S S S 求導??:
∑ r = 1 K r β r ? 1 = d d β ( ∑ r = 0 K β r ) = d d β ( 1 ? β K + 1 1 ? β ) = 1 ? ( K + 1 ) β K + K β K + 1 ( 1 ? β ) 2 \sum_{r=1}^K r \beta^{r-1} = \frac{d}{d\beta} \left( \sum_{r=0}^K \beta^r \right) = \frac{d}{d\beta} \left( \frac{1-\beta^{K+1}}{1-\beta} \right) = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} ∑r=1K?rβr?1=dβd?(∑r=0K?βr)=dβd?(1?β1?βK+1?)=(1?β)21?(K+1)βK+KβK+1?
- ??代入γ表達式??:
γ = ( 1 ? β ) ? 1 ? ( K + 1 ) β K + K β K + 1 ( 1 ? β ) 2 + K β K = 1 ? ( K + 1 ) β K + K β K + 1 1 ? β + K β K \gamma = (1-\beta) \cdot \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} + K\beta^K = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{1-\beta} + K\beta^K γ=(1?β)?(1?β)21?(K+1)βK+KβK+1?+KβK=1?β1?(K+1)βK+KβK+1?+KβK
- 化簡??:
γ = 1 ? β K 1 ? β \gamma = \frac{1 - \beta^K}{1-\beta} γ=1?β1?βK?
??物理意義??:
- 當 K → ∞ K \to \infty K→∞時, γ → 1 1 ? β = 1 α \gamma \to \frac{1}{1-\beta} = \frac{1}{\alpha} γ→1?β1?=α1?(理想無限長草稿)。
- 例如 β \beta β = 0.8` 時, γ max = 5 \gamma_{\text{max}} = 5 γmax?=5,即平均每次生成5個token。
得證
Walltime的時間優化
??定理 3.8??:算法 1 在總運行時間上的預期改進因子為
‘ 1 ? α γ + 1 ( 1 ? α ) ( γ c + 1 ) ‘ `\frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)}` ‘(1?α)(γc+1)1?αγ+1?‘
??證明??:
記運行目標模型 M p M_p Mp? ??單步??的成本為 T T T。
算法 1 的??單次運行成本??為 T c γ + T Tc\gamma + T Tcγ+T(其中 c γ T c\gamma T cγT用于運行近似模型 M q M_q Mq? γ \gamma γ 次, T T T 用于運行 M p M_p Mp? 一次)。
根據單次算法期望能生成的token算法推導,單次運行??平均生成 token 數量??為 1 ? α γ + 1 1 ? α \dfrac{1 - \alpha^{\gamma + 1}}{1 - \alpha} 1?α1?αγ+1?。
因此,使用算法 1 生成單個 token 的??總體預期成本??為:
( c γ + 1 ) ( 1 ? α ) 1 ? α γ + 1 T ‘ \frac{(c\gamma + 1)(1 - \alpha)}{1 - \alpha^{\gamma + 1}}T` 1?αγ+1(cγ+1)(1?α)?T‘
由于標準解碼算法生成單個 token 的成本為 T
,
比較可得上述改進因子。?
(注:符號 “?” 表示證明結束)
關鍵術語說明:
英文術語 | 中文翻譯 | 符號 | 含義 |
---|---|---|---|
walltime | 總運行時間 | - | 算法從啟動到結束的時鐘時間 |
expected improvement factor | 預期改進因子 | - | 優化后時間開銷的縮減比例 |
cost per step | 單步成本 | T T T | 目標模型 M p M_p Mp? 推理一個 token 的時間 |
approximation model | 近似模型 | M q M_q Mq? | 快速但低精度的草稿模型 |
tokens | 標記(Token) | - | 模型生成的基本文本單位 |
rejection rate | 拒絕率 | α \alpha α | 草稿模型 M q M_q Mq? 的 token 被目標模型 M p M_p Mp? 拒絕的概率 |
γ \gamma γ | 生成長度 | γ \gamma γ | 草稿模型單次運行的 token 生成數 |
cost ratio | 成本比 | c c c | M q M_q Mq? 與 M p M_p Mp? 的單步時間比值( 0 < c < 1 0 < c < 1 0<c<1) |
公式解析:
- ??改進因子??
1 ? α γ + 1 ( 1 ? α ) ( γ c + 1 ) \frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)} (1?α)(γc+1)1?αγ+1?
- ??分子?? 1 ? α γ + 1 1 - \alpha^{\gamma+1} 1?αγ+1:草稿模型連續生成
\gamma
個 token 均未被拒絕的概率補償 - ??分母?? ( 1 ? α ) (1-\alpha) (1?α):單 token 接受率, γ c + 1 \gamma c + 1 γc+1:草稿+驗證的總時間成本
該值 ??>1?? 時表示加速,值越大加速效果越顯著
- ??單 token 成本公式??
( c γ + 1 ) ( 1 ? α ) 1 ? α γ + 1 T \frac{(c\gamma+1)(1-\alpha)}{1-\alpha^{\gamma+1}}T 1?αγ+1(cγ+1)(1?α)?T
- ??分子?? ( c γ + 1 ) ( 1 ? α ) T (c\gamma+1)(1-\alpha)T (cγ+1)(1?α)T:草稿生成+驗證的實際計算量
- ??分母?? 1 ? α γ + 1 1-\alpha^{\gamma+1} 1?αγ+1:有效 token 產出的概率加權
操作數計算
操作數的計算量也是類似的,直接貼結論了
( 1 ? α ) ( γ c ^ + γ + 1 ) 1 ? α γ + 1 \frac{(1-\alpha)(\gamma \hat{c}+\gamma+1)}{1-\alpha^{\gamma+1}} 1?αγ+1(1?α)(γc^+γ+1)?
采樣和原分布的等價性證明
參考https://arxiv.org/pdf/2302.01318
其中需要一步代換證明下面兩個公式等價:
原始公式
第一個公式:
= 1 ? ∑ x ′ min ? ( p ( x ′ ) , q ( x ′ ) ) =1-\sum_{x^{\prime}}\min\left(p\left(x^{\prime}\right),q\left(x^{\prime}\right)\right) =1?x′∑?min(p(x′),q(x′))
第二個公式:
= ∑ x ′ max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) =\sum_{x^{\prime}}\max\left(0,q\left(x^{\prime}\right)-p\left(x^{\prime}\right)\right) =x′∑?max(0,q(x′)?p(x′))
推導步驟
步驟 1: 應用 min 函數的恒等式
對于任何兩個實數 a a a 和 b b b,都存在以下恒等關系:
min ? ( a , b ) = a ? max ? ( 0 , a ? b ) \min(a,b) = a - \max(0, a - b) min(a,b)=a?max(0,a?b)
令 b = p ( x ′ ) b = p(x') b=p(x′), a = q ( x ′ ) a = q(x') a=q(x′),得到:
min ? ( p ( x ′ ) , q ( x ′ ) ) = q ( x ′ ) ? max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) \min(p(x'),q(x')) = q(x') - \max(0, q(x') - p(x')) min(p(x′),q(x′))=q(x′)?max(0,q(x′)?p(x′))
步驟 2: 代入第一個公式
將恒等式代入原始公式:
1 ? ∑ x ′ min ? ( p ( x ′ ) , q ( x ′ ) ) = 1 ? ∑ x ′ [ q ( x ′ ) ? max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) ] \begin{aligned} &1 - \sum_{x^{\prime}} \min(p(x'),q(x')) \\ &= 1 - \sum_{x^{\prime}} \left[ q(x') - \max(0, q(x') - p(x')) \right] \end{aligned} ?1?x′∑?min(p(x′),q(x′))=1?x′∑?[q(x′)?max(0,q(x′)?p(x′))]?
步驟 3: 拆分求和運算
將求和符號分配到表達式內部:
= 1 ? [ ∑ x ′ p ( x ′ ) ? ∑ x ′ max ? ( 0 , p ( x ′ ) ? q ( x ′ ) ) ] = 1 - \left[ \sum_{x^{\prime}} p(x') - \sum_{x^{\prime}} \max(0, p(x') - q(x')) \right] =1?[x′∑?p(x′)?x′∑?max(0,p(x′)?q(x′))]
= 1 ? ∑ x ′ q ( x ′ ) + ∑ x ′ max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) = 1 - \sum_{x^{\prime}} q(x') + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =1?x′∑?q(x′)+x′∑?max(0,q(x′)?p(x′))
步驟 4: 應用概率分布性質
因為 p p p 和 q q q 都是概率分布函數,滿足:
∑ x ′ p ( x ′ ) = 1 和 ∑ x ′ q ( x ′ ) = 1 \sum_{x^{\prime}} p(x') = 1 \quad \text{和} \quad \sum_{x^{\prime}} q(x') = 1 x′∑?p(x′)=1和x′∑?q(x′)=1
代入表達式:
= 1 ? 1 + ∑ x ′ max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) = 1 - 1 + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =1?1+x′∑?max(0,q(x′)?p(x′))
= ∑ x ′ max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) = \sum_{x^{\prime}} \max(0, q(x') - p(x')) =x′∑?max(0,q(x′)?p(x′))
得證
Reference
https://arxiv.org/pdf/2211.17192