論文地址:https://arxiv.org/pdf/2402.17764.pdf
相關博客
【自然語言處理】BitNet b1.58:1bit LLM時代
【自然語言處理】【長文本處理】RMT:能處理長度超過一百萬token的Transformer
【自然語言處理】【大模型】MPT模型結構源碼解析(單機版)
【自然語言處理】【大模型】ChatGLM-6B模型結構代碼解析(單機版)
【自然語言處理】【大模型】BLOOM模型結構源碼解析(單機版)
一、BitNet
? BitNet采用了與Transformer基本一致的模型架構,僅將標準矩陣乘法層換成了BitLinear
,其他組件仍然是高精度的。BitLinear
主要是包含的操縱:權重量化、激活量化以及LayerNorm。
? 權重量化。通過減均值實現0中心化,然后用sign實現二值化。假設全精度權重為 W ∈ R n × m W\in\mathcal{R}^{n\times m} W∈Rn×m,則二值量化過程為
W ~ = Sign ( W ? α ) (1) \widetilde{W}=\text{Sign}(W-\alpha) \tag{1} \\ W =Sign(W?α)(1)
Sign ( W i j ) = { + 1 , if W i j > 0 ? 1 , if W i j ≤ 0 (2) \text{Sign}(W_{ij})=\begin{cases} +1,&&\text{if}\;W_{ij}>0 \\ -1,&&\text{if}\;W_{ij}\leq 0 \\ \end{cases} \tag{2} \\ Sign(Wij?)={+1,?1,??ifWij?>0ifWij?≤0?(2)
α = 1 n m ∑ i j W i j (3) \alpha=\frac{1}{nm}\sum_{ij}W_{ij} \tag{3} \\ α=nm1?ij∑?Wij?(3)
? 激活量化。使用absmax的方式將激活量化至b-bit。具體的實現方式是乘以 Q b Q_b Qb?再除以輸入矩陣的最大絕對值,從而將激活縮放至 [ ? Q b , Q b ] ( Q b = 2 b ? 1 ) [-Q_b,Q_b](Q_b=2^{b-1}) [?Qb?,Qb?](Qb?=2b?1),即
x ~ = Quant ( x ) = Clip ( x × Q b γ , ? Q b + ? , Q b ? ? ) (4) \tilde{x}=\text{Quant}(x)=\text{Clip}(x\times\frac{Q_b}{\gamma},-Q_b+\epsilon,Q_b-\epsilon) \tag{4}\\ x~=Quant(x)=Clip(x×γQb??,?Qb?+?,Qb???)(4)
Clip ( x , a , b ) = max ? ( a , min ? ( b , x ) ) , γ = ∥ x ∥ ∞ (5) \text{Clip}(x,a,b)=\max(a,\min(b,x)),\quad\gamma=\parallel x\parallel_\infty \tag{5} \\ Clip(x,a,b)=max(a,min(b,x)),γ=∥x∥∞?(5)
其中 ? \epsilon ?是防止裁剪時溢出的小浮點數。
? 對于非線性函數之前的激活值則采用不同的量化方式,通過減輕最小值的方式將其縮放至 [ 0 , Q b ] [0,Q_b] [0,Qb?],從而保證所有值均為非負:
x ~ = Quant ( x ) = Clip ( ( x ? η ) × Q b γ , ? , Q b ? ? ) , η = min ? i , j x i j (6) \tilde{x}=\text{Quant}(x)=\text{Clip}((x-\eta)\times\frac{Q_b}{\gamma},\epsilon,Q_b-\epsilon),\quad\eta=\min_{i,j}x_{ij}\tag{6} \\ x~=Quant(x)=Clip((x?η)×γQb??,?,Qb???),η=i,jmin?xij?(6)
? LayerNorm。在對激活值量化前,為了保證量化后的方差穩定,采用了SubLN
。
? BitLinear
的完成計算過程為
y = W ~ x ~ = W ~ Quant ( LN ( x ) ) × β γ Q b (7) y=\widetilde{W}\tilde{x}=\widetilde{W}\text{Quant}(\text{LN}(x))\times\frac{\beta\gamma}{Q_b}\tag{7} \\ y=W x~=W Quant(LN(x))×Qb?βγ?(7)
LN ( x ) = x ? E ( x ) Var ( x ) + ? , β = 1 n m ∥ W ∥ 1 (8) \text{LN}(x)=\frac{x-E(x)}{\sqrt{\text{Var}(x)+\epsilon}},\quad\beta=\frac{1}{nm}\parallel W\parallel_1 \tag{8} \\ LN(x)=Var(x)+??x?E(x)?,β=nm1?∥W∥1?(8)
二、BitNet b1.58
? BitNet b1.58在BitNet的基礎上做了一些修改。
? 權重量化。采用absmean的方式將權重約束在 { ? 1 , 0 , 1 } \{-1,0,1\} {?1,0,1}中,而BitNet則將權重約束為二值 { ? 1 , 1 } \{-1,1\} {?1,1}。具體來說,先使用平均絕對值來縮放權重,然后通過舍入的方式轉換為 { ? 1 , 0 , 1 } \{-1,0,1\} {?1,0,1}:
W ~ = RoundClip ( W γ + ? , ? 1 , 1 ) (9) \widetilde{W}=\text{RoundClip}(\frac{W}{\gamma+\epsilon},-1,1)\tag{9} \\ W =RoundClip(γ+?W?,?1,1)(9)
RoundClip ( x , a , b ) = max ? ( a , min ? ( b , round ( x ) ) ) (10) \text{RoundClip}(x,a,b)=\max(a,\min(b,\text{round}(x)))\tag{10} \\ RoundClip(x,a,b)=max(a,min(b,round(x)))(10)
γ = 1 n m ∑ i j ∣ W i j ∣ (11) \gamma=\frac{1}{nm}\sum_{ij}|W_{ij}|\tag{11} \\ γ=nm1?ij∑?∣Wij?∣(11)
? 激活量化。同BitNet一樣,但是對于非線性函數前的激活不再量化至 [ 0 , Q b ] [0,Q_b] [0,Qb?],而是都量化至 [ ? Q b , Q b ] [-Q_b,Q_b] [?Qb?,Qb?]。
? 此外,為了能夠方便于開源軟件兼容,整體結構采用類似LLaMA的結構。具體來說,使用RMSNorm、SwiGLU、RoPE并移除所有偏置。
三、實驗
1. 困惑度
? BitNet b1.58在3B大小時,困惑度與LLaMA相匹配,但是速度快2.71倍且顯存使用減少3.55倍。當BitNet b1.58大小為3.9B時,速度快2.4倍且顯存減少3.32倍,并且效果顯著優于LLaMA 3B。
2. 下游任務
? 隨著模型尺寸的增加,BitNet b1.58和LLaMA在下游任務上的差距逐步縮小。在尺寸達到3B時,BitNet b.158能夠與全精度相匹配。
3. 顯存和延時
? 隨著模型尺寸的增加,BitNet b1.58的速度優勢和顯存優勢會更加明顯。
4. 能耗
? 矩陣乘法是LLM中能耗最高的部分。BitNet b1.58主要是INT8的加法計算,而LLaMA則是由FP16加法和乘法組成。在7nm芯片上,BitNet b1.58能夠節約71.4倍的計算能耗。隨著模型尺寸的增加,BitNet b1.58在能耗方面會越來越高效。
5. 吞吐
? 相同機器下,BitNet b1.58的batch size是LLaMA LLM的11倍,吞吐則是8.9倍。