WHAT:量化感知訓練(Quantization-Aware Training, QAT) 是一種在模型訓練階段引入量化誤差的技術。
它的核心思想是:通過在前向傳播時插入“偽量化節點”引入量化誤差,將權重和激活模擬為低精度(如 int8)格式,同時仍然使用高精度(如 float32)進行反向傳播和參數更新,使得模型在訓練時適應量化誤差的存在,從而在實際部署時保證性能。偽量化節點一般通過連續的量化與反量化(如float32量化->int8反量化->float32)引入量化誤差。
WHY: 量化技術可以降低模型的大小和計算復雜度,提高模型在移動設備或嵌入式系統等資源受限環境中的運行效率。
以將權重從 float32
(32位)量化為 int8
(8位)為例
存儲空間:理想情況下將縮小至原來的 1/4(理論上限)
推理速度:在支持低精度加速的硬件(如 ARM CPU、DSP、TPU、NPU)上推理速度通常可提升 2 ~ 4 倍,但實際加速比依賴于具體平臺、模型結構和量化方式。
HOW:
1. 準備工作
在正式進行qat訓練之前需要兩個步驟。首先,使用標準的訓練方法預訓練一個模型,以獲得較好的權重和量化起點;其次,準備一個完全支持qat的模型結構,由于某些模塊(例如多頭自注意力機制)在qat框架并不原生支持(如 tensorflow.model_optimization
?并不支持 MultiHeadAttention
的自動量化),這些模塊在qat階段需要手動實現或替換為可量化版本,而不是直接調用tensorflow等寫好的包,以確保量化代碼能識別這些參數并正確插入偽量化節點并進行量化訓練。
2. 訓練過程(以tensorflow為例)
step1: 輸入激活(float32)
↓
step2: 偽量化權重(float32->量化->int8->反量化->float32)引入量化誤差
↓
step3: 前向計算
↓
step4: 偽量化輸出(float32->量化->int8->反量化->float32)引入激活誤差
↓
step5: 反向傳播,遇到偽量化節點使用STE(Straight Through Estimator)傳遞梯度
【待補充】
為了實現自定義的 QAT 訓練,最推薦也最快速的方法之一,就是通過為每一層顯式命名的方式進行標記。這也是 TensorFlow 官方推薦的做法。
在 QAT 訓練開始前,我們通常會逐層遍歷模型,使用 annotate_layer
對需要量化的層打上標記,并通過 clone_function
將模型復制一遍。
然后,使用 quantize_apply()
對復制后的模型進行包裝,此操作會根據指定的量化方案,在所有標記過的層中插入對應的偽量化節點(包括權重和激活)
接下來,只需像普通模型一樣調用 compile()
和 fit()
,即可進入標準的訓練流程啦!