如何保存訓練的最優模型和使用最優模型文件

一 保存最優模型

主要就是我們在for循環中加上一個test測試,并且我還在test函數后面加上了返回值,可以返回準確率,然后每次進行一次對比,然后取大的。然后這里有兩種保存方式,一種是保存了整個模型,另一個是保存了模型參數。

1 僅保存模型參數

torch.save(model.state_dict(),'best_model.pth')

然后后面我們使用的時候

model =torch.load('best1.pth')#
model.to(device)
model.load_state_dict(torch.load("best.pth", map_location=device))
model.eval()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
test(test_dataloader,model,loss_fn)

注意這里要設置eval模式,因為我們要保證我們的模型參數不再變化了。

2 保存整個模型

torch.save(model,'best1.pth')

在調用的時候

model = torch.load('best1.pth', map_location=torch.device('cuda'))
model.eval()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
test(test_dataloader,model,loss_fn)

直接調用就好。

注意這兩種必須要有定義好的網絡,不然無法運行(保存整個網絡也要定于一個完全相同的網絡)。

完整代碼

epochs=20
for i in range(epochs):print(f"Epoch {i+1}")train(train_dataloader,model,loss_fn,optimizer)corrects = test(test_dataloader,model,loss_fn)accuracy_list.append(corrects)if corrects>best_acc:print(f"Best Accuracy: {corrects}%")best_acc=corrects#第一種# torch.save(model.state_dict(),'best_model.pth')#第二種torch.save(model,'best1.pth')

完整代碼含網絡

import numpy as np
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transformsclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(  # 2d一般用于圖像,3d用于視頻數據(多一個時間維度),1d一般用于結構化的序列數據in_channels=3,  # 圖像通道個數,1表示灰度圖(確定了卷積核 組中的個數),out_channels=16,  # 要得到多少個特征圖,卷積核的個數kernel_size=5,  # 卷積核人小,5*5stride=1,  # 步長padding=2  # 填充值),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 進行池化操作(2x2 區域))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),)self.out = nn.Linear(64 * 64 * 64, 20)  # 全連接層得到的結果def forward(self, x):  # 前向傳播,你得告訴它 數據的流向 是神經網絡層連接起來,函數名稱不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # view和reshape是一樣的作用,但此處是tensor形式output = self.out(x)return outputdata_transform={# 'train': transforms.Compose([#     # 調整圖像大小為300x300像素#     transforms.Resize([256, 256]),##     # # 隨機旋轉:-45到45度之間隨機選擇角度#     # transforms.RandomRotation(45),#     # ##     # # # 從中心裁剪出256x256的區域#     # transforms.CenterCrop([256, 256]),#     ##     # # 隨機水平翻轉:以50%的概率進行水平鏡像#     # transforms.RandomHorizontalFlip(p=0.5),#     ##     # # 隨機垂直翻轉:以50%的概率進行垂直鏡像#     # transforms.RandomVerticalFlip(p=0.5),#     ##     # # # 顏色抖動:隨機調整亮度、對比度、飽和度和色調#     # # transforms.ColorJitter(#     # #     brightness=0.2,    # 亮度變化幅度為20%#     # #     contrast=0.1,      # 對比度變化幅度為10%#     # #     saturation=0.1,    # 飽和度變化幅度為10%#     # #     hue=0.1            # 色調變化幅度為10%#     # # ),#     # ##     # # # 隨機灰度化:以10%的概率將圖像轉換為灰度圖#     # transforms.RandomGrayscale(p=0.1),##     # 將PIL圖像轉換為PyTorch張量,并自動歸一化到[0,1]范圍#     transforms.ToTensor(),##     # 標準化:使用ImageNet數據集的均值和標準差進行標準化#     transforms.Normalize(#         [0.485, 0.456, 0.406],  # 均值(R, G, B通道)#         [0.229, 0.224, 0.225]   # 標準差(R, G, B通道)#     )# ]),# 驗證/測試數據的預處理(通常不需要數據增強)'test': transforms.Compose([transforms.Resize([256, 256]),# transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}class food_dataset(Dataset):def __init__(self, root, transform=None):super().__init__()self.root = rootself.transform = transformself.images = []self.labels = []with open(root,encoding='utf-8') as f:samples = [i.strip().split() for i in f.readlines()]for img_path,label in samples:self.images.append(img_path)self.labels.append(label)def __len__(self):return len(self.images)def __getitem__(self, index):image=Image.open(self.images[index]).convert('RGB')if self.transform:image=self.transform(image)label = self.labels[index]# print(label)label = torch.from_numpy(np.array(label,dtype=np.int64))# print(label)return image, labeldef test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()batch_size_num=1loss,correct=0,0with torch.no_grad():for X, y in test_dataloader:X,y=X.to(device),y.to(device)pred = model(X)loss = loss_fn(pred,y)+losscorrect += (pred.argmax(1) == y).type(torch.float).sum().item()loss/=num_batchescorrect/=sizeprint(f'Test result: \n Accuracy: {(100*correct)}%,Avg loss: {loss}')device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")test_data=food_dataset('test_data',transform=(data_transform['test']))
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)# model =CNN()
# model.to(device)
# model.load_state_dict(torch.load("best.pth"))
model=torch.load('best.pt')
model.eval()
loss_fn = nn.CrossEntropyLoss()
test(test_dataloader,model,loss_fn)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/95704.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/95704.shtml
英文地址,請注明出處:http://en.pswp.cn/web/95704.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vue3+ts+echarts多Y軸折線圖

因為放在了子組件才監聽&#xff0c;加載渲染調用&#xff0c;有暗黑模式才調用&#xff0c;<!-- 溫濕度傳感器 --><el-row v-if"deviceTypeId 2"><el-col :xs"24" :sm"24" :md"24" :lg"24" :xl"24&qu…

基于Taro4打造的一款最新版微信小程序、H5的多端開發簡單模板

基于Taro4、Vue3、TypeScript、Webpack5打造的一款最新版微信小程序、H5的多端開發簡單模板 特色 &#x1f6e0;? Taro4, Vue 3, Webpack5, pnpm10 &#x1f4aa; TypeScript 全新類型系統支持 &#x1f34d; 使用 Pinia 的狀態管理 &#x1f3a8; Tailwindcss4 - 目前最流…

ITU-R P.372 無線電噪聲預測庫調用方法

代碼功能概述&#xff08;ITURNoise.c&#xff09;該代碼是一個 ITU-R P.372 無線電噪聲預測 的計算程序&#xff0c;能夠基于 月份、時間、頻率、地理位置、人為噪聲水平 計算特定地點的 大氣噪聲、銀河噪聲、人為噪聲及其總和&#xff0c;并以 CSV 或標準輸出 方式提供結果。…

《從報錯到運行:STM32G4 工程在 Keil 中的頭文件配置與調試實戰》

《從報錯到運行&#xff1a;STM32G4 工程在 Keil 中的頭文件配置與調試實戰》文章提綱一、引言? 闡述 STM32G4 在嵌入式領域的應用價值&#xff0c;說明 Keil 是開發 STM32G4 工程的常用工具? 指出頭文件配置是 STM32G4 工程在 Keil 中開發的關鍵基礎環節&#xff0c;且…

Spring 事務提交成功后執行額外邏輯

1. 場景與要解決的問題在業務代碼里&#xff0c;常見訴求是&#xff1a;只有當數據庫事務真正提交成功后&#xff0c;才去執行某些“后置動作”&#xff0c;例如&#xff1a;發送 MQ、推送消息、寫審計/埋點日志、刷新緩存、通知外部系統等。如果這些動作在事務提交前就執行&am…

Clickhouse MCP@Mac+Cherry Studio部署與調試

一、需求背景 已經部署測試了Mysql、Drois的MCP Server,想進一步測試Clickhouse MCP的表現。 二、環境 1)操作系統 MacOS+Apple芯片 2)Clickhouse v25.7.6.21-stable、Clickhouse MCP 0.1.11 3)工具Cherry Studio 1.5.7、Docker Desktop 4.43.2(199162) 4)Python 3.1…

Java Serializable 接口:明明就一個空的接口嘛

對于 Java 的序列化,我之前一直停留在最淺層次的認知上——把那個要序列化的類實現 Serializbale 接口就可以了嘛。 我似乎不愿意做更深入的研究,因為會用就行了嘛。 但隨著時間的推移,見到 Serializbale 的次數越來越多,我便對它產生了濃厚的興趣。是時候花點時間研究研…

野火STM32Modbus主機讀取寄存器/線圈失敗(三)-嘗試將存貯事件的地方改成數組(非必要解決方案)(附源碼)

背景 盡管crc校驗正確了&#xff0c;也成功發送了EV_MASTER_EXECUTE事件&#xff0c;但是eMBMasterPoll( void )中總是接收的事件是EV_MASTER_FRAME_RECEIVED或者EV_MASTER_FRAME_SENT&#xff0c;一次都沒有執行EV_MASTER_EXECUTE。EV_MASTER_EXECUTE事件被別的事件給覆蓋了&…

微信小程序校園助手程序(源碼+文檔)

源碼題目&#xff1a;微信小程序校園助手程序&#xff08;源碼文檔&#xff09;?? 文末聯系獲取&#xff08;含源碼、技術文檔&#xff09;博主簡介&#xff1a;10年高級軟件工程師、JAVA技術指導員、Python講師、文章撰寫修改專家、Springboot高級&#xff0c;歡迎高校老師、…

59-python中的類和對象、構造方法

1. 認識一下對象 世間萬物皆是"對象" student_1{ "姓名":"小樸", "愛好":"唱、跳、主持" ......... }白紙填寫太落伍了 設計表格填寫先進一些些 終極目標是程序使用對象去組織數據程序中設計表格&#xff0c;我們稱為 設計類…

向成電子驚艷亮相2025物聯網展,攜工控主板等系列產品引領智造新風向

2025年8月27-29日&#xff0c;IOTE 2025 第二十四屆國際物聯網展深圳站在深圳國際會展中心&#xff08;寶安&#xff09;盛大啟幕&#xff01;作為全球規模領先的物聯網盛會之一&#xff0c;本屆展會以“生態智能&#xff0c;物聯全球”為核心&#xff0c;匯聚超1000家全球頭部…

陣列信號處理之均勻面陣波束合成方向圖的繪制與特點解讀

陣列信號處理之均勻面陣波束合成方向圖的繪制與特點解讀 文章目錄前言一、方向圖函數二、方向圖繪制三、副瓣電平四、陣元個數對主瓣寬度的影響五、陣元間距對主瓣寬度的影響六、MATLAB源代碼總結前言 \;\;\;\;\;均勻面陣&#xff08;Uniform Planar Array&#xff0c;UPA&…

算法在前端框架中的集成

引言 算法是前端開發中提升性能和用戶體驗的重要工具。隨著 Web 應用復雜性的增加&#xff0c;現代前端框架如 React、Vue 和 Angular 提供了強大的工具集&#xff0c;使得將算法與框架特性&#xff08;如狀態管理、虛擬 DOM 和組件化&#xff09;無縫集成成為可能。從排序算法…

網絡爬蟲是自動從互聯網上采集數據的程序

網絡爬蟲是自動從互聯網上采集數據的程序網絡爬蟲是自動從互聯網上采集數據的程序&#xff0c;Python憑借其豐富的庫生態系統和簡潔語法&#xff0c;成為了爬蟲開發的首選語言。本文將全面介紹如何使用Python構建高效、合規的網絡爬蟲。一、爬蟲基礎與工作原理 網絡爬蟲本質上是…

Qt Model/View/Delegate 架構詳解

Qt Model/View/Delegate 架構詳解 Qt的Model/View/Delegate架構是Qt框架中一個重要的設計模式&#xff0c;它實現了數據存儲、數據顯示和數據編輯的分離。這種架構不僅提高了代碼的可維護性和可重用性&#xff0c;還提供了極大的靈活性。 1. 架構概述 Model/View/Delegate架構將…

光譜相機在手機行業的應用

在手機行業&#xff0c;光譜相機技術通過提升拍照色彩表現和擴展健康監測等功能&#xff0c;正推動攝像頭產業鏈升級&#xff0c;并有望在AR/VR、生物醫療等領域實現更廣泛應用。以下為具體應用場景及技術突破的詳細說明&#xff1a;?一、光譜相機在手機行業的應用場景??拍照…

FASTMCP中的Resources和Templates

Resources 給 MCP 客戶端/LLM 讀取的數據端點&#xff08;只讀、按 URI 索引、像“虛擬文件系統”或“HTTP GET”&#xff09;&#xff1b; Templates 可帶參數的資源路由&#xff08;URI 里占位符 → 運行函數動態生成內容&#xff09;。 快速要點 ? 用途&#xff1a;把文件…

OpenBMC之編譯加速篇

加快 OpenBMC 的編譯速度是一個非常重要的話題,因為完整的構建通常非常耗時(在高性能機器上也需要數十分鐘,普通電腦上可能長達數小時)。以下是從不同層面優化編譯速度的詳細策略,您可以根據自身情況組合使用。 一、核心方法:利用 BitBake 的緩存和共享機制(效果最顯著…

Kafka面試精講 Day 8:日志清理與數據保留策略

【Kafka面試精講 Day 8】日志清理與數據保留策略 在Kafka的高吞吐、持久化消息系統中&#xff0c;日志清理與數據保留策略是決定系統資源利用效率、數據可用性與合規性的關鍵機制。作為“Kafka面試精講”系列的第8天&#xff0c;本文聚焦于日志清理機制&#xff08;Log Cleani…

基于Hadoop的網約車公司數據分析系統設計(代碼+數據庫+LW)

摘 要 本系統基于Hadoop平臺&#xff0c;旨在為網約車公司提供一個高效的數據分析解決方案。隨著網約車行業的快速發展&#xff0c;平臺上產生的數據量日益增加&#xff0c;傳統的數據處理方式已無法滿足需求。因此&#xff0c;設計了一種基于Hadoop的大規模數據處理和分析方…