模型參數量衡量單位
M:百萬(Million)
B:十億(Billion)
1 B = 1000 M 1B = 1000M 1B=1000M
參數存儲精度
模型參數是固定的,但是一個參數所表示多少字節不一定,需要看這個參數以什么樣的精度去存儲。
- 單精度浮點數(FP32):每個參數占用4字節(32位),提供較高的數值精度。
- 半精度浮點數(FP16):每個參數占用2字節(16位),可以節省存儲空間和計算資源,但精度有所降低。
- 8位整數(INT8):每個參數占用1字節(8位),主要用于量化模型,進一步減少存儲和計算開銷,但精度顯著降低。
- 雙精度浮點數(FP64):每個參數占用8字節(64位),提供最高精度,但存儲和計算成本也最高。
參數所占顯存
參數顯存 = 參數數量 × 每個參數的字節數(B)
這里的 B 指的是字節
總顯存 = 參數顯存 + 激活值顯存 + 梯度顯存 + 優化器狀態顯存
在使用 checkpoint 進行推理的時候,主要計算參數顯存。
舉例:
一個 7b 參數的模型,參數存儲精度為 float16,那么:
- 總參數個數: 7 ? 10 9 7 * 10^9 7?109
- 一個參數所占字節數: 16 / 8 = 2 ( B ) 16 / 8 = 2(B) 16/8=2(B)
- 參數所占總字節數,即參數顯存: 7 ? 10 9 ? 2 = 14 ? 10 9 ( B ) = 14 ? 10 9 / 1024 / 1024 / 1024 ≈ 14 ( G ) 7 * 10^9 * 2 = 14*10^9(B)= 14*10^9 / 1024 / 1024 / 1024 ≈ 14(G) 7?109?2=14?109(B)=14?109/1024/1024/1024≈14(G)
簡單來看,如果是float16,參數顯存就是 參數量*2;如果是 float32,參數顯存就是 參數量*4;如果是int8,參數顯存就是 參數量*1。