如何計算Transformer 相關模型的參數量呢?
先回憶一下Transformer模型論文《Attention is all your need》中的兩個圖。
設Transformer模型的層數為N,每個Transformer層主要由self-attention 和 Feed Forward組成。設self-attention模塊的head個數為 n h e a d n_{head} nhead?,每一個head對應的維度為 d h e a d d_{head} dhead?,self-attention輸出維度為 d m o d e l = n heads ? d head d_{model}= n_\text{heads}\cdot d_\text{head} dmodel?=nheads??dhead?。我們可以得到一個Transformer層的參數量為 12 d m o d e l 2 + 13 d m o d e l 12 d_{model}^2 + 13 d_{model} 12dmodel2?+13dmodel?,具體如下:
-
self-attention塊的模型參數有Q、K、V的權重矩陣 W Q 、 W K 、 W V W_Q、W_K 、W_V WQ?、WK?、WV?和偏置,輸出矩陣 W O W_O WO?及其偏置。這4個權重矩陣的大小為 [ d m o d e l , d m o d e l ] [d_{model}, d_{model}] [dmodel?,dmodel?],4個偏置的大小為 [ d m o d e l ] [d_{model}] [dmodel?],所以self-attention塊的參數量為 4 d m o d e l 2 + 4 d m o d e l 4 d_{model}^2 + 4 d_{model} 4dmodel2?+4dmodel?。
-
Feed Forward塊一般由2個線性層組成,第一個線性層將維度從 d m o d e l d_{model} dmodel? 映射成 4 d m o d e l 4d_{model} 4dmodel?, 其權重矩陣 W 1 W_1 W1?的大小為 [ d m o d e l , 4 d m o d e l ] [d_{model}, 4d_{model}] [dmodel?,4dmodel?] ,其偏置的大小為 [ 4 d m o d e l ] [4d_{model}] [4dmodel?]。 第二個線性層將維度從 4 d m o d e l 4d_{model} 4dmodel? 映射成 d m o d e l d_{model} dmodel?,其權重矩陣 W 2 W_2 W2?的大小為 [ 4 d m o d e l , d m o d e l ] [4d_{model}, d_{model}] [4dmodel?,dmodel?] ,其偏置的大小為 [ d m o d e l ] [d_{model}] [dmodel?]。所以Feed Forward的參數量為 8 d m o d e l 2 + 5 d m o d e l 8 d_{model}^2 + 5 d_{model} 8dmodel2?+5dmodel?。
-
self-attention 和 Feed Forward都跟隨著layer normalization,它有兩個可訓練模型參數,形狀都是 [ d m o d e l ] [d_{model}] [dmodel?]。所以2個layer normalization的參數量為 4 d m o d e l 4 d_{model} 4dmodel?。
除了Transformer層之外的參數有:
- 詞embedding矩陣的參數量,embedding的維度通常等于 d m o d e l d_{model} dmodel?,設詞表的大小為V,則詞embedding的參數量為 V d m o d e l Vd_{model} Vdmodel?。
- 位置向量相關,有些位置向量表示方式需要學習參數。
所以N層Transformer模型的可訓練模型參數量為 N ( 12 d m o d e l 2 + 13 d m o d e l ) + V d m o d e l N(12 d_{model}^2 + 13 d_{model}) + Vd_{model} N(12dmodel2?+13dmodel?)+Vdmodel?。當 d m o d e l d_{model} dmodel?較大時,可以忽略一次項,模型參數量近似為 12 N d m o d e l 2 12 N d_{model}^2 12Ndmodel2?。
最后試驗一下模型參數估計量與論文是否對的上,下表是GPT3和LLaMA的計算對比,可以發現數量級是可以對的上的,因為我們忽略了一次項,所以具體數據與論文不一致。
模型名 | 實際參數量 | n l a y e r n_{layer} nlayer? | d m o d e l d_{model} dmodel? | n h e a d n_{head} nhead? | d h e a d d_{head} dhead? | 估計參數量 |
---|---|---|---|---|---|---|
GPT-3 | 175B | 96 | 12288 | 96 | 128 | 173946175488 |
LLaMA 6.7B | 6.7B | 32 | 4096 | 32 | 128 | 6442450944 |
LLaMA 13.0B | 13.0B | 40 | 5120 | 40 | 128 | 12582912000 |
LLaMA 32.5B | 32.5B | 60 | 6656 | 52 | 128 | 31897681920 |
LLaMA 65.2B | 65.2B | 80 | 8192 | 64 | 128 | 64424509440 |
參考資料
-
Transformer 論文(模型圖來自論文)、GPT3的論文等
-
整理過程中參考的blog: 1. 知乎用戶回旋托馬斯x 的文章,除了計算量外,還算了計算量、中間激活等 , 2 transformer 參數量計算, 3 flops 計算, 4 transformers 參數量計算公式
-
transfomers 庫如何得到參數量