PyTorch導出JIT模型并用C++ API libtorch調用

PyTorch導出JIT模型并用C++ API libtorch調用

本文將介紹如何將一個 PyTorch 模型導出為 JIT 模型并用 PyTorch 的 C++API libtorch運行這個模型。

Step1:導出模型

首先我們進行第一步,用 Python API 來導出模型,由于本文的重點是在后面的部署階段,因此,模型的訓練就不進行了,直接對 torchvision 中自帶的 ResNet50 進行導出。在實際應用中,大家可以對自己訓練好的模型進行導出。

# export_jit_model.py
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()example_input = torch.rand(1, 3, 224, 224)jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

導出 JIT 模型的方式有兩種:trace 和 script。

我們采用 torch.jit.trace 的方式來導出 JIT 模型,這種方式會根據一個輸入將模型跑一遍,然后記錄下執行過程。這種方式的問題在于對于有分支判斷的模型不能很好的應對,因為一個輸入不能覆蓋到所有的分支。但是在我們 ResNet50 模型中不會遇到分支判斷,因此這里是合適的。關于兩種導出 JIT 模型的方式各自優劣不是本文的中斷,以后會再寫一篇來分析。

在我們的工程目錄 demo 下運行上面的 export_jit_model.py ,會得到一個 JIT 模型件:resnet50_jit.pth

Step 2:安裝libtorch

接下來我們要安裝 PyTorch 的 C++ API:libtorch。這一步很簡單,直接下載官方預編譯的文件并解壓即可:

wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

也解壓在我們的工程目錄 demo 下即可。

Step 3:安裝OpenCV

用 Python 或 C++ 做圖像任務,OpenCV 是經常用到的。如果還沒有安裝的讀者可以參考如下在工程目錄 demo 下進行安裝,構建的過程可能會比較久。已經安裝的讀者可跳過此步驟,一會兒在 CMakeLists.txt 文件中正確地指定本機的 OpenCV 地址即可。

git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6

Step 4:準備測試圖像并用Python測試

我們先準備一張小貓的圖像,并用 PyTorch ResNet50 模型正常跑一下,一會兒與我們 C++ 模型運行的結果對比來驗證 C++ 模型是否被正確的部署。

kitten.jpg

在這里插入圖片描述

寫一個腳本用 PyTorch 運行一下模型:

# pytorch_test.pyimport torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor()])# normalize])model = models.resnet50(pretrained=True)
model.eval()img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())

輸出結果是:282。通過查看ImageNet 1K 類別名與索引的對應關系,可以看到,結果為 tiger cat,模型預測正確。一會兒我們看一下部署后的 C++ 模型是否能正確輸出結果 282。

Step 5:準備cpp源文件

我們下面準備一會要執行的 cpp 源文件,第一次使用 libtorch 的讀者可以先借鑒下面的文件。

這里有幾個點要說一下,不注意可能會犯錯:

  1. cv::imread() 默認讀取為三通道BGR,需要進行B/R通道交換,這里采用 cv::cvtColor()實現。

  2. 圖像尺寸需要調整到 224×224224\times 224224×224,通過 cv::resize() 實現。

  3. opencv讀取的圖像矩陣存儲形式:H x W x C, 但是pytorch中 Tensor的存儲為:N x C x H x W, 因此需要進行變換,就是np.transpose()操作,這里使用tensor.permut()實現,效果是一樣的。

  4. 數據歸一化,采用 tensor.div(255) 實現。

// test_model.cpp
#include <vector>#include <torch/torch.h>
#include <torch/script.h>#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>int main(int argc, char* argv[]) {// 加載JIT模型auto module = torch::jit::load(argv[1]);// 加載圖像auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);cv::Mat image_transfomed;cv::resize(image, image_transfomed, cv::Size(224, 224));cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);// 圖像轉換為Tensortorch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);tensor_image = tensor_image.permute({2, 0, 1});// tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255.);// tensor_image = tensor_image.sub(0.5);// tensor_image = tensor_image.div(0.5);tensor_image = tensor_image.unsqueeze(0);// 運行模型torch::Tensor output = module.forward({tensor_image}).toTensor();// 結果處理int result = output.argmax().item<int>();std::cout << "The classifiction index is: " << result << std::endl;return 0;
}

Step 6:構建運行驗證

我們先來寫一下 CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)add_executable(resnet50  test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")set_property(TARGET resnet50  PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

現在我們的工程目錄 demo 下有以下文件:

CMakeLists.txt  export_jit_model.py  kitten.jpg  libtorch  pytorch_test.py  resnet50_jit.pth  test_model.cpp

然后開始用 CMake 構建工程:

mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make

整個過程沒有報錯的話我們就已經構建完成了,會得到一個可執行文件 resnet50 在工程目錄 demo 下。

接下來我們執行,并驗證運行結果是否與 PyTorch 的結果一致:

./build/resnet50 resnet50_jit.pth kitten.jpg

輸出:

The classifiction index is: 282

運行成功并且結果正確。

Ref:

https://www.jianshu.com/p/7cddc09ca7a4

https://blog.csdn.net/cxx654/article/details/115916275

https://zhuanlan.zhihu.com/p/370455320

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532594.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532594.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532594.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

sqli-lab--writeup(7~10)文件輸出,時間布爾盲注

前置知識點&#xff1a; 1、outfile是將檢索到的數據&#xff0c;保存到服務器的文件內&#xff1a; 格式&#xff1a;select * into outfile “文件地址” 示例&#xff1a; mysql> select * into outfile ‘f:/mysql/test/one’ from teacher_class; 2、文件是自動創建…

樹莓派4B (aarch64) 安裝PyTorch 1.8 的可行方案

樹莓派4B (aarch64) 安裝PyTorch 1.8 的可行方案 最終可行方案 試了一堆方案&#xff08;源碼編譯、Fast.ai的安裝文件等&#xff09;之后&#xff0c;終于找到一個可行的方案。是在 PyTorch 官方討論社區的一個帖子中找到的&#xff0c;在回復中一個大佬給出了自己在2021年1…

sqli-lab———writeup(11~17)

less11 用戶名提交單引號顯示sql語法錯誤&#xff0c;故存在sql注入 根據單引號報錯&#xff0c;在用戶名和密碼任意行輸入 萬能密碼&#xff1a;‘ or 11# 輸入后username語句為&#xff1a;SELECT username, password FROM users WHERE username or 11; 雙引號 password語…

深入理解Python中的全局解釋鎖GIL

深入理解Python中的全局解釋鎖GIL 轉自&#xff1a;https://zhuanlan.zhihu.com/p/75780308 注&#xff1a;本文為蝸牛學院資深講師卿淳俊老師原創&#xff0c;首發自公眾號https://mp.weixin.qq.com/s/TBiqbSCsjIbNIk8ATky-tg&#xff0c;如需轉載請私聊我處獲得授權并注明出處…

sqli-lab————Writeup(18~20)各種頭部注入

less18 基于錯誤的用戶代理&#xff0c;頭部POST注入 admin admin 登入成功&#xff08;進不去重置數據庫&#xff09; 顯示如下 有user agent參數&#xff0c;可能存在注入點 顯示版本號&#xff1a; 爆庫&#xff1a;User-Agent:and extractvalue(1,concat(0x7e,(select …

Python GIL

轉自&#xff1a;https://blog.csdn.net/weixin_41594007/article/details/79485847 Python GIL 在進行GIL講解之前&#xff0c;我們可以先回顧一下并行和并發的區別&#xff1a; 并行&#xff1a;多個CPU同時執行多個任務&#xff0c;就好像有兩個程序&#xff0c;這兩個程序…

sqli-lab——Writeup21~38(各種過濾繞過WAF和)

Less-21 Cookie Injection- Error Based- complex - string ( 基于錯誤的復雜的字符型Cookie注入) base64編碼&#xff0c;單引號&#xff0c;報錯型&#xff0c;cookie型注入。 本關和less-20相似&#xff0c;只是cookie的uname值經過base64編碼了。 登錄后頁面&#xff1a;…

Libtorch報錯:terminate called after throwing an instance of ‘c10::Error‘ what():isTensor()INTERNAL ASS

Libtorch報錯&#xff1a;terminate called after throwing an instance of ‘c10::Error’ what(): isTensor() INTERNAL ASSERT FAILED 報錯 問題出現在筆者想要將 yolov5 通過 PyTorch 的 C 接口 Libtorch 部署到樹莓派上。 完整報錯信息&#xff1a; terminate called …

sqli-lab——Writeup(38~over)堆疊等......

知識點&#xff1a; 1.堆疊注入原理&#xff08;stacked injection&#xff09; 在SQL中&#xff0c;分號&#xff08;;&#xff09;是用來表示一條sql語句的結束。試想一下我們在 ; 結束一個sql語句后繼續構造下一條語句&#xff0c;會不會一起執行&#xff1f;因此這個想法…

mysql常規使用(建立,增刪改查,視圖索引)

目錄 1.數據庫建立 2.增刪改查 3.視圖建立&#xff1a; 1.數據庫建立 mysql> mysql> show databases; ----------------------------------- | Database | ----------------------------------- | information_schema | | ch…

php操作mysql數據庫

phpmyadmin phpadmin是一個mysql圖形化管理工具&#xff0c;是一款實用php開發的mysql苦戶端軟件&#xff0c;基于web跨平臺的管理系統&#xff0c;支持簡體中文&#xff0c;官網&#xff1a;www.phpmyadmin.net可以下載免費最新版。提供圖形化操作界面&#xff0c;完成對mysq…

C:C++ 函數返回多個參數

C/C 函數返回多個參數 轉自&#xff1a;https://blog.csdn.net/onlyou2030/article/details/48174461 筆者是 Python 入門的&#xff0c;一直很困惑 C/C 中函數如何返回多個參數。 如果一個函數需要返回多個參數&#xff0c;可以采用以下兩種方法&#xff1a; 傳引用或指針…

sql預編譯

一.數據庫預編譯起源: 數據庫接受sql語句,需要解析和制定執行,中間需要花費一段時間. 有時候同一語句可能會多次執行, 那么就會造成資源的浪費 如何減少編譯執行的時間 ? 就有了預編譯,預編譯是將這類語句提前用占位符替代,一次編譯,多次執行. 預編譯后的執行代碼會被緩存下來…

C++中智能指針的原理、使用、實現

C中智能指針的原理、使用、實現 轉自&#xff1a;https://www.cnblogs.com/wxquare/p/4759020.html 1 智能指針的作用 C程序設計中使用堆內存是非常頻繁的操作&#xff0c;堆內存的申請和釋放都由程序員自己管理。程序員自己管理堆內存可以提高了程序的效率&#xff0c;但是…

Xctf練習sql注入--supersqli

三種方法 方法一 1 回顯正常 1’回顯不正常,報sql語法錯誤 1’ -- 回顯正常,說明有sql注入點,應該是字符型注入(# 不能用) 1’ order by 3 -- 回顯失敗,說明有2個注入點 1’ union select 1,2 -- 回顯顯示過濾語句: 1’; show databases -- 爆數據庫名 -1’; show tables …

深拷貝與淺拷貝、值語義與引用語義對象語義 ——以C++和Python為例

深拷貝與淺拷貝、值語義與引用語義/對象語義 ——以C和Python為例 值語義與引用語義&#xff08;對象語義&#xff09; 本小節參考自&#xff1a;https://www.cnblogs.com/Solstice/archive/2011/08/16/2141515.html 概念 在任何編程語言中&#xff0c;區分深淺拷貝的關鍵都…

一次打卡軟件的實戰滲透測試

直接打卡抓包, 發現有疑似企業網站,查ip直接顯示以下頁面 直接顯示了后臺安裝界面…就很有意思 探針和phpinfo存在 嘗試連接mysql失敗 fofa掃描為阿里云服務器 找到公司官網使用nmap掃描,存在端口使用onethink 查詢onethink OneThink是一個開源的內容管理框架&#xff0c;…

C++中類的拷貝控制

C中類的拷貝控制 轉自&#xff1a;https://www.cnblogs.com/ronny/p/3734110.html 1&#xff0c;什么是類的拷貝控制 當我們定義一個類的時候&#xff0c;為了讓我們定義的類類型像內置類型&#xff08;char,int,double等&#xff09;一樣好用&#xff0c;我們通常需要考下面…

centos7ubuntu搭建Vulhub靶場(推薦Ubuntu)

這里寫目錄標題一.前言總結二.成功操作&#xff1a;三.出現報錯&#xff1a;四.vulhub使用正文&#xff1a;一.前言總結二.成功操作&#xff1a;三.出現報錯&#xff1a;四.vulhub使用看完點贊關注不迷路!!!! 后續繼續更新優質安全內容!!!!!一.前言總結 二.成功操作&#xff1…

使用 PyTorch 數據讀取,JAX 框架來訓練一個簡單的神經網絡

使用 PyTorch 數據讀取&#xff0c;JAX 框架來訓練一個簡單的神經網絡 本文例程部分主要參考官方文檔。 JAX簡介 JAX 的前身是 Autograd &#xff0c;也就是說 JAX 是 Autograd 升級版本&#xff0c;JAX 可以對 Python 和 NumPy 程序進行自動微分。可以通過 Python的大量特征…