學習pytorch18 pytorch完整的模型訓練流程

pytorch完整的模型訓練流程

  • 1. 流程
    • 1. 整理訓練數據 使用CIFAR10數據集
    • 2. 搭建網絡結構
    • 3. 構建損失函數
    • 4. 使用優化器
    • 5. 訓練模型
    • 6. 測試數據 計算模型預測正確率
    • 7. 保存模型
  • 2. 代碼
    • 1. model.py
    • 2. train.py
  • 3. 結果
    • tensorboard結果
      • 以下圖片 顏色較淺的線是真實計算的值,顏色較深的線是做了平滑處理的值
      • 訓練loss
      • 測試loss
      • 測試集正確率
  • 4. 需要注意的細節

1. 流程

1. 整理訓練數據 使用CIFAR10數據集

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)

2. 搭建網絡結構

在這里插入圖片描述
model.py

3. 構建損失函數

loss_fn = nn.CrossEntropyLoss()

4. 使用優化器

learing_rate = 1e-2 # 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

5. 訓練模型

output = net(imgs)    # 數據輸入模型
loss = loss_fn(output, targets)  # 損失函數計算損失 看計算的輸出和真實的標簽誤差是多少
# 優化器開始優化模型  1.梯度清零  2.反向傳播  3.參數優化
optimizer.zero_grad()  # 利用優化器把梯度清零 全部設置為0
loss.backward()        # 設置計算的損失值的鉤子,調用損失的反向傳播,計算每個參數結點的參數
optimizer.step()       # 調用優化器的step()方法 對其中的參數進行優化  

6. 測試數據 計算模型預測正確率

output = net(imags)
# 計算測試集的正確率
preds = (output.argmax(1)==targets).sum()
accuracy += preds 
rate = accuracy/len(test_data)

調用模型輸出tensor 數據類型的 argmax方法, argmax或獲取一行或者一列數值中最大數值的下標位置,argmax(0) 是從列的維度取一列數值的最大值的下標,argmax(1) 是從行的維度取一行數值的最大值的下標
output.argmax(1)==targets 會輸出如下圖最后一行 [false, ture], 對應位置相同則為true,對應位置不同則為false;
調用sum()方法,計算求和,false值為0,true值為1.
最后計算得出測試集整體正確率: rate = accuracy/len(test_data)
在這里插入圖片描述

7. 保存模型

torch.save(net, './net_epoch{}.pth'.format(i))

2. 代碼

1. model.py

import torch
from torch import nn# 2. 搭建模型網絡結構--神經網絡
class Cifar10Net(nn.Module):def __init__(self):super(Cifar10Net, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.net(x)return xif __name__ == '__main__':net = Cifar10Net()input = torch.ones((64, 3, 32, 32))output = net(input)print(output.shape)

2. train.py

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriterfrom p24_model import *# 1. 準備數據集
# 訓練數據
from torch.utils.data import DataLoadertrain_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
# 測試數據
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)# 查看數據大小--size
print("訓練數據集大小:", len(train_data))
print("測試數據集大小:", len(test_data))
# 利用DataLoader來加載數據集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)# 2. 導入模型結構 創建模型
net = Cifar10Net()# 3. 創建損失函數  分類問題--交叉熵
loss_fn = nn.CrossEntropyLoss()# 4. 創建優化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)# 設置訓練網絡的一些參數
epoch = 10   # 記錄訓練的輪數
total_train_step = 0  # 記錄訓練的次數
total_test_step = 0   # 記錄測試的次數# 利用tensorboard顯示訓練loss趨勢
writer = SummaryWriter('./train_logs')for i in range(epoch):# 訓練步驟開始net.train()  # 可以加可以不加  只有當模型結構有 Dropout BatchNorml層才會起作用for data in train_loader:imgs, targets = data  # 獲取數據output = net(imgs)    # 數據輸入模型loss = loss_fn(output, targets)  # 損失函數計算損失 看計算的輸出和真實的標簽誤差是多少# 優化器開始優化模型  1.梯度清零  2.反向傳播  3.參數優化optimizer.zero_grad()  # 利用優化器把梯度清零 全部設置為0loss.backward()        # 設置計算的損失值,調用損失的反向傳播,計算每個參數結點的參數optimizer.step()       # 調用優化器的step()方法 對其中的參數進行優化# 優化一次 認為訓練了一次total_train_step += 1if total_train_step % 100 == 0:print('訓練次數: {}   loss: {}'.format(total_train_step, loss))# 直接打印loss是tensor數據類型,打印loss.item()是打印的int或float真實數值, 真實數值方便做數據可視化【損失可視化】# print('訓練次數: {}   loss: {}'.format(total_train_step, loss.item()))writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)# 利用現有模型做模型測試# 測試步驟開始total_test_loss = 0accuracy = 0net.eval()  # 可以加可以不加  只有當模型結構有 Dropout BatchNorml層才會起作用with torch.no_grad():for data in test_loader:imags, targets = dataoutput = net(imags)loss = loss_fn(output, targets)total_test_loss += loss.item()# 計算測試集的正確率preds = (output.argmax(1)==targets).sum()accuracy += preds# writer.add_scalar('test-loss', total_test_loss, global_step=i+1)writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)total_test_step += 1print("---------test loss: {}--------------".format(total_test_loss))print("---------test accuracy: {}--------------".format(accuracy))# 保存每一個epoch訓練得到的模型torch.save(net, './net_epoch{}.pth'.format(i))writer.close()

3. 結果

訓練數據集大小: 50000
測試數據集大小: 10000
0.01
訓練次數: 100   loss: 2.2905373573303223
訓練次數: 200   loss: 2.2878968715667725
訓練次數: 300   loss: 2.258394718170166
訓練次數: 400   loss: 2.1968581676483154
訓練次數: 500   loss: 2.0476632118225098
訓練次數: 600   loss: 2.002145767211914
訓練次數: 700   loss: 2.016021728515625
---------test loss: 316.382279753685--------------
訓練次數: 800   loss: 1.8957302570343018
訓練次數: 900   loss: 1.8659226894378662
訓練次數: 1000   loss: 1.9004186391830444
訓練次數: 1100   loss: 1.9708642959594727
......

tensorboard結果

安裝tensorboard運行環境

pip install tensorboard
pip install opencv-python
pip install six
tensorboard --logdir=train_logs

以下圖片 顏色較淺的線是真實計算的值,顏色較深的線是做了平滑處理的值

訓練loss

在這里插入圖片描述

測試loss

在這里插入圖片描述

測試集正確率

在這里插入圖片描述

4. 需要注意的細節

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

所有網絡層繼承于torch.nn.Module, net.train() net.eval() 在模型訓練或測試之初 可以加可以不加 只有當模型結構有 Dropout BatchNorml層才會起作用,當模型有這兩個網絡層的時候,兩個代碼需要加上。
在這里插入圖片描述

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/208738.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/208738.shtml
英文地址,請注明出處:http://en.pswp.cn/news/208738.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

國產化軟件突圍!懌星科技eStation產品榮獲2023鈴軒獎“前瞻優秀獎”

11月11日,2023中國汽車供應鏈峰會暨第八屆鈴軒獎頒獎典禮在江蘇省昆山市舉行。懌星科技憑借eStation產品,榮獲2023鈴軒獎“前瞻智能座艙類優秀獎”,懌星CEO潘凱受邀出席鈴軒獎晚會并代表領獎。 2023鈴軒獎“前瞻智能座艙類優秀獎” 鈴軒獎&a…

el-table 跨頁多選

步驟一 在<el-table>中:row-key"getRowKeys"和selection-change"handleSelectionChange" 在<el-table-column>中type"selection"那列&#xff0c;添加:reserve-selection"true" <el-table:data"tableData"r…

隊列排序:給定序列a,每次操作將a[1]移動到 從右往左第一個嚴格小于a[1]的元素的下一個位置,求能否使序列有序,若可以,求最少操作次數

題目 思路&#xff1a; 賽時代碼&#xff08;先求右起最長有序區間長度&#xff0c;再求左邊最小值是否小于等于右邊有序區間左端點的數&#xff09; #include<bits/stdc.h> using namespace std; #define int long long const int maxn 1e6 5; int a[maxn]; int n; …

阿里云磁盤在線擴容

我們從阿里云的控制面板中給硬盤擴容后結果發現我們的磁盤空間并沒有改變 注意&#xff1a;本次操作是針對CentOS 7的 &#xfeff;#使用df -h并沒有發現我們的磁盤空間增加 #使用fdisk -l發現確實還有部分空間 運行df -h命令查看云盤分區大小。 以下示例返回分區&#xf…

python3安裝redis

#!/usr/bin/python3import os import platform import argparse import shutil# 自定義變量 default_system "ubuntu" default_redis_version "6.2.6" default_install_path "/usr/local/redis" default_local_package_dir os.path.dirname(…

eve-ng鏡像模擬設備-信息安全管理與評估-2023國賽

eve-ng鏡像模擬設備-信息安全管理與評估-2023國賽 author&#xff1a;leadlife data&#xff1a;2023/12/4 mains&#xff1a;EVE-ng 模擬器 - 信息安全管理與評估模擬環境部署 references&#xff1a; EVE-ng 官網&#xff1a;https://www.eve-ng.net/EVE-ng 中文網&#xff1…

嵌入版python作為便攜計算器(安裝及配置ipython)

今天用別的電腦調試C&#xff0c;需要計算反三角函數時發現沒有趁手工具&#xff0c;忽然想用python作為便攜計算器放在U盤&#xff0c;遂想到嵌入版python 懶得自己配可以直接下載&#xff0c;使用方法見第4節 1&#xff0c;下載embeddable python&#xff08;嵌入版python&…

React中傳入props.children后, 為什么會導致組件的重新渲染?

傳入props.children后, 為什么會導致組件的重新渲染&#xff1f; 問題描述 在 react 中, 我想要對組件的渲染進行優化, 遇到了一個非常意思的問題, 當我向一個組件中傳入了 props.children 之后, 每次父組件重新渲染都會導致這個組件的重新渲染; 它看起來的表現就像是被memo包…

【1day】?萬戶協同辦公平臺 convertFile 任意文件讀取漏洞學習

注:該文章來自作者日常學習筆記,請勿利用文章內的相關技術從事非法測試,如因此產生的一切不良后果與作者無關。 目錄 一、漏洞描述 二、影響版本 三、資產測繪 四、漏洞復現

圖的鄰接鏈表儲存

噴了一節課 。。。。。。。、。 #include<stdio.h> #include<stdlib.h> #define MAXNUM 20 //每一個頂點的節點結構&#xff08;單鏈表&#xff09; typedef struct ANode{ int adjvex;//頂點指向的位置 struct ArcNode *next;//指向下一個頂點 …

C++ 內存分區模型

目錄 程序運行前 代碼區 全局區 程序運行后 new 在堆區開辟數據 delete釋放堆區數據 堆區開辟數組 內存分區模型 棧&#xff08;Stack&#xff09; 堆&#xff08;Heap&#xff09; 全局/靜態存儲區&#xff08;Global/Static Storage&#xff09; 常量存儲區&am…

力扣230. 二叉搜索樹中第K小的元素

深度優先搜索 思路&#xff1a; 二叉搜索樹的特性&#xff0c;通過中序遍歷得到有序序列&#xff0c;則遍歷到第K個節點的時候即為結果&#xff1b;使用棧通過深度優先遍歷進行中序遍歷&#xff1a; 先將節點和左子節點壓棧&#xff1b;然后棧頂上就是“最左”葉子節點&#x…

Linux DAC權限的簡單應用

Linux的DAC&#xff08;Discretionary Access Control&#xff09;權限模型是一種常見的訪問控制機制&#xff0c;它用于管理文件和目錄的訪問權限。作為一名經驗豐富的Linux系統安全工程師&#xff0c;我會盡可能以簡單明了的方式向計算機小白介紹Linux DAC權限模型。 在Linu…

jenkins中“Jenkins Plot Plugin”的使用方法,比較兩個excel的數據差異

Jenkins Plot Plugin是Jenkins的一個插件&#xff0c;它可以用于生成圖表和報表&#xff0c;以便更好地理解和分析構建和測試數據。下面是使用Jenkins Plot Plugin比較兩個Excel數據差異的步驟&#xff1a; 1.安裝Jenkins Plot Plugin&#xff1a;在Jenkins的插件管理頁面搜索…

使用 Axios 進行網絡請求的全面指南

使用 Axios 進行網絡請求的全面指南 本文將向您介紹如何使用 Axios 進行網絡請求。通過分步指南和示例代碼&#xff0c;您將學習如何使用 Axios 庫在前端應用程序中發送 GET、POST、PUT 和 DELETE 請求&#xff0c;并處理響應數據和錯誤。 準備工作 在開始之前&#xff0c;請…

電子學會C/C++編程等級考試2021年09月(五級)真題解析

C/C++等級考試(1~8級)全部真題?點這里 第1題:抓牛 農夫知道一頭牛的位置,想要抓住它。農夫和牛都位于數軸上,農夫起始位于點N(0<=N<=100000),牛位于點K(0<=K<=100000)。農夫有兩種移動方式: 1、從X移動到X-1或X+1,每次移動花費一分鐘 2、從X移動到2*X,每…

ubuntu18.04安裝opencv-4.5.5+opencv_contrib-4.5.5

一、安裝opencv依賴 sudo apt-get install build-essential sudo apt-get install cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev sudo apt-get install python-dev python-numpy libtbb2 libtbb-dev libjpeg-dev libpng-dev libtiff-d…

Navicat 技術指引 | 適用于 GaussDB 分布式的自動運行功能

Navicat Premium&#xff08;16.3.3 Windows 版或以上&#xff09;正式支持 GaussDB 分布式數據庫。GaussDB 分布式模式更適合對系統可用性和數據處理能力要求較高的場景。Navicat 工具不僅提供可視化數據查看和編輯功能&#xff0c;還提供強大的高階功能&#xff08;如模型、結…

「Python編程基礎」第7章:字符串操作

文章目錄 一、回顧二、新手容易踩坑的引號三、轉義字符四、多行字符串寫法五、注釋六、字符串索引和切片七、字符串的in 和 not in八、字符串拼接九、轉換大小寫十、合并字符串join()十一、分割字符串split()十二、字符串替換 replace()十三、字符串內容判斷方法十四、字符串內…

讀文章摘錄

20%的時間可以做點業余項目。有個叫克萊舍基的人&#xff0c;寫了一本書&#xff0c;書名叫《認知盈余-網絡時代的創造與繁榮》&#xff0c;他有個觀點&#xff0c;閑暇時間給人機會創造有價值的東西。 很重要的一點是選合適的人&#xff0c;把他們引入團隊。何謂合適的人&…