論文閱讀:speculative decoding

Fast Inference from Transformers via Speculative Decoding

論文地址:https://arxiv.org/pdf/2211.17192

speculative sampling

為了從分布 p ( x ) p(x) p(x) 中采樣,我們實際上是從分布 q ( x ) q(x) q(x) 中采樣 x x x,如果 q ( x ) ≤ p ( x ) q(x) \leq p(x) q(x)p(x),則保留該樣本;如果 q ( x ) > p ( x ) q(x) > p(x) q(x)>p(x),則以概率 1 ? p ( x ) q ( x ) 1 - \frac{p(x)}{q(x)} 1?q(x)p(x)? 拒絕該樣本,并重新從調整后的分布 p ′ ( x ) = norm ( max ? ( 0 , p ( x ) ? q ( x ) ) ) p'(x) = \text{norm}(\max(0, p(x)-q(x))) p(x)=norm(max(0,p(x)?q(x))) 中采樣。對于任何分布 p ( x ) p(x) p(x) q ( x ) q(x) q(x),以及以此方式采樣的 x x x,確實有 x ~ p ( x ) x \sim p(x) xp(x)

給定通過在條件前綴上運行 M q M_q Mq? 獲得的分布 q ( x ) q(x) q(x),我們可以采樣一個標記 x 1 ~ q ( x ) x_1 \sim q(x) x1?q(x)。然后,我們通過在前綴上運行 M p M_p Mp? 來計算分布 p ( x ) p(x) p(x),同時并行地推測性地計算下一個標記 x 2 x_2 x2? 的分布,即在前綴上追加 x 1 x_1 x1? 后運行 M p M_p Mp?。一旦兩項計算都完成,我們就按上述方式處理:如果 x 1 x_1 x1? 被拒絕,我們丟棄 x 2 x_2 x2? 的計算,并從調整后的分布中重新采樣 x 1 x_1 x1?;如果 x 1 x_1 x1? 被接受,我們就保留兩個標記。算法 1 將這一想法推廣為一次采樣 1 到 γ + 1 \gamma + 1 γ+1 個標記。
運行算法

分析

有幾個證明需要注意一下:

單次算法期望能生成的token
  1. 單次算法期望能生成的token數量服從幾何分布,但是求和項是有限制的,這里推導下?

  2. ??接受率β的定義??
    設目標模型分布為 p(x),草稿模型分布為 q(x)。草稿模型生成的單個token被目標模型接受的概率為:

β = ∑ x min ? ( q ( x ) , p ( x ) ) \beta = \sum_x \min\left(q(x), p(x)\right) β=x?min(q(x),p(x))

  1. ??拒絕率α的定義??

α = 1 ? β = 1 ? ∑ x min ? ( p ( x ) , q ( x ) ) x \alpha = 1 - \beta = 1 - \sum_x \min(p(x), q(x)) x α=1?β=1?x?min(p(x),q(x))x

  • 假設每個token的接受事件獨立且同分布(i.i.d.),草稿模型一次生成 K 個token:

  • ??首次拒絕發生在位置 r?? 的概率為:

    P ( r ) = ( 1 ? β ) β r ? 1 ( 1 ≤ r ≤ K ) P(r) = (1-\beta) \beta^{r-1} \quad (1 \leq r \leq K) P(r)=(1?β)βr?1(1rK)

    所有token均被接受?? 的概率為: β K \beta^K βK

  • 綜上期望能生成的token數量為:

    γ = ∑ r = 1 K r ? P ( r ) ? 拒絕前生成的token + K ? β K ? 全接受時生成K個token \gamma = \underbrace{\sum_{r=1}^K r \cdot P(r)}_{\text{拒絕前生成的token}} + \underbrace{K \cdot \beta^K}_{\text{全接受時生成K個token}} γ=拒絕前生成的token r=1K?r?P(r)??+全接受時生成Ktoken K?βK??

代入 P ( r ) P(r) P(r) 后展開:

γ = ∑ r = 1 K r ? ( 1 ? β ) β r ? 1 + K β K \gamma = \sum_{r=1}^K r \cdot (1-\beta) \beta^{r-1} + K \beta^K γ=r=1K?r?(1?β)βr?1+KβK

  1. 幾何級數求和?

幾何級數求和公式為:

∑ r = 1 K r β r ? 1 \sum_{r=1}^K r \beta^{r-1} r=1K?rβr?1 求和處理:

  • ?令 S = ∑ r = 1 K β r ? 1 S = \sum_{r=1}^K \beta^{r-1} S=r=1K?βr?1?:

S = 1 + β + β 2 + ? + β K ? 1 = 1 ? β K 1 ? β S = 1 + \beta + \beta^2 + \cdots + \beta^{K-1} = \frac{1-\beta^K}{1-\beta} S=1+β+β2+?+βK?1=1?β1?βK?

  • ??對 S S S 求導??:

∑ r = 1 K r β r ? 1 = d d β ( ∑ r = 0 K β r ) = d d β ( 1 ? β K + 1 1 ? β ) = 1 ? ( K + 1 ) β K + K β K + 1 ( 1 ? β ) 2 \sum_{r=1}^K r \beta^{r-1} = \frac{d}{d\beta} \left( \sum_{r=0}^K \beta^r \right) = \frac{d}{d\beta} \left( \frac{1-\beta^{K+1}}{1-\beta} \right) = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} r=1K?rβr?1=dβd?(r=0K?βr)=dβd?(1?β1?βK+1?)=(1?β)21?(K+1)βK+KβK+1?

  • ??代入γ表達式??:

γ = ( 1 ? β ) ? 1 ? ( K + 1 ) β K + K β K + 1 ( 1 ? β ) 2 + K β K = 1 ? ( K + 1 ) β K + K β K + 1 1 ? β + K β K \gamma = (1-\beta) \cdot \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} + K\beta^K = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{1-\beta} + K\beta^K γ=(1?β)?(1?β)21?(K+1)βK+KβK+1?+KβK=1?β1?(K+1)βK+KβK+1?+KβK

  • 化簡??:

γ = 1 ? β K 1 ? β \gamma = \frac{1 - \beta^K}{1-\beta} γ=1?β1?βK?

??物理意義??:

  • K → ∞ K \to \infty K時, γ → 1 1 ? β = 1 α \gamma \to \frac{1}{1-\beta} = \frac{1}{\alpha} γ1?β1?=α1?(理想無限長草稿)。
  • 例如 β \beta β = 0.8` 時, γ max = 5 \gamma_{\text{max}} = 5 γmax?=5,即平均每次生成5個token。

得證

Walltime的時間優化

??定理 3.8??:算法 1 在總運行時間上的預期改進因子為
‘ 1 ? α γ + 1 ( 1 ? α ) ( γ c + 1 ) ‘ `\frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)}` (1?α)(γc+1)1?αγ+1?

??證明??:
記運行目標模型 M p M_p Mp? ??單步??的成本為 T T T
算法 1 的??單次運行成本??為 T c γ + T Tc\gamma + T Tcγ+T(其中 c γ T c\gamma T cγT用于運行近似模型 M q M_q Mq? γ \gamma γ 次, T T T 用于運行 M p M_p Mp? 一次)。
根據單次算法期望能生成的token算法推導,單次運行??平均生成 token 數量??為 1 ? α γ + 1 1 ? α \dfrac{1 - \alpha^{\gamma + 1}}{1 - \alpha} 1?α1?αγ+1?
因此,使用算法 1 生成單個 token 的??總體預期成本??為:
( c γ + 1 ) ( 1 ? α ) 1 ? α γ + 1 T ‘ \frac{(c\gamma + 1)(1 - \alpha)}{1 - \alpha^{\gamma + 1}}T` 1?αγ+1(cγ+1)(1?α)?T
由于標準解碼算法生成單個 token 的成本為 T
比較可得上述改進因子。?
(注:符號 “?” 表示證明結束)


關鍵術語說明:

英文術語中文翻譯符號含義
walltime總運行時間-算法從啟動到結束的時鐘時間
expected improvement factor預期改進因子-優化后時間開銷的縮減比例
cost per step單步成本 T T T目標模型 M p M_p Mp? 推理一個 token 的時間
approximation model近似模型 M q M_q Mq?快速但低精度的草稿模型
tokens標記(Token)-模型生成的基本文本單位
rejection rate拒絕率 α \alpha α草稿模型 M q M_q Mq? 的 token 被目標模型 M p M_p Mp? 拒絕的概率
γ \gamma γ生成長度 γ \gamma γ草稿模型單次運行的 token 生成數
cost ratio成本比 c c c M q M_q Mq? M p M_p Mp? 的單步時間比值( 0 < c < 1 0 < c < 1 0<c<1

公式解析:

  1. ??改進因子??
    1 ? α γ + 1 ( 1 ? α ) ( γ c + 1 ) \frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)} (1?α)(γc+1)1?αγ+1?
  • ??分子?? 1 ? α γ + 1 1 - \alpha^{\gamma+1} 1?αγ+1:草稿模型連續生成 \gamma 個 token 均未被拒絕的概率補償
  • ??分母?? ( 1 ? α ) (1-\alpha) (1?α):單 token 接受率, γ c + 1 \gamma c + 1 γc+1:草稿+驗證的總時間成本

該值 ??>1?? 時表示加速,值越大加速效果越顯著

  1. ??單 token 成本公式??
    ( c γ + 1 ) ( 1 ? α ) 1 ? α γ + 1 T \frac{(c\gamma+1)(1-\alpha)}{1-\alpha^{\gamma+1}}T 1?αγ+1(cγ+1)(1?α)?T
  • ??分子?? ( c γ + 1 ) ( 1 ? α ) T (c\gamma+1)(1-\alpha)T (cγ+1)(1?α)T:草稿生成+驗證的實際計算量
  • ??分母?? 1 ? α γ + 1 1-\alpha^{\gamma+1} 1?αγ+1:有效 token 產出的概率加權
操作數計算

操作數的計算量也是類似的,直接貼結論了

( 1 ? α ) ( γ c ^ + γ + 1 ) 1 ? α γ + 1 \frac{(1-\alpha)(\gamma \hat{c}+\gamma+1)}{1-\alpha^{\gamma+1}} 1?αγ+1(1?α)(γc^+γ+1)?

采樣和原分布的等價性證明

參考https://arxiv.org/pdf/2302.01318
其中需要一步代換證明下面兩個公式等價:

原始公式

第一個公式:
= 1 ? ∑ x ′ min ? ( p ( x ′ ) , q ( x ′ ) ) =1-\sum_{x^{\prime}}\min\left(p\left(x^{\prime}\right),q\left(x^{\prime}\right)\right) =1?x?min(p(x),q(x))

第二個公式:
= ∑ x ′ max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) =\sum_{x^{\prime}}\max\left(0,q\left(x^{\prime}\right)-p\left(x^{\prime}\right)\right) =x?max(0,q(x)?p(x))

推導步驟

步驟 1: 應用 min 函數的恒等式

對于任何兩個實數 a a a b b b,都存在以下恒等關系:
min ? ( a , b ) = a ? max ? ( 0 , a ? b ) \min(a,b) = a - \max(0, a - b) min(a,b)=a?max(0,a?b)

b = p ( x ′ ) b = p(x') b=p(x) a = q ( x ′ ) a = q(x') a=q(x),得到:
min ? ( p ( x ′ ) , q ( x ′ ) ) = q ( x ′ ) ? max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) \min(p(x'),q(x')) = q(x') - \max(0, q(x') - p(x')) min(p(x),q(x))=q(x)?max(0,q(x)?p(x))

步驟 2: 代入第一個公式

將恒等式代入原始公式:
1 ? ∑ x ′ min ? ( p ( x ′ ) , q ( x ′ ) ) = 1 ? ∑ x ′ [ q ( x ′ ) ? max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) ] \begin{aligned} &1 - \sum_{x^{\prime}} \min(p(x'),q(x')) \\ &= 1 - \sum_{x^{\prime}} \left[ q(x') - \max(0, q(x') - p(x')) \right] \end{aligned} ?1?x?min(p(x),q(x))=1?x?[q(x)?max(0,q(x)?p(x))]?

步驟 3: 拆分求和運算

將求和符號分配到表達式內部:
= 1 ? [ ∑ x ′ p ( x ′ ) ? ∑ x ′ max ? ( 0 , p ( x ′ ) ? q ( x ′ ) ) ] = 1 - \left[ \sum_{x^{\prime}} p(x') - \sum_{x^{\prime}} \max(0, p(x') - q(x')) \right] =1?[x?p(x)?x?max(0,p(x)?q(x))]
= 1 ? ∑ x ′ q ( x ′ ) + ∑ x ′ max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) = 1 - \sum_{x^{\prime}} q(x') + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =1?x?q(x)+x?max(0,q(x)?p(x))

步驟 4: 應用概率分布性質

因為 p p p q q q 都是概率分布函數,滿足:
∑ x ′ p ( x ′ ) = 1 和 ∑ x ′ q ( x ′ ) = 1 \sum_{x^{\prime}} p(x') = 1 \quad \text{和} \quad \sum_{x^{\prime}} q(x') = 1 x?p(x)=1x?q(x)=1

代入表達式:
= 1 ? 1 + ∑ x ′ max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) = 1 - 1 + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =1?1+x?max(0,q(x)?p(x))
= ∑ x ′ max ? ( 0 , q ( x ′ ) ? p ( x ′ ) ) = \sum_{x^{\prime}} \max(0, q(x') - p(x')) =x?max(0,q(x)?p(x))

得證

Reference

https://arxiv.org/pdf/2211.17192

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/86819.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/86819.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/86819.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

java操作word里的表格

依賴&#xff1a; <dependency><groupId>com.techCoLtd</groupId><artifactId>aspose-words-16.4.0-jdk16</artifactId><classifier>jdk16</classifier> </dependency>/*** 刪除表格及表格的行* throws Exception*/ private s…

單鏈表經典算法題之分割鏈表

給定一個頭結點和一個值x&#xff0c;是鏈表中所有小于x的值都在x前面 typedef struct ListNode ListNode; struct ListNode* partition(struct ListNode* head, int x) { //思路一&#xff1a;在原鏈表上進行修改 //思路二&#xff1a;創建新鏈表&#xff0c;使用哨兵位&…

Modbus TCP轉DeviceNet網關連接ABB變頻器配置案例

某工廠需要將支持Modbus TCP協議的上位機控制系統&#xff08;如PLC或SCADA&#xff09;與支持DeviceNet協議的變頻器&#xff08;如ABB ACS880、施耐德ATV320等&#xff09;進行通信。為實現協議轉換&#xff0c;采用開疆智能Modbus TCP轉DeviceNet網關KJ-DVCZ-MTCPS作為中間設…

【力扣 簡單 C++】206. 反轉鏈表

目錄 題目 解法一&#xff1a;迭代 解法二&#xff1a;遞歸 題目 待添加 解法一&#xff1a;迭代 class Solution { private:ListNode* reverse(ListNode* head){ListNode* newHead {};while (head){ListNode* nextNode {head->next};head->next newHead;newHead …

計算機視覺之三維重建(深入淺出SfM與SLAM核心算法)—— 1. 攝像機幾何

文章目錄 1. 針孔相機1.1. 針孔成像1.2. 光圈對成像的影響 2. 透視投影相機2.1. 透鏡成像2.2. 失焦2.3. 徑向畸變2.4. 透視投影的性質 3. 世界坐標系到像素坐標系的變換4. 其它相機模型4.1. 弱透視投影攝像機4.2. 正交投影攝像機4.3. 各種攝像機模型的應用場合 課程視頻鏈接&am…

第十三節:第七部分:Stream流的中間方法、Stream流的終結方法

Stream流常見的中間方法 Stream流常見的終結方法 代碼 學生類&#xff08;代碼一與代碼二共涉及到的類&#xff09; package com.itheima.day28_Stream;import java.util.Objects;public class Student implements Comparable<Student> {private String name;private i…

深入理解 Go 中的字節序(Endianness)檢測代碼

深入理解 Go 中的字節序&#xff08;大小端&#xff09;檢測代碼 在計算機系統中&#xff0c;字節序&#xff08;Endianness&#xff09; 是指多字節數據類型&#xff08;如 int16、int32 等&#xff09;在內存中的存儲順序。Go 語言標準庫提供了對大端&#xff08;Big-endian&…

JAVA:RabbitMQ 消息持久化機制的技術指南

?? 1、簡述 在使用 RabbitMQ 構建可靠消息系統時,消息丟失是必須避免的問題。為此,RabbitMQ 提供了消息持久化機制(Message Durability),可以保障在 Broker 異常宕機后數據不會丟失。 本篇博客將從原理出發,結合 Spring Boot 實戰講解如何正確實現 RabbitMQ 消息持久…

tabs頁簽嵌套表格,切換表格保存數據不變并回勾

需求&#xff1a;點擊左邊的tab頁簽&#xff0c;請求右側表格數據&#xff1b;如果返回的接口數據存在taskuser字段并不為null&#xff0c;那么按照這個字段去回勾數據。如果存在數據&#xff0c;但與后面所勾選的數據項不同&#xff0c;按照后面勾選的為主。 <el-tabs tab-…

Java Kafka消費者

基礎 Java Kafka消費者主要通過以下核心類實現&#xff1a; KafkaConsumer&#xff1a;消費者的核心類&#xff0c;用于創建消費者對象進行數據消費1ConsumerConfig&#xff1a;獲取各種配置參數&#xff0c;如果不配置就使用默認值1ConsumerRecord&#xff1a;每條數據都要封…

Git操作問題及解決方案-記錄5

Git操作問題及解決方案 問題一&#xff1a;本地更改與遠程更新沖突 問題描述 當本地文件有未提交的更改&#xff0c;同時遠程倉庫也有更新時&#xff0c;執行git pull會導致沖突。 $ git pull origin main error: Your local changes to the following files would be overw…

一[3]、ubuntu18.04環境 利用 yolov8 訓練開源列車數據集,并實現列車軌道檢測

一、開源車載數據集地址 (7 封私信) 軌道交通數據集-OSDaR23: Open Sensor Data for Rail 2023 - 知乎 二、參考資料 https://zhuanlan.zhihu.com/p/692608487 YOLOv8訓練自己的數據集-CSDN博客 https://download.csdn.net/blog/column/12710137/140991739

C語言數據結構筆記5:Keil 編譯器優化行為_malloc指針內存分配問題

記錄倆個keil5 STM32 的c語言編程中 &#xff0c;編譯器優化行為 和 指針內存分配問題。 目錄 關閉Keil 編譯器優化行為&#xff1a; malloc指針內存分配問題 多層嵌套的結構體&#xff1a; 用指針取值&#xff1a; 發現問題&#xff1a; 解決問題&#xff1a; 示例代碼 關閉Ke…

每日八股文6.12

每日八股-6.12 計算機網絡1.當我們在瀏覽器中輸入一個 URL 并按下回車后&#xff0c;到頁面最終顯示出來&#xff0c;這中間都發生了哪些關鍵步驟&#xff1f;2.請簡述一下 JWT&#xff08;JSON Web Tokens&#xff09;的原理和校驗機制3.DNS 是如何進行域名解析的&#xff1f;…

什么是云計算的邊緣原生應用?

關于作者&#xff1a;John Bradshaw阿卡邁公司歐洲、中東和非洲地區云計算技術與戰略總監 當談及云計算時&#xff0c;人們往往會聯想到那些坐落于國際大都會核心地帶的大型數據中心集群&#xff0c;這些設施作為數字時代的重要樞紐&#xff0c;承載著海量數據處理任務。盡管這…

Linux常用命令速查與面試高頻命令總結

&#x1f427; Linux常用命令速查與面試高頻命令總結 本文旨在幫助初學者快速掌握 Linux 的常用命令&#xff0c;同時為即將參加技術面試的朋友們提供一份高頻命令清單和實用技巧。 &#x1f530; 一、基礎命令&#xff1a;熟練使用命令行從這里開始 這些是你在 Linux 中最常用…

基礎測試工具使用經驗

背景 vtune&#xff0c;perf, nsight system等基礎測試工具&#xff0c;都是用過的&#xff0c;但是沒有記錄&#xff0c;都逐漸忘了。所以寫這篇博客總結記錄一下&#xff0c;只要以后發現新的用法&#xff0c;就記得來編輯補充一下 perf 比較基礎的用法&#xff1a; 先改這…

淺談DaemonSet

1. DaemonSet 概述 ?定義?&#xff1a;DaemonSet 確保 Kubernetes 集群的每個節點上運行一個 Pod 實例。?特性?&#xff1a; 每個節點上只有一個 Pod 實例。新節點加入集群時&#xff0c;會自動在新節點上創建 Pod。舊節點被刪除時&#xff0c;其上的 Pod 會被回收。 2.…

計算機系統(6)

◆指令尋址方式&#xff1a; 順序尋址方式&#xff1a;執行一段程序時&#xff0c;是一條指令接著一條指令的順序執行。 跳躍尋址方式:下一條指令的地址碼不是由程序計數器給出&#xff0c;而是由本條指令直接給出。程序跳躍后&#xff0c;按新的指令地址開始順序執行。因此&…

基于服務器使用 apt 安裝、配置 Nginx

&#x1f9fe; 一、查看可安裝的 Nginx 版本 首先&#xff0c;你可以運行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core輸出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…