【TVM 教程】使用 TVM 部署框架預量化模型

本文介紹如何將深度學習框架量化的模型加載到 TVM。預量化模型的導入是 TVM 中支持的量化之一。有關 TVM 中量化的更多信息,參閱?此處。

這里演示了如何加載和運行由 PyTorch、MXNet 和 TFLite 量化的模型。加載后,可以在任何 TVM 支持的硬件上運行編譯后的量化模型。

首先,導入必要的包:

from PIL import Image
import numpy as np
import torch
from torchvision.models.quantization import mobilenet as qmobilenetimport tvm
from tvm import relay
from tvm.contrib.download import download_testdata

定義運行 demo 的輔助函數:

def get_transform():import torchvision.transforms as transformsnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])return transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize,])def get_real_image(im_height, im_width):img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"img_path = download_testdata(img_url, "cat.png", module="data")return Image.open(img_path).resize((im_height, im_width))def get_imagenet_input():im = get_real_image(224, 224)preprocess = get_transform()pt_tensor = preprocess(im)return np.expand_dims(pt_tensor.numpy(), 0)def get_synset():synset_url = "".join(["https://gist.githubusercontent.com/zhreshold/","4d0b62f3d01426887599d4f7ede23ee5/raw/","596b27d23537e5a1b5751d2b0481ef172f58b539/","imagenet1000_clsid_to_human.txt",])synset_name = "imagenet1000_clsid_to_human.txt"synset_path = download_testdata(synset_url, synset_name, module="data")with open(synset_path) as f:return eval(f.read())def run_tvm_model(mod, params, input_name, inp, target="llvm"):with tvm.transform.PassContext(opt_level=3):lib = relay.build(mod, target=target, params=params)runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.device(target, 0)))runtime.set_input(input_name, inp)runtime.run()return runtime.get_output(0).numpy(), runtime

從標簽到類名的映射,驗證模型的輸出是否合理:

synset = get_synset()

用貓的圖像進行演示:

inp = get_imagenet_input()

部署量化的 PyTorch 模型?

首先演示如何用 PyTorch 前端加載由 PyTorch 量化的深度學習模型。

參考?PyTorch 靜態量化教程,了解量化的工作流程。

用下面的函數來量化 PyTorch 模型。此函數采用浮點模型,并將其轉換為 uint8。這個模型是按通道量化的。

def quantize_model(model, inp):model.fuse_model()model.qconfig = torch.quantization.get_default_qconfig("fbgemm")torch.quantization.prepare(model, inplace=True)# Dummy calibrationmodel(inp)torch.quantization.convert(model, inplace=True)

從 torchvision 加載預量化、預訓練的 Mobilenet v2 模型?

之所以選擇 mobilenet v2,是因為該模型接受了量化感知訓練,而其他模型則需要完整的訓練后校準。

qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval()

輸出結果:

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /workspace/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth0%|          | 0.00/13.6M [00:00<?, ?B/s]44%|####4     | 6.03M/13.6M [00:00<00:00, 63.2MB/s]89%|########8 | 12.1M/13.6M [00:00<00:00, 61.4MB/s]
100%|##########| 13.6M/13.6M [00:00<00:00, 66.0MB/s]

量化、跟蹤和運行 PyTorch Mobilenet v2 模型?

量化和 jit 的詳細信息可參考 PyTorch 網站上的教程。

pt_inp = torch.from_numpy(inp)
quantize_model(qmodel, pt_inp)
script_module = torch.jit.trace(qmodel, pt_inp).eval()with torch.no_grad():pt_result = script_module(pt_inp).numpy()

輸出結果:

/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/observer.py:179: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.reduce_range will be deprecated in a future release of PyTorch."
/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/observer.py:1126: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero pointReturning default scale and zero point "

使用 PyTorch 前端將量化的 Mobilenet v2 轉換為 Relay-QNN?

PyTorch 前端支持將量化的 PyTorch 模型,轉換為具有量化感知算子的等效 Relay 模塊。將此表示稱為 Relay QNN dialect。

若要查看量化模型是如何表示的,可以從前端打印輸出。

可以看到特定于量化的算子,例如 qnn.quantize、qnn.dequantize、qnn.requantize 和 qnn.conv2d 等。

input_name = "input"  # 對于 PyTorch 前端,輸入名稱可以是任意的。
input_shapes = [(input_name, (1, 3, 224, 224))]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
# print(mod) # 打印查看 QNN IR 轉儲

編譯并運行 Relay 模塊?

獲得量化的 Relay 模塊后,剩下的工作流程與運行浮點模型相同。詳細信息請參閱其他教程。

在底層,量化特定的算子在編譯之前,會被降級為一系列標準 Relay 算子。

target = "llvm"
tvm_result, rt_mod = run_tvm_model(mod, params, input_name, inp, target=target)

輸出結果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead."target_host parameter is going to be deprecated. "

比較輸出標簽?

可看到打印出相同的標簽。

pt_top3_labels = np.argsort(pt_result[0])[::-1][:3]
tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3]print("PyTorch top3 labels:", [synset[label] for label in pt_top3_labels])
print("TVM top3 labels:", [synset[label] for label in tvm_top3_labels])

輸出結果:

PyTorch top3 labels: ['tiger cat', 'Egyptian cat', 'tabby, tabby cat']
TVM top3 labels: ['tiger cat', 'Egyptian cat', 'tabby, tabby cat']

但由于數字的差異,通常原始浮點輸出不應該是相同的。下面打印 mobilenet v2 的 1000 個輸出中,有多少個浮點輸出值是相同的。

print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0]))

輸出結果:

154 in 1000 raw floating outputs identical.

測試性能?

以下舉例說明如何測試 TVM 編譯模型的性能。

n_repeat = 100  # 為使測試更準確,應選取更大的數值
dev = tvm.cpu(0)
print(rt_mod.benchmark(dev, number=1, repeat=n_repeat))

輸出結果:

Execution time summary:mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)90.3752      90.2667      94.6845      90.0629       0.6087

備注

推薦這種方法的原因如下:

  • 測試是在 C++ 中完成的,因此沒有 Python 開銷大
  • 包括幾個準備工作
  • 可用相同的方法在遠程設備(Android 等)上進行分析。

備注

如果硬件對?INT8 整數的指令沒有特殊支持,量化模型與 FP32 模型速度相近。如果沒有?INT8 整數的指令,TVM 會以 16 位進行量化卷積,即使模型本身是 8 位。

對于 x86,在具有 AVX512 指令集的 CPU 上可實現最佳性能。這種情況 TVM 對給定 target 使用最快的可用 8 位指令,包括對 VNNI 8 位點積指令(CascadeLake 或更新版本)的支持。

此外,以下一般技巧對 CPU 性能的提升同樣適用:

  • 將環境變量 TVM_NUM_THREADS 設置為物理 core 的數量
  • 為硬件選擇最佳 target,例如 “llvm -mcpu=skylake-avx512” 或 “llvm -mcpu=cascadelake”(未來會有更多支持 AVX512 的 CPU)

部署量化的 MXNet 模型?

待更新

部署量化的 TFLite 模型?

待更新

腳本總運行時長: (1 分 7.374 秒)

下載 Python 源代碼:deploy_prequantized.py

下載 Jupyter Notebook:deploy_prequantized.ipynb

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:
http://www.pswp.cn/web/44330.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/44330.shtml
英文地址,請注明出處:http://en.pswp.cn/web/44330.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Linux】常見指令收官權限理解

tar指令 上一篇博客已經介紹了zip/unzip指令&#xff0c;接下來我們來看一下另一個關于壓縮和解壓的指令&#xff1a;tar指令tar指令&#xff1a;打包/解包&#xff0c;不打開它&#xff0c;直接看內容 關于tar的指令有太多了&#xff1a; tar [-cxtzjvf] 文件與目錄 ...…

C++運行時類型識別

目錄 C運行時類型識別A.What&#xff08;什么是運行時類型識別RTTI&#xff09;B.Why&#xff08;為什么需要RTTI&#xff09;C.dynamic_cast運算符Why&#xff08;dynamic_cast運算符的作用&#xff09;How&#xff08;如何使用dynamic_cast運算符&#xff09; D.typeid運算符…

【Scrapy】 Scrapy 爬蟲框架

準我快樂地重飾演某段美麗故事主人 飾演你舊年共尋夢的戀人 再去做沒流著情淚的伊人 假裝再有從前演過的戲份 重飾演某段美麗故事主人 飾演你舊年共尋夢的戀人 你縱是未明白仍夜深一人 穿起你那無言毛衣當跟你接近 &#x1f3b5; 陳慧嫻《傻女》 Scrapy 是…

各地戶外分散視頻監控點位,如何實現遠程集中實時監看?

公司業務涉及視頻監控項目承包搭建&#xff0c;此前某個項目需求是為某林業公司提供視頻監控解決方案&#xff0c;需要實現各地視頻攝像頭的集中實時監看&#xff0c;以防止國家儲備林的盜砍、盜伐行為。 公司原計劃采用運營商專線連接各個視頻監控點位&#xff0c;實現遠程視…

跟著李沐學AI:線性回歸

引入 買房出價需要對房價進行預測。 假設1&#xff1a;影響房價的關鍵因素是臥室個數、衛生間個數和居住面積&#xff0c;記為x1、x2、x3。 假設2&#xff1a;成交價是關鍵因素的加權和 。權重和偏差的實際值在后面決定。 拓展至一般線性模型&#xff1a; 給定n維輸入&…

MySQL 9.0 正式發行Innovation創新版已支持向量

從 MySQL 8.1 開始&#xff0c;官方啟用了新的版本模型&#xff1a;MySQL 創新版 (Innovation) 和長期支持版 (LTS)。 根據介紹&#xff0c;兩者的質量都已達到可用于生產環境級別。區別在于&#xff1a; 如果希望嘗試最新的功能和改進&#xff0c;并喜歡與最新技術保持同步&am…

怎樣在 C 語言中實現棧?

&#x1f345;關注博主&#x1f397;? 帶你暢游技術世界&#xff0c;不錯過每一次成長機會&#xff01; &#x1f4d9;C 語言百萬年薪修煉課程 通俗易懂&#xff0c;深入淺出&#xff0c;匠心打磨&#xff0c;死磕細節&#xff0c;6年迭代&#xff0c;看過的人都說好。 文章目…

動手學深度學習(Pytorch版)代碼實踐 -循環神經網絡-55循環神經網絡的從零開始實現和簡潔實現

55循環神經網絡的實現 1.從零開始實現 import math import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l import matplotlib.pyplot as plt import liliPytorch as lp# 讀取H.G.Wells的時光機器數據集 batch_size, num_ste…

開發個人Ollama-Chat--7 服務部署

開發個人Ollama-Chat–7 服務部署 服務部署 go-ChatGPT項目涉及的中間件服務較多&#xff0c;以下部署文件目錄&#xff1a; |-- chat-api | |-- etc | | -- config.yaml | -- logs |-- chat-rpc | |-- etc | | -- config.yaml | -- logs |-- docker-compos…

ElasticSearch第一天

學習目標&#xff1a; 能夠理解ElasticSearch的作用能夠安裝ElasticSearch服務能夠理解ElasticSearch的相關概念能夠使用Postman發送Restful請求操作ElasticSearch能夠理解分詞器的作用能夠使用ElasticSearch集成IK分詞器能夠完成es集群搭建 第一章 ElasticSearch簡介 1.1 什么…

windows 中的 Nsight Systems 通過ssh 鏈接分析 Linux 中的cuda程序性能

1&#xff0c;Linux 環境 安裝 ssh-server $ sudo apt install openssh-server 安裝較新版本的 cuda sdk 下載cuda-samples github repo 編輯修改 ssh 配置&#xff1a; $ sudo vim /etc/ssh/sshd_config 刪除相關注釋&#xff0c;修改后如下&#xff1a; Port 22 Addres…

只會vue的前端開發工程師是不是不能活了?最近被一個flutter叼了

**Vue與Flutter&#xff1a;前端開發的新篇章** 在前端開發的世界里&#xff0c;Vue.js和Flutter無疑是兩顆璀璨的明星。Vue以其輕量級、易上手的特點吸引了大量前端開發者的青睞&#xff0c;而Flutter則以其跨平臺、高性能的優勢迅速崛起。那么&#xff0c;對于只會Vue的前端…

【深度學習基礎】環境搭建 linux系統下安裝pytorch

目錄 一、anaconda 安裝二、創建pytorch1. 創建pytorch環境&#xff1a;2. 激活環境3. 下載安裝pytorch包4. 檢查是否安裝成功 一、anaconda 安裝 具體的安裝說明可以參考我的另外一篇文章【環境搭建】Linux報錯bash: conda: command not found… 二、創建pytorch 1. 創建py…

OceanBase:引領下一代分布式數據庫技術的前沿

OceanBase的基本概念 定義和特點 OceanBase是一款由螞蟻金服開發的分布式關系數據庫系統&#xff0c;旨在提供高性能、高可用性和強一致性的數據庫服務。它結合了關系數據庫和分布式系統的優勢&#xff0c;適用于大規模數據處理和高并發業務場景。其核心特點包括&#xff1a; …

【考研數學】25張宇強化36講測評及強化階段注意事項

張宇新版36講創新真的很大&#x1f979; 引入了很多張宇老師認為對大家解題幫助很大的技巧和知識點&#xff0c;但是也有人認為是多余的。 張宇老師新版36講第一講就講了整整8個小時&#xff01;&#x1f62d; 大家想想&#xff0c;自己有那個時間去吃透36講嗎&#xff1f;如果…

python調用阿里云匯率接口

整體請求流程 介紹&#xff1a; 本次解析通過阿里云云市場的云服務來實現程序中對貨幣匯率實時監控&#xff0c;首先需要準備選擇一家可以提供匯率查詢的商品。 https://market.aliyun.com/apimarket/detail/cmapi00065831#skuyuncode5983100001 步驟1: 選擇商品 如圖點擊…

debian 12 Install

debian 前言 Debian是一個基于Linux內核的自由和開放源代碼操作系統&#xff0c;由全球志愿者組成的Debian項目維護和開發。該項目始于1993年&#xff0c;由Ian Murdock發起&#xff0c;旨在創建一個完整的、基于Linux的自由軟件操作系統。 debian download debian 百度網盤…

分布式應用系統設計:即時消息系統

即時消息(IM)系統&#xff0c;涉及&#xff1a;站內消息系統 組件如下&#xff1b; 客戶端&#xff1a; WEB頁面&#xff0c;IM桌面客戶端。通過WebSocket 跟ChatService后端服務連接 Chat Service&#xff1a; 提供WebSocket接口&#xff0c;并保持跟“客戶端”狀態的維護。…

會聲會影分割音頻怎么不能用 會聲會影分割音頻方法 會聲會影視頻制作教程 會聲會影下載免費中文版2023

將素材中的音頻分割出來&#xff0c;對聲音部分進行單獨編輯&#xff0c;是剪輯過程中的常用操作。會聲會影視頻剪輯軟件在分割音頻后&#xff0c;還可以對聲音素材進行混音編輯、音頻調節、添加音頻濾鏡等操作。有關會聲會影分割音頻怎么不能用&#xff0c;會聲會影分割音頻方…

如何快速制作您的數據可視化大屏?

數據大屏可視化主要就是借助圖形&#xff0c;利用生動、直觀的形式展示出數據信息的具體數值&#xff0c;使得使用者短時間內更加直觀的接受到大量信息。數據大屏以直觀、高度視覺沖擊力的方式向受眾揭示數據背后隱藏的規律&#xff0c;傳達數據價值。其以圖形化的形式呈現數據…