Pytorch訓練LeNet模型MNIST數據集

如何用torch框架訓練深度學習模型(詳解)

0. 需要的包

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

1. 數據加載和導入

以MNIST數據集為例

# 1.1 需要設置數據歸一化
train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
# 1.2 用dataset.MNIST函數下載和加載訓練集與測試集 
train_dataset = datasets.MNIST(dataset_path, train=True, download=False, transform=train_transform)
test_dataset = datasets.MNIST(dataset_path, train=False, download=False, transform=test_transform)
# 1.3 加載進dataload用于后續數據按batch取用
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

補充:這里的transform根據不同的數據集選擇不同的值
datasets加載數據集時path的路徑為:'.\data\' 該目錄下包括\MNIST文件夾

2. 加載模型和設置超參數

# 2.1 這里需要提前定義model的class,包括層結構和forward函數
model = LeNet_Mnist().to(device)
# 2.2 設置優化器、損失函數、訓練輪次
learning_rate = 1e-2
# 傳入模型參數,用于優化更新
sgd = SGD(model.parameters(), lr=learning_rate)  
loss_fn = CrossEntropyLoss()
all_epoch = 20

3. 訓練

# 3.1 首先設置訓練模式
model.train()
# 3.2 按照batch從train_loader中批量選擇數據
for idx, (train_x, train_label) in enumerate(train_loader):train_x = train_x.to(device)train_label = train_label.to(device)sgd.zero_grad()predict_y = model(train_x.float())loss = loss_fn(predict_y, train_label.long())loss.backward()sgd.step()

補充:可以在外面再套一層迭代次數

for current_epoch in range(all_epoch):  # local training

4. 測試

# 4.1 記錄測試結果
all_correct_num = 0
all_sample_num = 0
# 4.2 進入模型驗證模式,該模式下不會修改梯度
model.eval()
# 4.3 按批次測試
for idx, (test_x, test_label) in enumerate(test_loader):test_x = test_x.to(device)test_label = test_label.to(device)predict_y = model(test_x.float()).detach()predict_y = torch.argmax(predict_y, dim=-1)current_correct_num = predict_y == test_labelall_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)all_sample_num += current_correct_num.shape[0]
# 4.4 記錄結果并輸出
acc = all_correct_num / all_sample_num
print('accuracy: {:.3f}'.format(acc), flush=True)

5. 保存結果

# 5.1 保存參數
print("Save the model state dict")
torch.save(model.state_dict(), "./lenet_mnist.pt")
# 5.2 或者也可以選擇保存checkpoint,每輪都保存一次,萬一中斷能繼續
checkpoint = {"model": model.state_dict(),"optim": sgd.state_dict(),}
print("Save the checkpoint")
torch.save(checkpoint, "./checkpoint{}.pt".format(current_epoch))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/18774.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/18774.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/18774.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python圖形界面(GUI)Tkinter筆記(九):用【Button()】功能按鈕實現人機交互

在Tkinter庫中,功能按鈕(Button)是實現人機交互的一個非常重要的組件: 【一】主要可實現功能及意義: (1)響應用戶交互: Button組件允許用戶通過點擊來觸發某個事件或動作。當用戶點擊按鈕時,可以執行一個指定的函數或方法。 (2)提供用戶輸入: Button組件是圖形用戶界面(G…

持續總結中!2024年面試必問 20 道 Rocket MQ面試題(三)

上一篇地址:持續總結中!2024年面試必問 20 道 Rocket MQ面試題(二)-CSDN博客 五、什么是生產者(Producer)和消費者(Consumer)在RocketMQ中? RocketMQ是一個高性能、高吞…

Linux完整版命令大全(二十五)

pine 功能說明&#xff1a;收發電子郵件&#xff0c;瀏覽新聞組。語  法&#xff1a;pine [-ahikorz][-attach<附件>][-attach_and_delete<附件>][-attachlist<附件清單>][-c<郵件編號>][-conf][-create_lu<地址薄><排序法>][-f<收件…

劇本殺小程序開發,探索市場發展新的商業機遇

劇本殺游戲作為一個新興行業&#xff0c;經歷了爆發式的增長&#xff0c;劇本殺游戲在市場中的熱度不斷升高。 不過&#xff0c;在市場的火熱下&#xff0c;競爭也在逐漸加大。因此&#xff0c;在市場競爭下&#xff0c;成本低、主題多樣、有趣的線上劇本殺小程序成為了創業者…

竹云董事長在第二屆ICT技術發展與企業數字化轉型高峰論壇作主題演講

5月25日&#xff0c;由中國服務貿易協會指導&#xff0c;中國服務貿易協會信息技術服務委員會主辦的 “第二屆ICT技術發展與企業數字化轉型高峰論壇” 在北京隆重召開。 本次論壇以 “數據驅動&#xff0c;AI引領&#xff0c;打造新質生產力” 為主題&#xff0c;特邀業內200余…

WebGL實現醫學教學軟件

使用WebGL實現醫學教學軟件是一個復雜但非常有益的項目&#xff0c;可以顯著提升醫學教育的互動性和效果。以下是詳細的實現步驟&#xff0c;包括需求分析、技術選型、開發流程和注意事項。北京木奇移動技術有限公司&#xff0c;專業的軟件外包開發公司&#xff0c;歡迎交流合作…

redis-cli help使用

1. redis-cli命令使用—先連接上服務器 連接到 Redis 服務器&#xff1a; 使用 redis-cli 命令即可連接到本地運行的 Redis 服務器&#xff0c;默認連接到本地的 6379 端口。 redis-cli如果 Redis 服務器不在本地或者端口不同&#xff0c;可以使用 -h 和 -p 參數指定主機和端…

華為校招機試 - LRU模擬(20240515)

題目描述 LRU(Least Recently Used)緩存算法是一種常用于管理緩存的策略,其目標是保留最近使用過的數據,而淘汰最久未被使用的數據。 實現簡單的LRU緩存算法,支持查詢、插入、刪除操作。 最久未被使用定義:查詢、插入和刪除操作均為一次訪問操作,每個元素均有一個最后…

探索Django 5: 從零開始,打造你的第一個Web應用

今天我們將一起探索 Django 5&#xff0c;一個備受開發者喜愛的 Python Web 框架。我們會了解 Django 5 的簡介&#xff0c;新特性&#xff0c;如何安裝 Django&#xff0c;以及用 Django 編寫一個簡單的 “Hello, World” 網站。最后&#xff0c;我會推薦一本與 Django 5 相關…

蘇洵,大器晚成的家風塑造者

&#x1f4a1; 如果想閱讀最新的文章&#xff0c;或者有技術問題需要交流和溝通&#xff0c;可搜索并關注微信公眾號“希望睿智”。 蘇洵&#xff0c;字明允&#xff0c;號老泉&#xff0c;生于宋真宗大中祥符二年&#xff08;公元1009年&#xff09;&#xff0c;卒于宋英宗治平…

量產導入 | 產品可靠性測試標準完整大集合(JEDEC/IEC/SAE…)

產品可靠性測試標準完整大集合(JEDEC/IEC/SAE…) 產品可靠性測試是產品質量保證中的重要一環, 包含有Pre-con, aging(壽命)和ESD(靜電)等, 下面就收集了權威標準JEDEC全系列, 請參照如下 同時也附上其它的可靠性標準供大家參考及交叉理解, 可能側重點不同, 大家可以參…

go語言同一包中的同一變量實現不同平臺設置不同的默認值 //go:build 編譯語法使用示例

在使用go來開發跨平臺應用的時候&#xff0c;比如配置文件的路徑&#xff0c;我們希望設置一個默認值&#xff0c;windows下的路徑是類似 d:\myapp\app.conf 這樣的&#xff0c; unix系統中的路徑是 /opt/myapp/app.conf 這樣的&#xff0c; 而我們在使用的時候需要使用的是同…

PPT忘記保存?教你如何輕松恢復

在日常辦公中PPT文件作為主流文檔格式&#xff0c;承載著我們大量的工作成果。然而當不小心誤點了“不保存”按鈕&#xff0c;或是遭遇軟件崩潰等意外情況導致文檔丟失時&#xff0c;文件內容是否還能夠能恢復&#xff0c;往往成為我們最關心的問題。本文將為您提供五大免費且實…

NetCore PetaPoco 事務處理分享

PetaPoco是一個輕量級的.NET和Mono數據庫訪問庫&#xff0c;它以單個C#文件的形式存在&#xff0c;便于集成到任何項目中。PetaPoco的主要特點包括無依賴性、快速的性能和對簡單事務的支持。它適用于嚴格的沒有裝飾的Poco類以及幾乎全部加了特性的Poco類&#xff0c;并提供了多…

現在版本的ultralytics沒有setup.py以后,本地代碼中修改了ultralytics源碼,怎么安裝到python環境中。

問題&#xff0c;在使用ultralytics訓練yolov8-obb模型時&#xff0c;修改了ultralytics源碼的網絡結構&#xff0c;發現調用的還是pip install安裝的ultralytics庫&#xff0c;新版本源碼中還沒有setup.py&#xff0c;該怎么把源碼中的ultralytics安裝到環境中。 解決方法&am…

《探索網絡七層模型:構建高效通信架構的關鍵》

在當今數字化時代&#xff0c;網絡通信已經成為人們生活和工作中不可或缺的一部分。而網絡七層模型作為計算機網絡體系結構的重要基礎&#xff0c;其技術架構對于構建高效、穩定的通信系統具有重要意義。本文將深入探討網絡七層模型的技術架構設計&#xff0c;以及其在構建現代…

輕松掌握圖片批量處理,趕緊學習這些小技巧!

在現今數字化的社會中&#xff0c;我們每天都會接觸到大量的圖片&#xff0c;無論是在工作中還是日常生活中。要想高效處理這些圖片&#xff0c;掌握圖片批量處理的技巧就顯得尤為重要。幸運的是&#xff0c;有許多小技巧和工具可以讓這一過程變得輕松愉快。 在本文中&#xf…

長安鏈使用Golang編寫智能合約教程(三)

本篇主要介紹長安鏈Go SDK寫智能合約的一些常見方法的使用方法或介紹 資料來源&#xff1a; 官方文檔官方示例合約庫官方SDK接口文檔 一、獲取參數、獲取狀態、獲取歷史記錄的方法解析 注意&#xff01; 這些查詢鏈上數據的方法&#xff1a;只能是查詢本合約之前上鏈的數據&a…

信息學一周賽事安排

本周比賽提醒 本周有以下幾場比賽即將開始&#xff1a; 1.ABC-356 比賽時間&#xff1a;6月1日&#xff08;周六&#xff09;晚20:00 比賽鏈接&#xff1a;https://atcoder.jp/contests/abc356 2.ARC-179 比賽時間&#xff1a;6月2日&#xff08;周日&#xff09;晚20:00 …

【Go】十、路由配置以及ZAP 高性能日志庫的使用

Project 目錄創建 mxshop-api user-web api ---- 服務接口 config ---- 配置信息 forms ---- 表單驗證信息 global ---- 全局信息 initialize ---- 初始化信息 middlewares ---- 中間件信息 proto ---- 數據信息 router ---- 路由信息 utils ---- 公用工具信息 validator ----…