pytorch學習筆記-模型訓練、利用GPU加速訓練(兩種方法)、使用模型完成任務

應該算是完結啦~再次感謝土堆老師!

模型訓練

模型訓練基本可以分為以下幾個步驟按序執行:
引入數據集-使用dataloader加載數據集-建立模型-設置損失函數-設置優化器-進行訓練-訓練中計算損失,并使用優化器更新參數-模型測試-模型存儲

習慣上會將model和train代碼分開寫,當然一開始混合寫也沒啥問題,直接給出一個例程:

# train.py
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from model import *
import timefrom torch.utils.tensorboard import SummaryWriterdata_transforms = transforms.Compose([transforms.ToTensor()
])#引入數據集
train_data = datasets.CIFAR10("./dataset",train=True,transform=data_transforms,download=True)test_data = datasets.CIFAR10("./dataset",train=False,transform=data_transforms,download=True)#加載數據
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)start_time = time.time()#建立模型
my_module = MyModule()#設置損失函數
cross_loss = nn.CrossEntropyLoss()#設置優化器
#設置學習率
learning_rate = 1e-2
optimizer = torch.optim.SGD(my_module.parameters(),lr=learning_rate)#進行訓練
#設置迭代次數
epoch = 10total_train_steps = 0writer = SummaryWriter("train_logs")for i in range(epoch):print("第{}輪訓練".format(i+1))#訓練my_module.train() #只對某些層起作用for data in train_dataloader:imgs, targets = dataoutputs = my_module(imgs)#計算損失loss = cross_loss(outputs, targets)#優化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_steps +=1if total_train_steps % 100 ==0:print("訓練次數:{},Loss:{}".format(total_train_steps,loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_steps)#測試,不再梯度下降my_module.eval() #同樣只對某些層起作用 total_test_loss = 0# total_test_steps = 0total_accuracy = 0test_data_size = len(test_data)with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = my_module(imgs)loss = cross_loss(outputs,targets)total_test_loss += loss.item()##對于分類任務可以求一下準確的個數,非必須#argmax(1)按行取最大的下標 argmax(0)按列取最大的下標accuracy = (outputs.argmax(1)==targets).sum()total_accuracy += accuracyprint("第{}輪的測試集Loss:{}".format(i+1,total_test_loss))print("測試集準確率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,i)end_time = time.time()print("time:{}".format(end_time-start_time))#存儲模型if i % 5 == 0:torch.save(my_module,"my_module_{}.pth".format(i))print("模型存儲成功")writer.close()

使用GPU加速訓練(方式一)

上述寫法默認是采用cpu進行訓練的,會比較慢,為了加速訓練過程,我們需要用GPU進行加速訓練。對應有兩種方式,推薦方式二(大部分例程也都是采用方式二的)
需要用到GPU加速的主要有如圖中的三個部分:
在這里插入圖片描述
對應有.cuda()的參數,只要在原始位置后面加上.cuda()就可以,考慮到有些設備沒有GPU,建議加上.cuda.is_avaliable的判斷。同樣給出完整例程:

# train_gpu1.py
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from model import *
import timefrom torch.utils.tensorboard import SummaryWriterdata_transforms = transforms.Compose([transforms.ToTensor()
])#引入數據集
train_data = datasets.CIFAR10("./dataset",train=True,transform=data_transforms,download=True)test_data = datasets.CIFAR10("./dataset",train=False,transform=data_transforms,download=True)#加載數據
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)start_time = time.time()#建立模型
my_module = MyModule()
if torch.cuda.is_available():my_module.cuda()#設置損失函數
cross_loss = nn.CrossEntropyLoss()
if torch.cuda.is_available():cross_loss.cuda()#設置優化器
#設置學習率
learning_rate = 1e-2
optimizer = torch.optim.SGD(my_module.parameters(),lr=learning_rate)#進行訓練
#設置迭代次數
epoch = 10total_train_steps = 0writer = SummaryWriter("train_logs")for i in range(epoch):print("第{}輪訓練".format(i+1))#訓練my_module.train() #只對某些層起作用for data in train_dataloader:imgs, targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = my_module(imgs)#計算損失loss = cross_loss(outputs, targets)#優化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_steps +=1if total_train_steps % 100 ==0:print("訓練次數:{},Loss:{}".format(total_train_steps,loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_steps)#測試,不再梯度下降my_module.eval() #同樣只對某些層起作用 total_test_loss = 0# total_test_steps = 0total_accuracy = 0test_data_size = len(test_data)with torch.no_grad():for data in test_dataloader:imgs, targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = my_module(imgs)loss = cross_loss(outputs,targets)total_test_loss += loss.item()##對于分類任務可以求一下準確的個數,非必須#argmax(1)按行取最大的下標 argmax(0)按列取最大的下標accuracy = (outputs.argmax(1)==targets).sum()total_accuracy += accuracyprint("第{}輪的測試集Loss:{}".format(i+1,total_test_loss))print("測試集準確率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,i)end_time = time.time()print("time:{}".format(end_time-start_time))#存儲模型if i % 5 == 0:torch.save(my_module.state_dict(),"my_module_{}.pth".format(i))print("模型存儲成功")writer.close()

使用GPU加速訓練(方式二)

現在更常見的寫法是用device+.to(device)的搭配,需要引入的位置和方式一提到的沒有任何差異,主要就是使用上的語法會有一點點不一樣,所以直接給出一個例程,大家看完就知道怎么用了:

# train_gpu2.py
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from model import *from torch.utils.tensorboard import SummaryWriterdata_transforms = transforms.Compose([transforms.ToTensor()
])#引入數據集
train_data = datasets.CIFAR10("./dataset",train=True,transform=data_transforms,download=True)test_data = datasets.CIFAR10("./dataset",train=False,transform=data_transforms,download=True)#加載數據
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)#確定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)#建立模型
my_module = MyModule()
my_module.to(device)#設置損失函數
cross_loss = nn.CrossEntropyLoss()
cross_loss.to(device)#設置優化器
#設置學習率
learning_rate = 1e-2
optimizer = torch.optim.SGD(my_module.parameters(),lr=learning_rate)#進行訓練
#設置迭代次數
epoch = 10total_train_steps = 0writer = SummaryWriter("train_logs")for i in range(epoch):print("第{}輪訓練".format(i+1))#訓練my_module.train() #只對某些層起作用for data in train_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = my_module(imgs)#計算損失loss = cross_loss(outputs, targets)#優化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_steps +=1if total_train_steps % 100 ==0:print("訓練次數:{},Loss:{}".format(total_train_steps,loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_steps)#測試,不再梯度下降my_module.eval() #同樣只對某些層起作用 total_test_loss = 0# total_test_steps = 0total_accuracy = 0test_data_size = len(test_data)with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = my_module(imgs)loss = cross_loss(outputs,targets)total_test_loss += loss.item()##對于分類任務可以求一下準確的個數,非必須#argmax(1)按行取最大的下標 argmax(0)按列取最大的下標accuracy = (outputs.argmax(1)==targets).sum()total_accuracy += accuracyprint("第{}輪的測試集Loss:{}".format(i+1,total_test_loss))print("測試集準確率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,i)#存儲模型if i % 5 == 0:torch.save(my_module.state_dict(),"my_module_{}.pth".format(i))print("模型存儲成功")writer.close()

使用模型完成任務

這個標題我想了很久應該叫什么才能和內容對應上…土堆老師原來名字叫模型驗證我沒看之前一直以為是evaluate的過程啊啊啊結果是test的過程

實際上我們訓練完模型得到一堆準確率啊或者什么的時候并不代表我們完成了整個事情,說人話就是沒啥用,所以這部分其實就是教我們怎么用得到的模型在其他數據上使用,這一部分還蠻簡單的,和訓練中的evaluate部分使用差不多,注意點就是別忘了給圖片reshape成含有batch_size的形狀(特別是單張的情況下),當然,如果有報錯也可以先考慮是不是形狀不太對的原因…

# test.py
from PIL import Image
import torchvision
from torchvision import transforms
from model import *image_path = "./test_imgs/cat.jpg"
image = Image.open(image_path)
# image = image.convert('RGB') # 對于png圖片要加上這一句data_transforms = torchvision.transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()])image = data_transforms(image)
print(image.shape)model = MyModule()
model.load_state_dict(torch.load("my_module_5.pth"))image = torch.reshape(image,(1,3,32,32))
model.eval()
with torch.no_grad():output = model(image)print(output)
print(output.argmax(1))# torch.Size([3, 32, 32])
# tensor([[-1.5852, -1.3985,  1.0891,  2.5762,  0.1534,  2.0844,  0.6164,  1.7049,#  -4.7464, -1.4447]])
# tensor([3])

其實我的模型準確率沒有很高,但是對于這張圖片竟然驚人的分對了(第3類-cat)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/919615.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/919615.shtml
英文地址,請注明出處:http://en.pswp.cn/news/919615.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

深度卷積神經網絡AlexNet

在提出LeNet后卷積神經網絡在計算機視覺和機器學習領域中報有名氣,但是卷積神經網絡并沒有主導這些領域,因為LeNet在小數據集上取得了很好的效果,在更大,更真實的數據集上訓練卷積神經網絡的性能 和可行性有待研究,20世…

數據結構-HashSet

在 Java 編程的世界里,集合框架是極為重要的一部分,而 HashSet 作為 Set 接口的典型實現類,在處理不允許重復元素的場景中頻繁亮相。今天,我們就一同深入探究 HashSet,梳理它的特點、常用方法,以及和其他相…

心意行藥號 · 慈心方的八種用法

心意行藥號 慈心方的八種用法慈心方是心意行藥號589個珍貴秘方中的一個養生茶方,配伍比例科學嚴謹,君臣佐使堪稱經典,自古就有“小小慈心方,轉動大乾坤”之說。自清代光緒年間傳承至今,慈心方受益者逾百萬計&#xff…

Spring面試寶典:Spring IOC的執行流程解析

在準備Spring框架的面試時,“Spring IOC的工作流程是什么?” 是一個非常經典的問題。雖然網上有很多詳細的教程,但它們往往過于復雜,對于沒有深入研究過源碼的人來說理解起來確實有些困難。今天我們就來簡化這個概念,從…

學習日志39 python

1 fromkeys()函數是什么在 Python 中,fromkeys() 是字典(dict)的一個類方法,用于創建一個新字典。它的作用是:根據指定的可迭代對象(如列表、元組等)中的元素作為鍵(key)…

SpringBoot + MyBatis-Plus 使用 listObjs 報 ClassCastException 的原因與解決辦法

在項目中我們經常會遇到這種需求: 根據一組 ID 查詢數據庫,并返回指定字段列表。 我在寫代碼的時候,遇到了一個典型的坑,分享出來給大家。一、問題背景我的代碼是這樣寫的(查詢項目表的負責人信息)&#xf…

WT2606B 驅屏語音芯片新增藍牙功能:功能集成一體化,產品升級自動化,語音交互無線化,場景應用普適化!

小伙伴們,歡迎來到我們的 #唯創芯片小講堂!今天我們要為大家介紹一位多才多藝的"芯片全能手"——WT2606B驅屏語音芯片。這顆芯片將在今年8月的I0TE物聯網展及ELEXCON 2025深圳國際電子展上大放異彩。在智能設備滿天飛的今天&#x…

ORA-16331: container is not open ORA-06512: at “SYS.DBMS_LOGMNR“

使用Flink CDC、Debezium等CDC工具對Oracle進行基于log的實時數據同步時遇到異常ORA-16331: container is not open的解決方案。 1. 異常信息 異常信息通常如下: at oracle.jdbc.driver.OracleStatement.executeInternal(OracleStatement.java:1823) at oracle.jdbc…

「三維共振」:重構實體零售的破局模式

在電商沖擊與消費升級的雙重浪潮下,傳統零售模式正面臨前所未有的挑戰。wo店首創的 “三維共振” 運營模式,以場景體驗為根基、數據驅動為引擎、社群共生為紐帶,構建起線上線下深度融合的新型零售生態,至今已實現連續 18 個月客流…

將集合拆分成若干個batch,并將batch存于新的集合

在使用saveAll()等方法時,為了防止集合元素過大,使用splitList將原集合,分割成若干個小集合 import java.util.ArrayList; import java.util.List;public class ListUtils {/*** 將集合拆分成若干個batch,并將batch存于新的集合** param list…

Java主流框架全解析:從企業級開發到云原生

Java作為一門歷史悠久且應用廣泛的編程語言,其強大的生態系統離不開各種優秀的框架支持。無論是傳統的企業級應用開發,還是現代的微服務、云原生架構,Java都提供了豐富的框架選擇。本文將全面解析當前主流的Java框架,涵蓋Web開發、…

機器學習——網格搜索(GridSearchCV)超參數優化

網格搜索(Grid Search)詳細教學1. 什么是網格搜索?在機器學習模型中,算法的**超參數(Hyperparameters)**對模型的表現起著決定性作用。比如:KNN 的鄰居數量 n_neighborsSVM 的懲罰系數 C 和核函…

【LeetCode】18. 四數之和

文章目錄18. 四數之和題目描述示例 1:示例 2:提示:解題思路算法一:排序 雙指針(推薦)算法二:通用 kSum(含 2Sum 雙指針)復雜度關鍵細節代碼實現要點完整題解代碼18. 四數…

Go語言入門(10)-數組

訪問數組元素:數組中的每個元素都可以通過“[]”和一個從0開始的索引進行訪問數組的長度可由內置函數len來確定。在聲明數組時,未被賦值元素的值是對應類型的零值。下面看一個例子package mainfunc main(){var planets [8]stringplanets[0] "Mercu…

為什么經過IPSec隧道后HTTPS會訪問不通?一次隧道環境下的實戰分析

在運維圈子里,大家可能都遇到過這種奇怪的問題:瀏覽器能打開 HTTP 網站,但一換成 HTTPS,頁面就死活打不開。前段時間,我們就碰到這么一個典型案例。故障現象某公司系統在 VPN 隧道里訪問 HTTPS 服務,結果就…

【Linux系統】進程信號:信號的產生和保存

上篇文章我們介紹了Syetem V IPC的消息隊列和信號量,那么信號量和我們下面要介紹的信號有什么關系嗎?其實沒有關系,就相當于我們日常生活中常說的老婆和老婆餅,二者并沒有關系1. 認識信號1.1 生活角度的信號解釋(快遞比…

WEB服務器(靜態/動態網站搭建)

簡介 名詞:HTML(超文本標記語言),網站(多個網頁組成一臺網站),主頁,網頁,URL(統一資源定位符) 網站架構:LAMP(linux(系統)+apache(服務器程序)+mysql(數據庫管理軟件)+php(中間軟件)) 靜態站點 Apache基礎 Apache官網:www.apache.org 軟件包名稱:…

開發避坑指南(29):微信昵稱特殊字符存儲異常修復方案

異常信息 Cause: java.sql.SQLException: Incorrect string value: \xF0\x9F\x8D\x8B\xE5\xBB... for column nick_name at row 1異常背景 抽獎大轉盤,抽獎后需要保存用戶抽獎記錄,用戶再次進入游戲時根據抽獎記錄判斷剩余抽獎機會。保存抽獎記錄時需要…

leetcode-python-242有效的字母異位詞

題目&#xff1a; 給定兩個字符串 s 和 t &#xff0c;編寫一個函數來判斷 t 是否是 s 的 字母異位詞。 示例 1: 輸入: s “anagram”, t “nagaram” 輸出: true 示例 2: 輸入: s “rat”, t “car” 輸出: false 提示: 1 < s.length, t.length < 5 * 104 s 和 t 僅…

【ARM】Keil MDK如何指定單文件的優化等級

1、 文檔目標解決在MDK中如何對于單個源文件去設置優化等級。2、 問題場景在正常的項目開發中&#xff0c;我們通常都是針對整個工程去做優化&#xff0c;相當于整個工程都是使用一個編譯器優化等級去進行的工程構建。那么在一些特定的情況下&#xff0c;工程師需要保證我的部分…