【PyTorch】圖像多分類部署

如果需要在獨立于訓練腳本的新腳本中部署模型,這種情況模型和權重在內存中不存在,因此需要構造一個模型類的對象,然后將存儲的權重加載到模型中。

加載模型參數,驗證模型的性能,并在測試數據集上部署模型

from torch import nn
from torchvision import models# 定義一個resnet18模型,不使用預訓練參數
model_resnet18 = models.resnet18(pretrained=False)
# 獲取模型的全連接層的輸入特征數
num_ftrs = model_resnet18.fc.in_features
# 定義分類的類別數
num_classes=10
# 將全連接層的輸出特征數改為分類的類別數
model_resnet18.fc = nn.Linear(num_ftrs, num_classes)import torch 
path2weights="./models/resnet18_pretrained.pt"
# 加載預訓練的ResNet18模型權重
model_resnet18.load_state_dict(torch.load(path2weights))
# 將ResNet-18模型設置為評估模式
model_resnet18.eval();
# 檢查CUDA是否可用
if torch.cuda.is_available():# 如果可用,將設備設置為CUDAdevice = torch.device("cuda")# 將模型移動到CUDA設備上model_resnet18=model_resnet18.to(device)def deploy_model(model,dataset,device, num_classes=10,sanity_check=False):# 獲取數據集的長度len_data=len(dataset)# 初始化輸出張量y_out=torch.zeros(len_data,num_classes)# 初始化真實標簽張量y_gt=np.zeros((len_data),dtype="uint8")# 將模型移動到指定設備model=model.to(device)# 初始化時間列表elapsed_times=[]with torch.no_grad():for i in range(len_data):# 獲取數據集中的一個樣本x,y=dataset[i]# 將真實標簽存入張量y_gt[i]=y# 記錄開始時間start=time.time()    # 將輸入數據傳入模型進行預測yy=model(x.unsqueeze(0).to(device))# 將預測結果存入張量y_out[i]=torch.softmax(yy,dim=1)# 計算預測時間elapsed=time.time()-start# 將預測時間存入列表elapsed_times.append(elapsed)# 如果進行完整性檢查,則跳出循環if sanity_check is True:break# 計算平均預測時間inference_time=np.mean(elapsed_times)*1000# 打印平均預測時間print("average inference time per image on %s: %.2f ms " %(device,inference_time))# 返回預測結果和真實標簽return y_out.numpy(),y_gt
from torchvision import datasets
import torchvision.transforms as transforms# 數據轉換
data_transformer = transforms.Compose([transforms.ToTensor()])path2data="./data"# 加載數據
test0_ds=datasets.STL10(path2data, split='test', download=True,transform=data_transformer)
print(test0_ds.data.shape)

from sklearn.model_selection import StratifiedShuffleSplit# 創建StratifiedShuffleSplit對象,設置分割次數為1,測試集大小為0.2,隨機種子為0
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)# 獲取test0_ds的索引
indices=list(range(len(test0_ds)))# 獲取test0_ds的標簽
y_test0=[y for _,y in test0_ds]# 對索引和標簽進行分割
for test_index, val_index in sss.split(indices, y_test0):# 打印測試集和驗證集的索引print("test:", test_index, "val:", val_index)# 打印測試集和驗證集的大小print(len(val_index),len(test_index))

from torch.utils.data import Subset# 從test0_ds中選取val_index索引的子集,賦值給val_ds
val_ds=Subset(test0_ds,val_index)
# 從test0_ds中選取test_index索引的子集,賦值給test_ds
test_ds=Subset(test0_ds,test_index)
# 定義均值
mean=[0.4467106, 0.43980986, 0.40664646]
# 定義標準差
std=[0.22414584,0.22148906,0.22389975]
# 定義一個名為test0_transformer的變量,用于將一系列的圖像變換操作組合在一起
test0_transformer = transforms.Compose([# 將圖像轉換為Tensor類型transforms.ToTensor(),# 對圖像進行歸一化操作,使用mean和std作為均值和標準差transforms.Normalize(mean, std),])   
# 將test0_transformer賦值給test0_ds的transform屬性
test0_ds.transform=test0_transformer
import time
import numpy as np# 調用deploy_model函數,傳入model_resnet18,val_ds,device和sanity_check參數,返回y_out和y_gt
y_out,y_gt=deploy_model(model_resnet18,val_ds,device=device,sanity_check=False)
# 打印y_out和y_gt的形狀
print(y_out.shape,y_gt.shape)

from sklearn.metrics import accuracy_score# 將y_out中的最大值索引賦值給y_pred
y_pred = np.argmax(y_out,axis=1)
# 打印y_pred和y_gt的形狀
print(y_pred.shape,y_gt.shape)# 計算并打印y_pred和y_gt的準確率
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

?

# 部署模型,得到預測結果和真實標簽
y_out,y_gt=deploy_model(model_resnet18,test_ds,device=device)# 取出預測結果中概率最大的類別
y_pred = np.argmax(y_out,axis=1)# 計算準確率
acc=accuracy_score(y_pred,y_gt)# 打印準確率
print(acc)

from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
np.random.seed(1)# 定義一個函數,用于顯示圖像
def imshow(inp, title=None):# 定義圖像的均值和標準差mean=[0.4467106, 0.43980986, 0.40664646]std=[0.22414584,0.22148906,0.22389975]# 將圖像從tensor轉換為numpy數組,并轉置inp = inp.numpy().transpose((1, 2, 0))# 將均值和標準差轉換為numpy數組mean = np.array(mean)std = np.array(std)# 將圖像的像素值進行歸一化inp = std * inp + mean# 將像素值限制在0和1之間inp = np.clip(inp, 0, 1)# 顯示圖像plt.imshow(inp)# 如果有標題,則顯示標題if title is not None:plt.title(title)# 暫停0.001秒plt.pause(0.001) # 定義網格大小
grid_size=16
# 隨機生成4個索引
rnd_inds=np.random.randint(1,len(test_ds),grid_size)
# 打印隨機生成的索引
print("image indices:",rnd_inds)# 根據索引獲取對應的圖像和標簽
x_grid_test=[test_ds[i][0] for i in rnd_inds]
y_grid_test=[(y_pred[i],y_gt[i]) for i in rnd_inds]# 將圖像轉換為網格
x_grid_test=utils.make_grid(x_grid_test, nrow=4, padding=2)
# 打印網格的形狀
print(x_grid_test.shape)# 設置圖像的大小
plt.rcParams['figure.figsize'] = (10, 10)
# 顯示網格
imshow(x_grid_test,y_grid_test)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/96765.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/96765.shtml
英文地址,請注明出處:http://en.pswp.cn/web/96765.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

FS950R08A6P2B 雙通道汽車級IGBT模塊Infineon英飛凌 電子元器件核心解析

一、核心解析:FS950R08A6P2B 是什么?1. 電子元器件類型FS950R08A6P2B 是英飛凌(Infineon) 推出的一款 950A/800V 雙通道汽車級IGBT模塊,屬于功率半導體模塊。它采用 EasyPACK 2B 封裝,集成多個IGBT芯片和二…

【系列文章】Linux中的并發與競爭[05]-互斥量

【系列文章】Linux中的并發與競爭[05]-互斥量 該文章為系列文章:Linux中的并發與競爭中的第5篇 該系列的導航頁連接: 【系列文章】Linux中的并發與競爭-導航頁 文章目錄【系列文章】Linux中的并發與競爭[05]-互斥量一、互斥鎖二、實驗程序的編寫2.1驅動…

TensorRT 10.13.3: Limitations

Limitations Shuffle-op can not be transformed to no-op for perf improvement in some cases. For the NCHW32 format, TensorRT takes the third-to-last dimension as the channel dimension. When a Shuffle-op is added like [N, ‘C’, H, 1] -> [‘N’, C, H], the…

Python與Go結合

Python與Go結合的方法Python和Go可以通過多種方式結合使用,通常采用跨語言通信或集成的方式。以下是幾種常見的方法:使用CFFI或CGO進行綁定Python可以通過CFFI(C Foreign Function Interface)調用Go編寫的庫,而Go可以通…

C++ 在 Visual Studio Release 模式下,調試運行與直接運行 EXE 的區別

前言 在 Visual Studio (以下簡稱 VS) 中開發 C 項目時,我們常常需要在 Debug 和 Release 兩種構建模式之間切換。Debug 模式適合開發和調試,而 Release 模式則針對生產環境,進行代碼優化以提升性能。然而,即使在 Release 模式下&…

南京方言數據集|300小時高質量自然對話音頻|專業錄音棚采集|方言語音識別模型訓練|情感計算研究|方言保護文化遺產數字化|語音情感識別|方言對話系統開發

引言與背景 隨著人工智能技術的快速發展,語音識別和自然語言處理領域對高質量方言數據的需求日益增長。南京方言作為江淮官話的重要分支,承載著豐富的地域文化和語言特色,在語言學研究和方言保護方面具有重要價值。本數據集精心采集了300小時…

基于LSTM深度學習的電動汽車電池荷電狀態(SOC)預測

基于LSTM深度學習的電動汽車電池荷電狀態(SOC)預測 摘要 電動汽車(EV)的普及對電池管理系統(BMS)提出了極高的要求。電池荷電狀態(State of Charge, SOC)作為BMS最核心的參數之一&am…

Golang語言之數組、切片與子切片

一、數組先記住數組的核心特點:盒子大小一旦定了就改不了(長度固定),但盒子里的東西能換(元素值可變)。就像你買了個能裝 3 個蘋果的鐵皮盒,想多裝 1 個都不行,但里面的蘋果可以換成…

速通ACM省銅第四天 賦源碼(G-C-D, Unlucky!)

目錄 引言: G-C-D, Unlucky! 題意分析 邏輯梳理 代碼實現 結語: 引言: 因為今天打了個ICPC網絡賽,導致坐牢了一下午,沒什么時間打題目了,就打了一道題,所以,今天我們就只講一題了&…

數據鏈路層總結

目錄 (一)以太網(IEEE 802.3) (1)以太網的幀格式 (2)幀協議類型字段 ①ARP協議 (橫跨網絡層和數據鏈路層的協議) ②RARP協議 (二&#xff…

Scala 新手實戰三案例:從循環到條件,搞定基礎編程場景

Scala 新手實戰三案例:從循環到條件,搞定基礎編程場景 對 Scala 新手來說,單純記語法容易 “學完就忘”,而通過小而精的實戰案例鞏固知識點,是掌握語言的關鍵。本文精選三個高頻基礎場景 ——9 乘 9 乘法口訣表、成績等…

java學習筆記----標識符與變量

1.什么是標識符?Java中變量、方法、類等要素命名時使用的字符序列,稱為標識符。 技巧:凡是自己可以起名字的地方都叫標識符。 比如:類名、方法名、變量名、包名、常量名等 2.標識符的命名規則由26個英文字母大小寫,0-9,或$組成 數字不可以開…

AI產品經理面試寶典第93天:Embedding技術選型與場景化應用指南

1. Embedding技術演進全景解析 1.1 稀疏向量:關鍵詞匹配的基石 1.1.1 問:請說明稀疏向量的適用場景及技術特點 答:稀疏向量適用于關鍵詞精確匹配場景,典型實現包括TF-IDF、BM25和SPLADE。其技術特征表現為50,000+高維向量且95%以上位置為零值,通過余弦或點積計算相似度…

【Mermaid.js】從入門到精通:完美處理節點中的空格、括號和特殊字符

文章標簽: Mermaid, Markdown, 前端開發, 數據可視化, 流程圖 文章摘要: 你是否在使用 Mermaid.js 繪制流程圖時,僅僅因為節點文本里加了一個空格或括號,整個圖就渲染失敗了?別擔心,這幾乎是每個 Mermaid 新…

多技術融合提升環境生態水文、土地土壤、農業大氣等領域的數據分析與項目科研水平

一:空間數據獲取與制圖1.1 軟件安裝與應用1.2 空間數據介紹1.3海量空間數據下載1.4 ArcGIS軟件快速入門1.5 Geodatabase地理數據庫二:ArcGIS專題地圖制作2.1專題地圖制作規范2.2 空間數據的準備與處理2.3 空間數據可視化:地圖符號與注記2.4 研…

【音視頻】Android NDK 與.so庫適配

一、名詞解析 名詞全稱核心說明Android NDKNative Development Kit在SDK基礎上增加“原生”開發能力,支持使用C/C編寫代碼,用于開發需要調用底層能力的模塊(如音視頻、加密算法等).so庫Shared Object即共享庫,由NDK編…

SpringBoot 輕量級一站式日志可視化與JVM監控

一、項目初衷Java 應用開發的同學都知道,項目上線后,日志的可視化查詢與 JVM 的可視化監控是一件非常重要的事。 市面上成熟方案一般是采用 ELK/EFK 實現日志可視化,采用 Actuator Prometheus Grafana 實現 JVM 監控。 這兩套都是非常優秀的…

【Leetcode hot 100】101.對稱二叉樹

問題鏈接 101.對稱二叉樹 問題描述 給你一個二叉樹的根節點 root , 檢查它是否軸對稱。 示例 1: 輸入:root [1,2,2,3,4,4,3] 輸出:true 示例 2: 輸入:root [1,2,2,null,3,null,3] 輸出:…

Zynq開發實踐(FPGA之選擇開發板)

【 聲明:版權所有,歡迎轉載,請勿用于商業用途。 聯系信箱:feixiaoxing 163.com】我們之所以選用zynq開發板,就在于它支持arm軟件開發,也支持fpga開發,甚至可以運行linux,這是之前沒有…

Flutter Riverpod 3.0 發布,大規模重構下的全新狀態管理框架

在之前的 《注解模式下的 Riverpod 有什么特別之處》我們聊過 Riverpod 2.x 的設計和使用原理,同時當時我們就聊到作者已經在開始探索 3.0 的重構方式,而現在隨著 Riverpod 3.0 的發布,riverpod 帶來了許多細節性的變化。 當然,這…