Pytorch的C++接口實踐

Pytorch1.1版本已經提供了相對穩定的c++接口,網上也有了眾多的資料供大家參考,進行c++的接口的初步嘗試。

可以按照對應的選項下載,下面我們要說的是:

如何利用已經編譯好的官方libtorch庫和其他的opencv庫等聯合編寫應用?

其實很簡單,大概的步驟有三步:

第一步:在python環境下將模型導出為jit的模型

第二步:編寫對應的c++ inference 程序。

第三步:直接在VS上(已經成功實驗VS2015,高版本的應該也可以)配置相應的libtorch環境,主要是:

dll路徑:?

PATH=H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\lib%3bD:\opencv\build\x64\vc14\bin%3b$(PATH)? 相應地去修改即可,不需要在PC的path環境下加入libtorch的路徑,而是在這里加更加簡單。

include路徑:

H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\include\torch\csrc\api\include;H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\include;D:\opencv\build\include\opencv2;D:\opencv\build\include\opencv;D:\opencv\build\include;%(AdditionalIncludeDirectories)

主要是加粗線那兩個。

注意一定要去掉SDL的檢查項,否則會出現錯誤警告。

lib路徑:

H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\lib;D:\opencv\build\x64\vc14\lib;%(AdditionalLibraryDirectories)

詳細的工程見:https://download.csdn.net/download/xiamentingtao/11486608

這里我們主要改編自:《Win10+VS2017+PyTorch(libtorch) C++ 基本應用》

主要代碼參考:?https://github.com/zhpmatrix/load-pytorch-model-with-c-

一些 常見的問題:

1. opencv的mat讀入libtorch

根據我的實踐,這里的最佳寫法是:

src = imread(s, cv::IMREAD_COLOR);  //讀圖// 圖像預處理 注意需要和python訓練時的預處理一致
int org_w = src.cols;
int org_h = src.rows;torch::Tensor img_tensor = torch::from_blob(src.data, { org_h, org_w,3 }, torch::kByte); //將cv::Mat轉成tensor,大小為448,448,3
img_tensor = img_tensor.permute({ 2, 0, 1 });  //調換順序變為torch輸入的格式 3,448,448
img_tensor = img_tensor.toType(torch::kFloat32).div_(255);

注意要先將uint8的圖像先讀入,再轉換成float型。

2. Tensor 轉換成cv::Mat

cv::Mat input(img_tensor.size(1), img_tensor.size(2), CV_32FC1, img_tensor.data<float>());

注意這里一定是CV_32FC1而不是CV_32FC3

另外的方式見:https://discuss.pytorch.org/t/convert-torch-tensor-to-cv-mat/42751/2

torch::Tensor out_tensor = module->forward(inputs).toTensor();
assert(out_tensor.device().type() == torch::kCUDA);
out_tensor=out_tensor.squeeze().detach().permute({1,2,0});
out_tensor=out_tensor.mul(255).clamp(0,255).to(torch::kU8);
out_tensor=out_tensor.to(torch::kCPU);
cv::Mat resultImg(512, 512,CV_8UC3);
std::memcpy((void*)resultImg.data,out_tensor.data_ptr(),sizeof(torch::kU8)*out_tensor.numel());

3. model的輸出處理

如果只有一個返回值,可以直接轉tensor:auto outputs = module->forward(inputs).toTensor();如果有多個返回值,需要先轉tuple:auto outputs = module->forward(inputs).toTuple();
torch::Tensor out1 = outputs->elements()[0].toTensor();
torch::Tensor out2 = outputs->elements()[1].toTensor();

4.Tracing fails because of “parameter sharing”?

看這個案例:https://discuss.pytorch.org/t/help-tracing-fails-because-of-parameter-sharing/40324

其中的部分代碼如上,問題就出現在這些畫框的地方,主要是這里初始化重復使用了相同的模塊進行賦值,例如self.encoder與self.conv1。

解決的辦法就是在構造slef.conv1時,對self.encoder[0]加入deepcopy修飾。

即:

from copy import deepcopy
self.conv1 = nn.Sequential(deepcopy(self.encoder[0]),deepcopy(self.relu),deepcopy(self.encoder[2]),deepcopy(self.relu))

參考:https://github.com/pytorch/pytorch/issues/8392#issuecomment-431863763

5. 關于python導出模型的問題

如果訓練的pytorch模型保存在cpu上,想在測試時使用gpu模式,則我們需要設置python端保存模型在gpu上,然后才能c++上使用gpu測試。

主要的方法就是:

    checkpoint = torch.load(model_path, map_location="cuda:0")  #very important# create modelmodel = TheModelClass(*args, **kwargs)model.load_state_dict(checkpoint)model.to(device)model.eval()x = torch.rand(1, 3, 448, 448)x = x.to(device)  # very importanttraced_script_module = torch.jit.trace(model.model, x)traced_script_module.save("**.pt")

然后才能在c++上使用gpu模式,方法為:

    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);module->to(at::kCUDA);assert(module != nullptr);std::cout << "ok\n";// 建立一個輸入,維度為(1,3,224,224),并移動至cudastd::vector<torch::jit::IValue> inputs;inputs.push_back(torch::ones({1, 3, 224, 224}).to(at::kCUDA));// Execute the model and turn its output into a tensor.at::Tensor output = module->forward(inputs).toTensor();

參考:

?

pytorch跨設備保存和加載模型(變量類型(cpu/gpu)不匹配原因之一)

https://pytorch.org/tutorials/beginner/saving_loading_models.html

https://blog.csdn.net/IAMoldpan/article/details/85057238

參考文獻:

1.利用Pytorch的C++前端(libtorch)讀取預訓練權重并進行預測

2.Pytorch的C++端(libtorch)在Windows中的使用

3.?https://pytorch.org/tutorials/advanced/cpp_frontend.html

4.?https://zhpmatrix.github.io/2019/03/01/c++-with-pytorch/

5.?Windows使用C++調用Pytorch1.0模型

6.?用cmake構建基于qt5,opencv,libtorch項目

7.?c++調用pytorch模型并使用GPU進行預測?(較好的例子)

8.?Ptorch 與libTorch 使用過程中問題記錄

9.?c++ load pytorch 的數據轉換

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/258502.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/258502.shtml
英文地址,請注明出處:http://en.pswp.cn/news/258502.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

一次慘痛的裝機經歷

最近不小心把我的聯想一體機電腦系統搞壞了&#xff0c;就不得不重裝系統&#xff0c;之前的系統是win7&#xff0c;于是開始的時候想著直接裝win10&#xff0c;升級一下系統。但是裝的過程中總是卡在了win10的正在準備系統中&#xff0c;進度環不轉了。后來轉了多次都不行&…

unity讓對象作為參數_unity-container – 一個unity容器可以將自身的引用作為構造函數參數傳遞嗎?...

簡短的答案是肯定的。當您使用Resolve方法時&#xff0c;這應該自動傳遞。例如&#xff1a;IUnityContainer container new UnityContainer();var something container.Resolve();另外&#xff0c;如果您想查看&#xff0c;這與Prism(CodePlex)使用的技術相同。更新增加測試&…

KnockoutJS + My97DatePicker

如何將Knockoutjs和其他腳本庫結合使用&#xff1f;這里給出一個Knockoutjs與my97datepicker配合使用的例子&#xff0c;例子中使用了ko的自定義綁定功能&#xff1a; ko.bindingHandlers.my97DatePicker {init: function (element, valueAccessor) {$(element).on(click, fun…

HttpClient v4.5 簡單抓取主頁數據

由于工作原因&#xff0c;需要每隔半小時刷新一些網頁&#xff0c;并查看上面的數據是否有更新。這件事能否自動化進行呢&#xff1f;查找了下Java相關的資料&#xff0c;蹦出一個關鍵詞&#xff1a;HttpClient。 HttpClient是常用Http客戶端庫&#xff0c;相關的資料也不少&am…

matlab局部放大的圖中圖畫法

【親測有效】 在作圖過程中&#xff0c;如果想將局部信息展示出來并且畫在同一張圖中&#xff0c;一般的MATLAB作圖法就比較拙計了&#xff0c;好在MATLAB還是很強大的&#xff0c;當然&#xff0c;除了不能當女朋友之外 .... ╮(╯▽╰)╭ function showdetail()% 在當前的ax…

進入Python世界——Python基礎知識

本文通過實例練習Python基礎語法, python版本2.7 # -*- coding: utf-8 -*- import randomimport re import requests from bs4 import BeautifulSoup# 爬取糗事百科的首頁內容 def qiushibaike():content requests.get(http://www.qiushibaike.com/).contentsoup BeautifulS…

db2 版本發布歷史_數據庫各廠商的發展歷史(2. DB2 of IBM)

如若轉載&#xff0c;請務必注明出處&#xff0c;iihero 2008.9.26于CSDN1973年&#xff0c;IBM研究中心啟動System R項目&#xff0c;為DB2的誕生打下良好基礎。System R 是 IBM 研究部門開發的一種產品&#xff0c;這種原型語言促進了技術的發展并最終在1983年將 DB2 帶到了商…

android---簡單的通訊錄

遺留問題:獲取頭像及其他信息 利用adapter和Cursor來獲取聯系人的姓名和手機號,重在復習之前學過的內容加深自己的理解. 其中需要注意的部分: 1.adapter中的getview的優化問題,用到tag這一屬性 2.onBackPressed()返回方法的重寫,使得程序更加人性化 下面是主要代碼 1.adapte…

win phone 獲取并且處理回車鍵事件

參考自&#xff1a;http://www.cnblogs.com/mohe/archive/2013/03/18/2966540.html 實用場景,比如輸入帳號和密碼啦,輸入搜索關鍵字啦.protected override void OnKeyDown(KeyEventArgs e) {if (e.Key Key.Enter){MessageBox.Show("我是windows phone 回車鍵"); …

【2020年】最新中國科學院大學學位論文寫作規范

最近在完成國科大博士論文寫作的時候&#xff0c;有一些心得體會&#xff0c;特此總結下來&#xff0c;以饗讀者&#xff0c;尤其是可愛的學弟學妹們。需要注意的是&#xff0c; 以下僅僅是我自己的心得而已&#xff0c;僅供參考。 1. 首先推薦大家使用國科大的Latex模板&…

談談Java基礎數據類型

Java的基本數據類型 類型意義取值boolean布爾值true或falsebyte8位有符號整型-128~127short16位有符號整型-pow(2,15)~pow(2,15)-1int32位有符號整型-pow(2,31)~pow(2,31)-1long64位有符號整型-pow(2,63)~pow(2,63)-1float32位浮點數IEEE754標準單精度浮點數double64位浮點數IE…

用fft對信號進行頻譜分析實驗報告_示波器上的頻域分析利器,Spectrum View測試分析...

簡介&#xff1a;【Spectrum View技術文章系列】從基礎篇開始&#xff0c;講述利用示波器上的Spectrum View功能觀測多通道信號頻譜分析正文&#xff1a;示波器和頻譜儀都是電子測試測量中必不可少的測試設備&#xff0c;分別用于觀察信號的時域波形和頻譜。時域波形是信號最原…

DataTable RowFilter 過濾數據

用Rowfilter加入過濾條件 eg&#xff1a; string sql "select Name,Age,Sex from UserInfo"; DataTable dt DataAccess.GetDataTable(sql);//外部方法&#xff08;通過一條查詢語句返回一個DataTable&#xff09; dt.DefaultView.RowFilter "Sex女"; dt…

platform_device與platform_driver

做Linux方面也有三個多月了&#xff0c;對代碼中的有些結構一直不是非常明確&#xff0c;比方platform_device與platform_driver一直分不清關系。在網上搜了下&#xff0c;做個總結。兩者的工作順序是先定義platform_device -> 注冊 platform_device->&#xff0c;再定義…

復盤caffe安裝

最近因之前的服務器上的caffe奔潰了&#xff0c;不得已重新安裝這一古老的深度學習框架&#xff0c;之前也嘗試了好幾次&#xff0c;每次都失敗&#xff0c;這次總算是成功了&#xff0c;因此及時地總結一下。 以下安裝的caffe主要是針對之前虹膜分割和鞏膜分割所需的caffe版本…

HP P2000 RAID-5兩塊盤離線的數據恢復報告

1. 故障描述本案例是HP P2000的存儲vmware exsi虛擬化平臺&#xff0c;由RAID-5由10塊lT硬盤組成&#xff0c;其中6號盤是熱備盤&#xff0c;由于故障導致RAID-5磁盤陣列的兩塊盤掉線&#xff0c;表現為兩塊硬盤亮黃燈。 經用戶方維護人員檢測&#xff0c;故障硬盤應為物理故障…

微智魔盒騙局_微智魔盒官宣

原標題&#xff1a;微智魔盒官宣微智魔盒官方宣傳視頻微達國際集團創建于2011年&#xff0c;是一家堅持創新的集科研、產銷、服務為一體的智能化產業平臺&#xff0c;致力于國際領先的專注人工智能領域的產業投資、項目孵化、教育培訓&#xff0c;并提供終極解決方案。集團創新…

瑞柏匡丞_移動互聯的發展現狀與未來

互聯網作為人類文明史上最偉大、最重要的科技發明之一&#xff0c;發展到今天&#xff0c;用翻天覆地來形容并不過分。而作為傳統互聯網的延伸和演進方向&#xff0c;移動互聯網更是在近兩年得到了迅猛的發展。如今&#xff0c;越來越多的用戶得以通過高速的移動網絡和強大的智…

android 進程間通信數據(一)------parcel的起源

關于parcel&#xff0c;我們先來講講它的“父輩” Serialize。 Serialize 是java提供的一套序列化機制。但是為什么要序列化&#xff0c;怎么序列化&#xff0c;序列化是怎么做到的&#xff0c;我們將在本文探討下。 一&#xff1a;java 中的serialize 關于Serialize這個東東&a…

為什么torch.nn.Linear的表達形式為y=xA^T+b而不是常見的y=Ax+b?

今天看代碼&#xff0c;對比了常見的公式表達與代碼的表達&#xff0c;發覺torch.nn.Linear的數學表達與我想象的有點不同&#xff0c;于是思索了一番。 眾多周知&#xff0c;torch.nn.Linear作為全連接層&#xff0c;將下一層的每個結點與上一層的每一節點相連&#xff0c;用…