機器學習模型在C++平臺的部署

一、概述

??機器學習模型的訓練通常在Python環境下完成,而現實生產環境的復雜性和多樣性使得模型的部署成為一個值得關注的重點。不同應用場景下有不同適應的實現方式,這里主要介紹通過一種通用中間格式——ONNX(Open Neural Network Exchange),來實現機器學習模型在C++平臺的部署。

二、步驟

??s1. Python環境中安裝onnxruntime、skl2onnx工具模塊;

??s2. Python環境中訓練機器學習模型;

??s3. 將訓練好的模型保存為.onnx格式的模型文件;

??s4. C++環境中安裝Microsoft.ML.OnnxRuntime程序包;
(Visual Studio 2022中可通過項目->管理NuGet程序包完成快捷安裝)

??S5. C++環境中加載模型文件,完成功能開發。

三、示例

??使用 Python 訓練一個線性回歸模型并將其導出為 ONNX 格式的文件,在C++環境下完成對模型的部署和推理。

1.Python訓練和導出

(環境:Python 3.11,scikit-learn 1.6.1,onnxruntime 1.22.0,skl2onnx 1.19.1)

import numpy as np
import onnxruntime as ort
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType# 生成示例數據
X, y = make_regression(n_samples=100, n_features=5, random_state=42)# 訓練線性回歸模型
model = LinearRegression()
model.fit(X, y)# 定義輸入格式
initial_type = [('input', FloatTensorType([None, 5]))]# 轉換模型為 ONNX 格式
onnx_model = convert_sklearn(model, initial_types=initial_type)# 保存 ONNX 模型
with open("linear_regression.onnx", "wb") as f:f.write(onnx_model.SerializeToString())print("\n模型已保存為: linear_regression.onnx\n")# 測試導出的模型
ort_session = ort.InferenceSession("linear_regression.onnx")
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name# 創建一個測試樣本
test_input = np.array([0.1, 0.2, 0.3, 0.4, 0.5]).reshape(1,5).astype(np.float32)# 運行推理
results = ort_session.run([output_name], {input_name: test_input})print(f"測試輸入: {test_input}")
print(f"預測結果: {results[0]}")

在這里插入圖片描述

2. C++ 部署和推理

(環境:C++ 14,Microsoft.ML.OnnxRuntime 1.22.0)

#include <iostream>
#include <vector>
#include <string>
#include <memory>
#include <onnxruntime_cxx_api.h>int main() {// 初始化環境Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXExample");// 初始化會話選項Ort::SessionOptions session_options;session_options.SetIntraOpNumThreads(1);session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);// 加載模型std::wstring model_path = L"linear_regression.onnx";Ort::Session session(env, model_path.c_str(), session_options);// 獲取輸入信息Ort::AllocatorWithDefaultOptions allocator;size_t num_inputs = session.GetInputCount();size_t num_outputs = session.GetOutputCount();// 假設只有一個輸入和一個輸出if (num_inputs != 1 || num_outputs != 1) {std::cerr << "模型必須有且僅有一個輸入和一個輸出" << std::endl;return 1;}// 獲取輸入名稱、類型和形狀std::string input_name = session.GetInputNameAllocated(0, allocator).get();Ort::TypeInfo input_type_info = session.GetInputTypeInfo(0);auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();ONNXTensorElementDataType input_type = input_tensor_info.GetElementType();std::vector<int64_t> input_dims = input_tensor_info.GetShape();// 獲取輸出名稱std::string output_name = session.GetOutputNameAllocated(0, allocator).get();// 創建輸入數據std::vector<float> input_data = { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f };size_t input_size = 5;// 創建輸入張量std::vector<int64_t> input_shape = { 1, static_cast<int64_t>(input_size) };auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_data.data(),input_data.size(), input_shape.data(), 2);// 驗證輸入張量是否為張量if (!input_tensor.IsTensor()) {std::cerr << "創建的輸入不是張量類型" << std::endl;return 1;}// 運行模型std::vector<const char*> input_names = { input_name.c_str() };std::vector<const char*> output_names = { output_name.c_str() };std::vector<Ort::Value> outputs = session.Run(Ort::RunOptions{ nullptr },input_names.data(),&input_tensor,1,output_names.data(),1);// 獲取輸出結果float* output_data = outputs[0].GetTensorMutableData<float>();Ort::TensorTypeAndShapeInfo output_info = outputs[0].GetTensorTypeAndShapeInfo();std::vector<int64_t> output_dims = output_info.GetShape();// 輸出結果std::cout << "輸入數據: ";for (float val : input_data) {std::cout << val << " ";}std::cout << std::endl;std::cout << "預測結果: ";for (size_t i = 0; i < output_info.GetElementCount(); ++i) {std::cout << output_data[i] << " ";}std::cout << std::endl;return 0;
}

在這里插入圖片描述



End.

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/90403.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/90403.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/90403.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

保姆級安裝 Ruby 環境下載及安裝教程, RubyInstaller下載及安裝教程

一、下載安裝 RubyInstaller 1.打開 RubyInstaller 官網&#xff1a;https://rubyinstaller.org/ 點擊跳轉, 官網界面如下圖&#xff1a; 點擊下載最新的 RubyDevkit 版本&#xff08;如 RubyDevkit 3.4.X (x64) &#xff09;。如下圖所示&#xff1a; 注意點&#xff1a;如果…

SQL 一鍵生成 Go Struct!支持字段注釋、類型映射、結構體命名規范

SQL 一鍵生成 Go Struct&#xff01;支持字段注釋、類型映射、結構體命名規范 在 Golang 開發中&#xff0c;尤其是操作數據庫時&#xff0c;我們經常會遇到這種場景&#xff1a; ? 拿到數據庫建表 SQL&#xff0c;卻要手動寫 Go struct? 字段幾十個、類型復雜&#xff0c;…

Web 前端框架選型:React、Vue 和 Angular 的對比與實踐

Web 前端框架選型&#xff1a;React、Vue 和 Angular 的對比與實踐 選擇前端框架就像選擇一個長期合作伙伴。錯誤的選擇可能會讓你的項目在未來幾年內背負沉重的技術債務&#xff0c;而正確的選擇則能讓開發效率飛速提升。 經過多年的項目實踐&#xff0c;我發現很多新人在框架…

C# 值拷貝、引用拷貝、淺拷貝、深拷貝

值拷貝定義&#xff1a;直接復制變量的值&#xff0c;適用于基本數據類型&#xff08;如int, float, char等&#xff09;。在 C# 中&#xff0c;值類型&#xff08;基本數據類型和結構體&#xff09;默認使用值拷貝。特點&#xff1a;創建原始值的完全獨立副本&#xff0c;修改…

深度學習圖像分類數據集—百種鳥類識別分類

該數據集為圖像分類數據集&#xff0c;適用于ResNet、VGG等卷積神經網絡&#xff0c;SENet、CBAM等注意力機制相關算法&#xff0c;Vision Transformer等Transformer相關算法。 數據集信息介紹&#xff1a;525種鳥類識別分類 訓練數據集總共有84635張圖片&#xff0c;每個文件夾…

零基礎 “入坑” Java--- 八、類和對象(一)

文章目錄一、初識面向對象二、類的定義和使用1.認識類2.類的定義格式三、類的實例化四、this引用五、對象的構造及初始化1.有關初始化2.構造方法3.就地初始化一、初識面向對象 Java是一門純面向對象的語言&#xff08;OOP&#xff09;&#xff0c;在面向對象的世界里&#xff…

數字孿生技術引領UI前端設計新篇章:智能物聯網的深度集成

hello寶子們...我們是艾斯視覺擅長ui設計、前端開發、數字孿生、大數據、三維建模、三維動畫10年經驗!希望我的分享能幫助到您!如需幫助可以評論關注私信我們一起探討!致敬感謝感恩!一、引言&#xff1a;數字孿生與物聯網的共生革命在智能設備爆發式增長的今天&#xff0c;傳統…

代碼審計-shiro漏洞分析

一、關于shiro介紹 簡單講&#xff0c;shiro是apache旗下的一個Java安全框架&#xff0c;輕量級簡單易上手&#xff0c;框架提供很多功能接口&#xff0c;常見的身份認證 、權限認證、會話管理、Remember 記住功能、加密等等。 二、漏洞分析 1.CVE-2019-12422-shiro550 漏洞原理…

EF提高性能(查詢禁用追蹤)(關閉延遲加載)

EF默認是支持延遲加載的&#xff0c;在加載一個表的數據時&#xff0c;會把關聯表的數據一并加載&#xff0c;這樣會影響性能。 一般建議關閉延遲加載可以提高EF加載的性能。還有其他方法提高性能&#xff08;查詢禁用追蹤&#xff09; 如果要實現延遲加載&#xff0c;必須滿足…

Leetcode+JAVA+貪心III

134.加油站在一條環路上有 n 個加油站&#xff0c;其中第 i 個加油站有汽油 gas[i] 升。你有一輛油箱容量無限的的汽車&#xff0c;從第 i 個加油站開往第 i1 個加油站需要消耗汽油 cost[i] 升。你從其中的一個加油站出發&#xff0c;開始時油箱為空。給定兩個整數數組 gas 和 …

Qt信號與槽機制及動態調用

Qt信號與槽機制及動態調用一、信號與槽1、Qt信號與槽機制概述2、信號與槽的基本使用3、信號與槽的特性4、使用Lambda表達式作為槽5、信號與槽的參數傳遞6、注意事項二、動態調用機制1、基本用法2、示例代碼3、帶參數的調用4、返回值處理5、信號與槽的動態連接6、動態方法調用7、…

K8s系列之:Kubernetes 的 OLM

K8s系列之:Kubernetes 的 OLM 什么是 Kubernetes 的 OLM什么是Kubernetes中的OperatorOLM 的功能OLM 的核心組件OLM優勢OLM 的工作原理OLM 與 OperatorHub 的關系OLM示例場景什么是CRDoperator 和 CRD的關系為什么需要 CRD 和 OperatorCRD定義資源類型DebeziumServer如何使用d…

前端-HTML-day2

目錄 1、無序列表 2、有序列表 3、定義列表 4、表格-基本使用 5、表格-結構標簽 6、表格-合并單元格 7、表單-input基本使用 8、表單-input占位文本 9、表單-單選框 10、表單-上傳多個文件 11、表單-多選框 12、表單-下拉菜單 13、表單-文本域 14、表單-label標簽…

兩種方式清除已經保存的git賬號密碼

方式一隨便選擇一個文件夾&#xff0c;然后鼠標右鍵-》TortoiseGit ->設置選擇已保存的數據-》認證數據-》清除-》點擊確定方式二 控制面板\用戶帳戶\憑據管理器-》windows憑據普通憑據-》找到git信息-》選擇刪除

Using Spring for Apache Pulsar:Message Production

1. Pulsar Template在Pulsar生產者端&#xff0c;Spring Boot自動配置提供了一個用于發布記錄的PulsarTemplate。該模板實現了一個名為PulsarOperations的接口&#xff0c;并提供了通過其合約發布記錄的方法。這些send API方法有兩類&#xff1a;send和sendAsync。send方法通過…

CSS揭秘:10.平行四邊形

前置知識&#xff1a;基本的css變形一、平行四邊形 要實現一個平行四邊形&#xff0c;可以使用CSS的skew變形屬性來傾斜元素。 transform: skewX(-45deg);圖-1顯示容器和內容都出現了傾斜&#xff0c;該如何解決這個問題&#xff1f; 二、嵌套方案 我們通過將內容嵌套 div 并使…

深度學習 必然用到的 線性代數知識

把標量到張量、點積到范數全串起來&#xff0c;幫你從 0 → 1 搭建 AI 數學底座 &#x1f680; 1 標量&#xff1a;深度學習的最小單元 標量 就是一維空間里的“點”&#xff0c;只有大小沒有方向。例如溫度 52 F、學習率 0.001。 記號&#xff1a;普通小寫 x&#xff1b;域&am…

OpenGL ES 紋理以及紋理的映射

文章目錄開啟紋理創建紋理綁定紋理生成紋理紋理坐標圖像配置線性插值重復效果限制拉伸完整代碼在 Android OpenGL ES 中使用紋理&#xff08;Texture&#xff09;可以顯著提升圖形渲染的質量和效率。以下是使用紋理的主要好處&#xff1a; 增強視覺真實感 紋理可以將復雜的圖像…

從金字塔到個性化路徑:AI 正在重新定義學習方式

幾十年來&#xff0c;我們的教育系統始終遵循著一條熟悉的路線&#xff1a; 從小學、初中、高中&#xff0c;再到大學和研究生。這條標準化的路徑&#xff08;K-12 到研究所&#xff09;結構清晰&#xff0c;卻也緩慢。但在當今這個信息爆炸、知識快速更新、個性化需求高漲的時…

產品經理崗位職責拆解

以下是產品經理崗位職責的詳細分解表&#xff0c;涵蓋工作內容、核心動作及輸出成果&#xff1a;崗位職責具體工作內容輸出成果1. 日常版本迭代管理需求分析及PRD產出協調資源推動產品上線- 收集業務/用戶需求&#xff0c;分析可行性及優先級- 撰寫PRD文檔&#xff0c;明確功能…