【ONNX量化實戰】使用ONNX Runtime進行靜態量化

目錄

    • 什么是量化
    • 量化實現的原理
    • 實戰
      • 準備數據
      • 執行量化
    • 驗證量化
    • 結語

什么是量化

量化是一種常見的深度學習技術,其目的在于將原始的深度神經網絡權重從高位原始位數被動態縮放至低位目標尾數。例如從FP32(32位浮點)量化值INT8(8位整數)權重。
FP32到INT8的整數縮放示意圖,來源NVIDIA
這么做的目的,是為了在不影響神經網絡的精度為前提下,減少模型運行時的內存消耗,提升推理系統整體的吞吐量。

量化實現的原理

量化的實現本質上以一種基于動態縮放的數值運算,因此在量化中,有幾個重要的參數:

  • 縮放系數 ( s c a l e ) (scale) scale:用于表述從高位縮放至低位的縮放系數,如果沒有它,量化就不存在了
  • x f x_f xf?:代表輸入的浮點高位值,一般是FP32或者FP64的輸入值

那么,如何計算縮放系數 ( s c a l e ) (scale) scale呢?
首先,我們要找出輸入值的最大值,原因是我們要找出整個輸入的的量化范圍,即:從哪里結束量化?因此,你可以使用 a m a x ( x f ) amax(x_f) amax(xf?)來算出 x f x_f xf?的極大值:

a m a x ( x f ) = m a x ( a b x ( x f ) ) amax(x_f) = max(abx(x_f)) amax(xf?)=max(abx(xf?))

找到最大值后,現在你要思考你需要多少位的量化,但通常在此之前,你需要算出你的輸入數據最大可以容納多少位數據:
n b i t = 2 ? a m a x ( x f ) n_{bit} = 2 *amax(x_f) nbit?=2?amax(xf?)
在確定了這個值之后,除以你預期你的位數所能承載的最大數據量,就得到了縮放系數:
s c a l e = n b i t / p o w ( 2 , t b i t ) scale = n_{bit} / pow(2,t_{bit}) scale=nbit?/pow(2,tbit?)

好了,現在你有了兩個量化過程中最重要的參數了,接下來就可以開始正式計算量化的結果了:
x q = C l i p ( R o u n d ( x f / s c a l e ) ) x_q = Clip(Round(x_f / scale)) xq?=Clip(Round(xf?/scale))

首先,我們需要將現有位數的輸入除以我們得到的縮放系數,即得到了目標位數的浮點數據,但別忘了:我們在量化時通常是為了將浮點值操作量化為整數值操作,因此需要將其取整為整數。

那么? C l i p Clip Clip在做什么?因為我們不希望我們量化后結果的范圍超出了目標位數的極大和極小值,因此使用 C l i p Clip Clip來裁切目標值為指定位數的極大和極小值。以INT8為例,則應該是:
x q = C l i p ( R o u n d ( x f / s c a l e ) , m i n = ? 128 , m a x = 127 ) x_q = Clip(Round(x_f/scale),min=-128,max=127) xq?=Clip(Round(xf?/scale),min=?128,max=127)
量化裁切示意圖-來源NVIDIA

實戰

說完了原理,我們該如何在ONNX中使用靜態量化呢?
在這里,我們需要使用onnxruntime庫來完成這個量化操作:

pip install onnxruntime-[target-ep]

其中,target-ep代表你期望模型在哪個類型的計算設備運行,如:

  • CUDA-GPU:則是pip install onnxruntime-gpu
  • DirectML:pip install onnxruntime-directml

準備數據

在準備數據時,我們不能像以前那樣直接使用Dict[str,ndarray]的方式來調用靜態量化,而是需要使用校準數據讀取方式來讀取:

from onnxruntime.quantization import (quantize_static, CalibrationDataReader,QuantType, QuantFormat, CalibrationMethod
)
from typing import *
# 創建一個DummpyDataReader類來繼承CalibrationDataReader類
class DummyDataReader(CalibrationDataReader):def __init__(self, calibration_dataset:List[Dict[str,np.ndarray]]):self.dataset:List[Dict[str,np.ndarray]] = calibration_datasetself.enum_data:Any = None# 重載get_next迭代函數def get_next(self):if self.enum_data is None:self.enum_data:Iterator = iter(self.dataset)return next(self.enum_data, None)

接下來我們就可以準備輸入數據了:

# 這里以Hubert Wav2Vec模型進行數據讀取(1,audio_length)
# 采樣率為16000Hz
import numpy as np
audio = np.load("./input.npy")
inputs = [{"feats": audio.astype(np.float32)},
]

執行量化

接下來我們就可以調用onnxruntime為我們提供的quantize_static函數了,在我們的實例中,會使用到如下的參數:

  • model_input [str]:輸入的模型位置,通常為FP32的ONNX模型權重
  • model_output [str]:量化權重保存的位置
  • calibration_data_reader [CalibrationDataReader]:我們剛剛創建的校準數據讀取類
  • quant_format [enum]:量化的格式,對于我們的實例中,使用QDQ(Quantize => Dequantize),即顯示量化和反量化格式,因為我們不希望自己手動去算量化,對吧?事實證明使用這個模式的情況ONNXRuntime會自動幫你料理 s c a l e scale scale和零值點的計算,以及后續的反量化等。
  • activation_type [enum]:指定模型內部相關的激活函數使用什么數據類型來完成計算,在我們的例子中,QINT8相對合適,因為Wav2Vec是從音頻中來提取特征表述,因此有符號比無符號效果會好很多。
  • weight_type [enum]:指定模型的權重是以什么數據類型來保存的,通常來說,如果你使用的是quantize_dynamic時,ONNXRuntime為了考慮兼容性,默認只會為你量化權重,而不會去管激活函數的量化。
  • calibrate_method [enum]:校準方法,指定在反量化階段以什么方式來完成數據校準,ONNXRuntime支持下述的校準方式:
    • MinMax:極大極小值,這種校準方式適合基于特征表述的神經網絡,如視覺模型,向量機
    • Entropy:基于熵,這種校準方式更適合于不確定性量化,即模型復雜度高,無法直接觀測模型內部數據變化的神經網絡,例如Transformer。適合處理高維度數據,對于我們這次示例中的Hubert十分有效,因為Hubert最終輸出的特征向量大小是 ( b × n × 768 ) (b \times n \times 768) b×n×768
    • Percentile:基于百分位的數據校準模式,可以顯著降低因量化產生的干擾值,但缺點就是容易***一刀切***,進而丟失數據
    • Distribution:基于分布的數據校準模式,當你看到***分布***這兩字兒,你大概心里也應該有個譜了:沒錯,它是基于數據在FP32狀態下的分布狀態來進行對應比例的縮放校準的,而這也正是它的問題所在,即每進行一次校準時都有參考來FP32狀態下的數據分布從而計算出INT8下可能的數據分布,因此對于時延要求不大的任務:如Diffusion可以用這類校準。

接下來我們就可以調用quantize_static()來執行靜態量化了:

quantize_static(model_input="./hubert.onnx",model_output="./hubert_int8.onnx",calibration_data_reader=reader,quant_format=QuantFormat.QDQ,activation_type=QuantType.QInt8,weight_type=QuantType.QInt8,calibrate_method=CalibrationMethod.Entropy
)

之后你會看到這樣的日志:

Collecting tensor data and making histogram ...
Finding optimal threshold for each tensor using 'entropy' algorithm ...
Number of tensors : 712
Number of histogram bins : 128 (The number may increase depends on the data it collects)
Number of quantized bins : 128

這在說明ONNXRuntime正在計算每個張量的最佳閾值和分布大小。

驗證量化

接下來我們就可以正常讀取這些模型寫模型來看看不用位數下的輸出精度了:

import onnxruntime as ort
# 加載FP32模型
model_fp32 = ort.InferenceSession("./hubert.onnx")
# 加載FP16模型
model_fp16 = ort.InferenceSession("./hubert_fp16.onnx")
# 加載INT8模型
model_int8 = ort.InferenceSession("./hubert_int8.onnx")
# 預測FP32
fp32_result = model_fp32.run(None,input_feed={"feats": audio.astype(np.float32)}
)
# 預測FP16
fp16_result = model_fp16.run(None,input_feed={"feats": audio.astype(np.float16)}
)
# 預測INT8
int8_result = model_int8.run(None,input_feed={"feats": audio.astype(np.float32)}
)# 繪制圖像
import matplotlib.pyplot as plt
fig, ax = plt.subplots(3, 1, figsize=(8,6))ax[0].plot(fp32_result[0][0, 0, :], label="FP32")
ax[0].set_title("FP32 Output")ax[1].plot(fp16_result[0][0, 0, :], label="FP16")
ax[1].set_title("FP16 Output")ax[2].plot(int8_result[0][0, 0, :], label="INT8")
ax[2].set_title("INT8 Output")for a in ax:a.legend()a.grid()plt.tight_layout()
plt.show()

輸出圖像如下:
請添加圖片描述
從圖像也可以很明顯的看出來:INT8的數據分布會更發散,雖然ONNXRuntime已經幫我們完成了反量化這一步驟。而FP16相比INT8則好看許多,雖然在浮點上位上少了很多表示位,但精度依然還是在線的,這也是量化時要權衡的問題:速度和精度,哪個對你的場景更重要?

結語

量化是一把雙刃劍,雖然可以對比原來的推理環境實現大幅度的性能提升,但速度提升的代價就是精度的明顯下降,因此在執行量化操作一定要權衡利弊,是否量化真的對你的場景真的很重要?你的任務是否真的很依賴那點兒因為降低精度而換回來的速度?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/85554.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/85554.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/85554.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【量子計算】格羅弗算法

文章目錄 🔍 一、算法原理與工作機制? 二、性能優勢:二次加速的體現🌐 三、應用場景?? 四、局限性與挑戰🔮 五、未來展望💎 總結 格羅弗算法(Grover’s algorithm)是量子計算領域的核心算法之…

C++ 互斥量

在 C 中,互斥量(std::mutex)是一種用于多線程編程中保護共享資源的機制,防止多個線程同時訪問某個資源,從而避免數據競爭(data race)和不一致的問題。 🔒 一、基礎用法:s…

CSS Content符號編碼大全

資源寶整理分享:?https://www.httple.net? 前端開發中常用的特殊符號查詢工具,包含Unicode編碼和HTML實體編碼,方便開發者快速查找和使用各種符號。支持基本形狀、箭頭、數學符號、貨幣符號等多種分類。 前端最常用符號 圖標形狀十進制十…

RPC常見問題回答

項目流程和架構設計 1.服務端的功能: 1.提供rpc調用對應的函數 2.完成服務注冊 服務發現 上線/下線通知 3.提供主題的操作 (創建/刪除/訂閱/取消訂閱) 消息的發布 2.服務的模塊劃分 1.網絡通信模塊 net 底層套用的moude庫 2.應用層通信協議模塊 1.序列化 反序列化數…

【JavaEE】(3) 多線程2

一、常見的鎖策略 1、樂觀鎖和悲觀鎖 悲觀鎖:預測鎖沖突的概率較高。在鎖中加阻塞操作。樂觀鎖:預測鎖沖突的概率較低。使用忙等/版本號等,不產生阻塞。 2、輕量級鎖和重量級鎖 重量級鎖:加鎖的開銷較大,線程等待鎖…

創客匠人服務體系解析:知識 IP 變現的全鏈路賦能模型

在知識服務行業深度轉型期,創客匠人通過 “工具 陪跑 圈層” 的三維服務體系,構建了從 IP 定位到商業變現的完整賦能鏈條。這套經過 5 萬 知識博主驗證的模型,不僅解決了 “內容生產 - 流量獲取 - 用戶轉化” 的實操難題,更推動…

國產ARM/RISCV與OpenHarmony物聯網項目(六)SF1節點開發

一、終端節點功能設計 1. 功能說明 終端節點設計的是基于鴻蒙操作系統的 TCP 服務器程序,用于監測空氣質量并提供遠程控制功能。與之前的光照監測程序相比,這個程序使用 E53_SF1 模塊(煙霧 / 氣體傳感器),主要功能包…

Plotly圖表全面使用指南 -- Displaying Figures in Python

文中內容僅限技術學習與代碼實踐參考,市場存在不確定性,技術分析需謹慎驗證,不構成任何投資建議。 在 Python 中顯示圖形 使用 Plotly 的 Python 圖形庫顯示圖形。 顯示圖形 Plotly的Python圖形庫plotly.py提供了多種顯示圖形的選項和方法…

getx用法詳細解析以及注意事項

源碼地址 在 Flutter 中,Get 是來自 get 包的一個輕量級、功能強大的狀態管理與路由框架,常用于: 狀態管理路由管理依賴注入(DI)Snackbar / Dialog / BottomSheet 管理本地化(多語言) 下面是 …

深度學習:人工神經網絡基礎概念

本文目錄: 一、什么是神經網絡二、如何構建神經網絡三、神經網絡內部狀態值和激活值 一、什么是神經網絡 人工神經網絡(Artificial Neural Network, 簡寫為ANN)也簡稱為神經網絡(NN),是一種模仿…

Unity2D 街機風太空射擊游戲 學習記錄 #12環射道具的引入

概述 這是一款基于Unity引擎開發的2D街機風太空射擊游戲,筆者并不是游戲開發人,作者是siki學院的涼鞋老師。 筆者只是學習項目,記錄學習,同時也想幫助他人更好的學習這個項目 作者會記錄學習這一期用到的知識,和一些…

網站如何啟用HTTPS訪問?本地內網部署的https網站怎么在外網打開?

在互聯網的世界里,數據安全已經成為了每個網站和用戶都不得不面對的問題。近期,網絡信息泄露事件頻發,讓越來越多的網站開始重視起用戶數據的安全性,因此啟用HTTPS訪問成為了一個熱門話題。作為一名網絡安全專家,我希望…

計算機網絡-----詳解網絡原理TCP/IP(上)

文章目錄 📕1. UDP協議??1.1 UDP的特點??1.2 基于UDP的應用層協議 📕2. TCP協議??2.1 TCP協議段格式??2.2 TCP協議特點之確認應答??2.3 TCP協議特點之超時重傳??2.4 TCP協議特點之連接管理??2.5 TCP協議特點之滑動窗口??2.6 TCP協議特點…

Lora訓練

一種大模型高效訓練方式&#xff08;PEFT&#xff09; 目標&#xff1a; 訓練有限的ΔW&#xff08;權重更新矩陣&#xff09; ΔW為低秩矩陣→ΔWAB&#xff08;其中A的大小為dr, B的大小為rk&#xff0c;且r<<min(d,k)&#xff09;→ 原本要更新的dk參數量大幅度縮減…

藍牙 5.0 新特性全解析:傳輸距離與速度提升的底層邏輯(面試寶典版)

藍牙技術自 1994 年誕生以來,已經經歷了多次重大升級。作為當前主流的無線通信標準之一,藍牙 5.0 在 2016 年發布后,憑借其顯著的性能提升成為了物聯網(IoT)、智能家居、可穿戴設備等領域的核心技術。本文將深入解析藍牙 5.0 在傳輸距離和速度上的底層技術邏輯,并結合面試…

Minio使用https自簽證書

自簽證書參考&#xff1a;window和ubuntu自簽證書_windows 自簽證書-CSDN博客 // certFilePath: 直接放在 resources 目錄下 或者可以自定實現讀取邏輯 // 讀取的是 .crt 證書文件public static OkHttpClient createTrustingOkHttpClient(String certFilePath) throws Excep…

汽車前縱梁焊接總成與沖壓件的高效自動化三維檢測方案

汽車主體結構件上存在很多安裝位&#xff0c;為保證汽車裝配時的準確性&#xff0c;主體結構件需要進行全方位的尺寸和孔位置精度檢測&#xff0c;以確保裝配線的主體結構件質量合格。 前縱梁焊接總成是車身框架的核心承載部件&#xff0c;焊接總成由多片鈑金沖壓件焊接組成&a…

F接口基礎.go

前言&#xff1a;接口是一組方法的集合&#xff0c;它定義了一個類型應該具備哪些行為&#xff0c;但不關心具體怎么實現這些行為。一個類型只要實現了接口中定義的所有方法&#xff0c;那么它就實現了這個接口。這種實現是隱式的&#xff0c;不需要顯式聲明。 目錄 接口的定…

cartographer官方指導文件說明---第3章 cartographer前端算法流程介紹

cartographer官方指導文件說明 第3章 cartographer前端算法流程介紹 3.1 Scan Match掃描匹配 掃描匹配&#xff08;Scan Matching&#xff09;是 Cartographer 中實現局部SLAM的核心技術&#xff0c;它通過優化算法將當前激光掃描數據對齊到子圖地圖中。下面從計算過程、數學…

汽車整車廠如何用數字孿生系統打造“透明車間”

隨著工業4.0時代的發展&#xff0c;數字孿生技術已成為現代制造業的重要利器。特別是在汽車整車廠&#xff0c;通過數字孿生系統的應用&#xff0c;能夠有效打造一個“透明車間”&#xff0c;實現生產過程的全面可視化與實時監控&#xff0c;提高生產效率&#xff0c;降低成本&…