1字節=8bit
16float=2字節
模型后面的xxb的單位是字節。
1b 字節≈ 0.93G,這個是以8bit運行,4bit減半,16bit(float)加倍,32bit(double)炒雞加倍。
剩下的是小頭,需要參數計算:
- s:最大序列長度(輸入中的令牌數量)
- b:批大小
- h:模型的隱藏維度
- a:注意頭的數量
對于整個層
總內存需求總計為11sbh + 5as2b(來自注意力塊)+ 19sbh(來自MLP塊)+ 4sbh(來自LN)
。
每層激活內存消耗= 34 sbh + 5as2b
小頭一般遠小于10G。
所以比如llama7b,只需要7*0.93≈9G,再加10,內存19G就可以(實際會更少,因為小頭遠低于10G),注意這個是以8bit運行,4bit減半,16bit(float)加倍,32bit(double)炒雞加倍。
感謝博客:https://developer.aliyun.com/article/1496103
感謝github: