如何使用C++調用Pytorch模型進行推理測試:使用libtorch庫

如何使用C++調用Pytorch模型進行推理測試:使用libtorch庫

目錄

      • 如何使用C++調用Pytorch模型進行推理測試:使用libtorch庫
        • 一、環境準備
          • 1,linux:以ubuntu 22.04系統為例
            • 1. 準備CUDA和CUDNN
            • 2. 準備C++環境
            • 3, 下載libtorch文件
            • 4, 編寫測試libtorch是否安裝成功
          • 2, windows: 以win10系統為例
            • 1, 準備CUDA和CUDNN
            • 2,準備C++編譯環境
            • 3,下載安裝libtorch
            • 4. 注意事項
          • 二、C++代碼封裝Pytorch模型測試:以resnet-18分類為例
          • 1, 安裝opencv用于讀取圖像
          • 2,用python導出訓練好的pytorch模型
          • 3,編寫C++代碼測試

一、環境準備
1,linux:以ubuntu 22.04系統為例
1. 準備CUDA和CUDNN

有兩種方式配置cuda和cudnn,一種是在系統環境安裝,可以參考:深度學習環境配置——ubuntu安裝CUDA與CUDNN

還有一種是在conda虛擬環境使用cudatoolkit-dev包,具體可以參考:Installing-and-Test-PyTorch-C-API-on-Ubuntu-with-GPU-enabled

我選擇的方式是在系統環境安裝cuda12.1和cudnn8.9.2。

可使用如下命令查看是否安裝成功:

NVCC -V
cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2

image-20240625103837610

2. 準備C++環境

安裝gcc, cmake和GLIBC,用apt install即可

可使用如下命令是否查看是否安裝成功:

gcc --version
cmake --version
ldd --version

image-20240625103749911

3, 下載libtorch文件

去pytoch官網https://pytorch.org/下載即可:

image-20240625103946244

可使用如下命令下載并解壓:

wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.3.1%2Bcu121.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.3.1+cu121.zip

將libtorch路徑配置到path變量:

vim ~/.bashrc

最后一行加入:

export LD_LIBRARY_PATH=/path/to/libtorch/lib:$LD_LIBRARY_PATH

注意將/path/to/libtorch替換為實際的path,我這里是/mnt/data1/zq/libtorch

查看是否成功:

source ~/.bashrc
echo $LD_LIBRARY_PATH

image-20240625110447696

4, 編寫測試libtorch是否安裝成功

創建main.cpp文件,內容如下:

#include <torch/torch.h>
#include <iostream>int main() {if (torch::cuda::is_available()) {std::cout << "CUDA is available! Running on GPU." << std::endl;// 創建一個隨機張量并將其移到GPU上torch::Tensor tensor_gpu = torch::rand({2, 3}).cuda();std::cout << "Tensor on GPU:\n" << tensor_gpu << std::endl;} else {std::cout << "CUDA not available! Running on CPU." << std::endl;// 創建一個隨機張量并保持在CPU上torch::Tensor tensor_cpu = torch::rand({2, 3});std::cout << "Tensor on CPU:\n" << tensor_cpu << std::endl;}return 0;
}

編譯和運行

創建CMakeLists.txt文件,內容如下:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(test_project)# Setting the C++ standard to C++17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)# If additional compiler flags are needed
add_compile_options(-Wall -Wextra -pedantic)# Setting the location of LibTorch
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)# Specify the name of the executable and the corresponding source file
add_executable(test_project main.cpp)# Linking LibTorch libraries
target_link_libraries(test_project "${TORCH_LIBRARIES}")# Set the output directory for the executable
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin)

/path/to/libtorch替換為實際的path

編譯并測試:

mkdir build
cd build
cmake ..
make 

編譯完成之后,應該會出現一個bin目錄,其中有一個test_project文件,直接運行即可看到輸出。

image-20240625111448917

出現CUDAFloatType說明,libtorch的GPU版本安裝成功。

2, windows: 以win10系統為例
1, 準備CUDA和CUDNN

可參考:Windows10下CUDA與cuDNN的安裝

2,準備C++編譯環境

這一步需要配置cmake, mingw。可參考:Windows 配置 C/C++ 開發環境

建議直接安裝Visual Studio這個IDE,可參考:Windows libtorch C++部署GPU版

3,下載安裝libtorch

參考這個視頻:

win10系統上LibTorch的安裝和使用(cuda10.1版本)

一個很水的LibTorch教程(1)

4. 注意事項

windows環境我沒有做測試,不保證一定可以成功。linux環境是親自測試的,保證可以復現

二、C++代碼封裝Pytorch模型測試:以resnet-18分類為例
1, 安裝opencv用于讀取圖像

需要使用opencv來讀取圖像數據,可通過如下命令安裝:

sudo apt install libopencv-dev
dpkg -l | grep libopencv # 查看是否安裝成功
2,用python導出訓練好的pytorch模型

在將PyTorch模型應用于C++環境之前,需要將其轉換為TorchScript。這可以通過兩種方式實現:tracingscripting。可以通過如下代碼導出訓練好的ResNet-18模型:

import torch
import torchvision# 加載預訓練的模型
model = torchvision.models.resnet18(pretrained=True)# 將模型設置為評估模式
model.eval()# 創建一個示例輸入
example_input = torch.rand(1, 3, 224, 224)  # 模型輸入的大小# 使用tracing導出模型
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("resnet18.pt")
3,編寫C++代碼測試

創建main.cpp文件,內容如下:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <filesystem>// Function to transform image to tensor
torch::Tensor transform_image(const cv::Mat& image) {cv::Mat img_transformed;cv::cvtColor(image, img_transformed, cv::COLOR_BGR2RGB);cv::resize(img_transformed, img_transformed, cv::Size(224, 224));img_transformed.convertTo(img_transformed, CV_32FC3, 1.0/255);auto img_tensor = torch::from_blob(img_transformed.data, {img_transformed.rows, img_transformed.cols, 3}, torch::kFloat);img_tensor = img_tensor.permute({2, 0, 1});img_tensor = torch::data::transforms::Normalize<torch::Tensor>({0.485, 0.456, 0.406}, {0.229, 0.224, 0.225})(img_tensor);img_tensor = img_tensor.unsqueeze(0);return img_tensor;
}// Load the model and classify an image
void classify_image(const std::string& model_path, const std::string& image_path) {// Load the modeltorch::jit::script::Module model = torch::jit::load(model_path);model.eval(); // Switch to evaluation mode// Load and transform the imagecv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR);if (image.empty()) {std::cerr << "Could not read the image: " << image_path << std::endl;return;}torch::Tensor tensor_image = transform_image(image);// Perform inferencetorch::Tensor output = model.forward({tensor_image}).toTensor();int64_t pred = output.argmax(1).item<int64_t>();std::cout << "The image is classified as class index: " << pred << std::endl;
}int main(int argc, char* argv[]) {std::string model_path = "resnet18.pt"; // Default model pathstd::string image_path = "default_image.jpg"; // Default image path// 從命令行接受兩個參數, 分別作為model_path和image_pathif (argc >= 3) {model_path = argv[1];image_path = argv[2];} else {std::cout << "Using default model and image paths." << std::endl;}classify_image(model_path, image_path);return 0;
}

創建CMakeLists.txt,內容如下:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(ImageClassification)set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)# 設置LibTorch的位置, /path/to/libtorch替換為實際路徑
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)find_package(OpenCV REQUIRED)add_executable(ImageClassification main.cpp)
target_link_libraries(ImageClassification "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")

編譯并運行:

mkdir build && cd build
cmake ..
make

在build目錄下會出現ImageClassification這個可執行文件,直接運行傳入model_path和image_path即可。

image-20240625114911739

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/40137.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/40137.shtml
英文地址,請注明出處:http://en.pswp.cn/web/40137.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

期權學習必看圣書:《3小時快學期權》要在哪里看?

今天帶你了解期權學習必看圣書&#xff1a;《3小時快學期權》要在哪里看&#xff1f;《3小時快學期權》是一本關于股票期權基礎知識的書籍。 它旨在通過簡明、易懂的語言和實用的案例&#xff0c;讓讀者在短時間內掌握股票期權的基本概念、操作方法和投資策略。通過這本書&…

img的onload事件不觸發

var img new Image(); img.src "圖片地址" img.addEventListener(load, function() {// 加載完后的處理 }, false);上面的代碼&#xff0c;可能存在設置addEventListener之前&#xff0c;圖片已經加載完了&#xff0c;onload事件不處罰。 調換一下img.src和img.add…

Linux系統(CentOS)安裝Mysql5.7.x

安裝準備&#xff1a; Linux系統(CentOS)添加防火墻、iptables的安裝和配置 請訪問地址&#xff1a;https://blog.csdn.net/esqabc/article/details/140209894 1&#xff0c;下載mysql安裝文件&#xff08;mysql-5.7.44為例&#xff09; 選擇Linux通用版本64位&#xff08;L…

算力互聯網網絡架構;SRV6;智享WAN

目錄 算力互聯網網絡架構 SRV6 主要特點 應用場景 結論 G-SRV6 多層次網絡切片 智享WAN 一、定義與背景 二、關鍵技術 三、應用場景與優勢 四、發展現狀與未來展望 智能算力網絡成為智能經濟時代代表性數字基礎設施 算力互聯網網絡架構 為構建算力互聯網這個前瞻性…

SQLAlchemy配置連接多個數據庫

1.定義配置項 首先定義兩個數據庫的配置信息 # PostgreSQL database configuration DB_USERNAMEpostgres DB_PASSWORDpassord DB_HOST127.0.0.1 DB_PORT5432 DB_DATABASEtest# mysql database configuration DB_USERNAME_MYSQLroot DB_PASSWORD_MYSQLpassword DB_HOST_MYSQL127…

后端之路——阿里云OSS云存儲

一、何為阿里云OSS 全名叫“阿里云對象存儲OSS”&#xff0c;就是云存儲&#xff0c;前端發文件到服務器&#xff0c;服務器不用再存到本地磁盤&#xff0c;可以直接傳給“阿里云OSS”&#xff0c;存在網上。 二、怎么用 大體邏輯&#xff1a; 細分的話就是&#xff1a; 1、準…

Rust: Fury高性能序列化庫嘗試

在序列化庫中&#xff0c;傳統的有Json,XML&#xff0c;性能好的有thrift&#xff0c;protobuf等。據說Fury官網的介紹&#xff0c;Fury性能要遠遠好于protobuf&#xff0c;且不象protobuf還需要定義IDL&#xff0c;非常輕便&#xff0c;隨取隨用。 今天來嘗試一下。 一、carg…

gitlab每日備份以及restore

gitlab服務有非常簡潔的每日備份命令&#xff0c; 從production的gitlab的每日備份中restore到backup環境也非常方便。 一、Production gitlab每日備份 1. Production gitlab環境上編寫腳本 cat /root/gitlab_bak.shgitlab-rake gitlab:backup:create > /var/opt/gitl…

JavaSE (Java基礎):面向對象(下)

8.7 多態 什么是多態&#xff1f; 即同一方法可以根據發送對象的不同而采用多種不同的方式。 一個對象的實際類型是確定的&#xff0c;但可以指向對象的引用的類型有很多。在句話我是這樣理解的&#xff1a; 在實例中使用方法都是根據他最開始將類實例化最左邊的類型來定的&…

消息中間件ApacheKafka在windows簡單安裝

一.背景 之前公司需要API網關管理軟件ApacheShenYu&#xff0c;我相信把調用的記錄都存到一個數據庫。他支持日志推送到kafka&#xff0c;所以&#xff0c;我準備嘗試一下通過kafka接收調用的日志信息。第一步&#xff0c;當然是安裝kafka了。 二.ApacheKafka的下載 打開下載…

【C++】 解決 C++ 語言報錯:Memory Leak

文章目錄 引言 內存泄漏&#xff08;Memory Leak&#xff09;是 C 編程中常見且嚴重的內存管理問題之一。當程序分配了內存而沒有正確釋放&#xff0c;導致內存無法被重新利用時&#xff0c;就會發生內存泄漏。這種錯誤會導致程序占用越來越多的內存&#xff0c;最終可能導致系…

關于人情世故的小討論

大家好&#xff0c;我是阿趙。 ??最近國內籃球界內出了不少事情&#xff0c;讓人對籃球這項運動產生了很多疑問。 ??去年的CUBA&#xff0c;擁有全國最好生源的清華大學居然輸給了連985 、211都不是的廣東工業大學。作為廣工的畢業生&#xff0c;我知道廣工的籃球一直都很強…

Unity PC和Android端的數據存儲和讀取

使用Resource&#xff1a; 提示&#xff1a;使用resouce打包后會被壓縮進.resources文件中&#xff0c;意味著它是只讀文件&#xff0c;且必須使用resouce.load加載&#xff1a; /// <summary>/// 全平臺使用/// </summary>/// <typeparam name"T"&g…

論文學習——動態多目標優化的一種新的分位數引導的對偶預測策略

論文題目&#xff1a;A novel quantile-guided dual prediction strategies for dynamic multi-objective optimization 動態多目標優化的一種新的分位數引導的對偶預測策略&#xff08;Hao Sun a,b, Anran Cao a,b, Ziyu Hu a,b, Xiaxia Li a,b, Zhiwei Zhao c&#xff09;In…

“免費”的可視化大屏案例分享-智慧園區綜合管理平臺

一.智慧園區是什么&#xff1f; 智慧園區是一種融合了新一代信息與通信技術的先進園區發展理念。它通過迅捷信息采集、高速信息傳輸、高度集中計算、智能事務處理和無所不在的服務提供能力&#xff0c;實現了園區內及時、互動、整合的信息感知、傳遞和處理。這樣的園區旨在提高…

自定義注解-手機號驗證注解

注解 package com.XX.assess.annotation;import com.XX.assess.util.MobileValidator;import javax.validation.Constraint; import javax.validation.Payload; import java.lang.annotation.*;/*** 手機號校驗注解* @author super*/ @Retention(RetentionPolicy.RUNTIME) @Ta…

正確使用Pytorch Geometric打開Cora(Planetoid)數據集

文章目錄 關于報錯&#xff08;"Cannot connect to host"&#xff09;解決方法 關于報錯&#xff08;“Cannot connect to host”&#xff09; 我們在使用PyG調用Planetoid數據集的時候&#xff0c;常會碰到如下報錯&#xff1a; 解決方法就是手動下載這個數據集。…

在 AWS Lambda 中使用 Flask 應用

本文將介紹如何在 AWS Lambda 中創建和部署一個使用 Flask 框架的應用。 1. 創建 Lambda 函數 首先,在 AWS Lambda 控制臺創建一個新的函數,命名為 ??flask-app??。 2. 準備 Flask 層 為了在 Lambda 中使用 Flask,我們需要創建一個包含 Flask 庫的層。按照以下步驟操…

java中如何使用ffmpeg命令來實現視頻編碼轉換

在Java中使用FFmpeg命令來進行視頻編碼轉換&#xff0c;可以通過調用系統命令來執行FFmpeg命令。下面是一個使用FFmpeg進行視頻轉碼的示例代碼&#xff1a; import java.io.BufferedReader; import java.io.InputStreamReader;public class FFmpegVideoConverter {public stat…

前端播放RTSP視頻流,使用FLV請求RTSP視頻流播放(Vue項目,在Vue中使用插件flv.js請求RTSP視頻流播放)

簡述&#xff1a;在瀏覽器中請求 RTSP 視頻流并進行播放時&#xff0c;直接使用原生的瀏覽器 API 是行不通的&#xff0c;因為它們不支持 RTSP 協議。為了解決這個問題&#xff0c;開發者通常會選擇使用像 flv.js 這樣的庫&#xff0c;它專為在瀏覽器中播放 FLV 和其他流媒體格…