Pytorch 實戰四 VGG 網絡訓練

系列文章目錄


文章目錄

  • 系列文章目錄
  • 前言
  • 一、源碼
    • 1. 解決線程沖突
    • 2.代碼框架
  • 二、代碼詳細介紹
    • 1.基礎定義
    • 2. epoch 的定義
    • 3. 每組圖片的訓練和模型保存


前言

??前面我們已經完成了數據集的制作,VGG 網絡的搭建,現在進行網絡模型的訓練。


一、源碼


import torch.nn as nn
import torchvision
from VggNet import VGGNet
from load_cifa10 import train_data_loader, test_data_loader
import torch.multiprocessing as mp
import torch
import multiprocessing
from torch.utils.data import DataLoaderfrom model.ClassModel import netdef main():# 訓練模型到底放在 CPU 還是GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 如果cuda有效,就在GPU訓練,否則CPU訓練# 我們會對樣本遍歷20次epoch_num = 20# 學習率lr = 0.01# 正確率計算相關batch_num=0correct0 =0# 網絡定義# print("need初始化")net = VGGNet().to(device)# 定義損失函數loss,多分類問題,采用交叉熵loss_func = nn.CrossEntropyLoss()# 定義優化器optimizeroptimizer = torch.optim.Adam(net.parameters(),lr=lr)# 動態調整學習率,第一個參數是優化器,起二個參數每5個epoch后調整學習率,第三個參數調整為原來的0.9倍lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5,gamma=0.9)# 定義循環for epoch in range(epoch_num):# print("epoch is:",epoch+1)# 定義網絡訓練的過程net.train()   # BatchNorm 和 dropout 會選擇相應的參數# 對數據進行遍歷for i,data in enumerate(train_data_loader):batch_num = len(train_data_loader)# 獲取輸入和標簽inputs, labels = datainputs, labels = inputs.to(device), labels.to(device) # 放到GPU上去# 拿到輸出# print("need output")outputs = net(inputs)  # 這句就會調用前向傳播,在PyTorch中,當執行outputs = net(inputs)時會自動觸發前向傳播,# 這是通過nn.Module的__call__方法實現的特殊機制6。具體原理可分為三個關鍵環節:# 計算損失loss = loss_func(outputs, labels)# 定義優化器,梯度要歸零,loss反向傳播,更新參數optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(outputs.data, dim=1)  # 得到一個batch的預測類別# 在cpu上面運行,當labels是普通張量時,.data屬性返回?剝離計算圖的純數值張量?(與原始張量共享內存但無梯度追蹤)correct = predicted.eq(labels.data).cpu().sum()correct=100.0*correct/len(inputs)correct0+=correct#自動更新學習率lr_scheduler.step()lr = optimizer.state_dict()['param_groups'][0]['lr']print("loss:{},acc:{}",lr,correct0/batch_num)if __name__ == '__main__':# Windows必須設置spawn,Linux/Mac自動選擇最佳方式mp.set_start_method('spawn' if torch.cuda.is_available() else 'fork')torch.multiprocessing.freeze_support()try:main()except RuntimeError as e:print(f"多進程錯誤: {str(e)}")print("降級到單進程模式...")train_data_loader = DataLoader(..., num_workers=0)main()

1. 解決線程沖突

??windows 跑代碼需要解決線程沖突的問題:需要自行定義main函數,然后把主題加在里面。當我們運行時自動調用main,就會執行下面的 if 語句,然后運行我們的代碼

if __name__ == '__main__':# Windows必須設置spawn,Linux/Mac自動選擇最佳方式mp.set_start_method('spawn' if torch.cuda.is_available() else 'fork')torch.multiprocessing.freeze_support()try:main()except RuntimeError as e:print(f"多進程錯誤: {str(e)}")print("降級到單進程模式...")train_data_loader = DataLoader(..., num_workers=0)main()

2.代碼框架

??代碼分成四個部分,第一個部分是基礎變量定義,第二個部分是循環 epoch ,第三部分是每個 batch 的處理,第四個保存模型,其中最重要的便是第三個。

圖 1 代碼框架

二、代碼詳細介紹

1.基礎定義

    # 訓練模型到底放在 CPU 還是GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 如果cuda有效,就在GPU訓練,否則CPU訓練# 我們會對樣本遍歷200次epoch_num = 200# 學習率lr = 0.01# 正確率計算相關batch_num=0correct0 =0# 網絡定義# print("need初始化")net = VGGNet().to(device)# 定義損失函數loss,多分類問題,采用交叉熵loss_func = nn.CrossEntropyLoss()# 定義優化器optimizeroptimizer = torch.optim.Adam(net.parameters(),lr=lr)# 動態調整學習率,第一個參數是優化器,起二個參數每5個epoch后調整學習率,第三個參數調整為原來的0.9倍lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5,gamma=0.9)

??基礎定義最開始需要定義跑數據的設備,CPU還是GPU.這個是死的。然后定義epoch次數、學習率,至于準確率看自己的使用情況,如果每次跑完一遍數據集打印準確率也不著急定義。我在最終跑完數據才打印準確率,所以需要定義一個全局的變量。接下來便是網絡初始化,初始化的網絡加載到設備上面。在網絡搭建的時候,我們只定義了網絡的層次和前向傳播。后面的損失函數和優化器需要在訓練中進行。那么基礎定義里面需要損失函數的選擇,優化器的選擇和動態調整學習率。epoch 改成200,我的電腦跑了1h還沒出結果,現在還在等,建議別弄大了。

圖 2 基礎定義

當然順序可以變,最好自己能記住需要的內容。

2. epoch 的定義

??epoch里面開始調用網絡,net.train() 會把網絡的參數進行初始化,BatchNorm 會自動啟用訓練模式,dropout層會全部激活,而這在測試集上不需要dropout的。后面便是每組圖片 batch的訓練行為。最后每一次處理整個數據集需要動態改變學習率,以及打印學習率的方法如下:

        #自動更新學習率lr_scheduler.step()# 打印學習率lr = optimizer.state_dict()['param_groups'][0]['lr']print("學習率", lr)

3. 每組圖片的訓練和模型保存

      for i,data in enumerate(train_data_loader):batch_num = len(train_data_loader)# 獲取輸入和標簽inputs, labels = datainputs, labels = inputs.to(device), labels.to(device) # 放到GPU上去# 拿到輸出# print("need output")outputs = net(inputs)  # 這句就會調用前向傳播,在PyTorch中,當執行outputs = net(inputs)時會自動觸發前向傳播,# 這是通過nn.Module的__call__方法實現的特殊機制6。具體原理可分為三個關鍵環節:# 計算損失loss = loss_func(outputs, labels)# 定義優化器,梯度要歸零,loss反向傳播,更新參數optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(outputs.data, dim=1)  # 得到一個batch的預測類別

??這里進行數據加載,分批次加載,此處的batch 大小是128,數量是391(用于準確率計算)。加載了數據,獲取數據的輸入和真實標簽。outputs = net(inputs) 這句對網絡傳入數據,自行前向傳播計算獲得輸出。拿到輸出后,進行損失函數計算。損失函數計算,是需要預測值和真實值的,看看偏差多少,因此傳入這兩個參數。優化器優化,開始梯度歸零,然后后向傳播,這個過程是自帶的,我們只定義了前向傳播,后向傳播優化參數后固定參數 optimizer.step(),最后使用torch.max() 輸出最相似的標簽。過程如圖:

圖 3 batch 循環

模型的保存就一句話:

torch.save(net.state_dict(),"./model/VGGNet.pth")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/85883.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/85883.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/85883.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

課程專注度分析系統文檔

一、項目概述 本項目基于 Flask 框架開發,結合計算機視覺技術(利用 YOLOv10 等模型 ),實現對課堂視頻的智能分析。可檢測視頻中學生手機使用情況、面部表情(專注、分心等 ),統計專注度、手機使…

中國設計 全球審美 | 安貝斯新產品發布會:以東方美學開辟控制臺仿生智造新紀元

6月17日,安貝斯(武漢)控制技術有限公司(以下簡稱“安貝斯”)在武漢隆重舉行“新產品發布暨協會聯合創新峰會”。近百位來自政府機構、行業協會、行業用戶及戰略合作伙伴的嘉賓齊聚現場,共同見證以“中國設計…

在微信小程序wxml文件調用函數實現時間轉換---使用wxs模塊實現

1. 創建 WXS 模塊文件(推薦單獨存放) 在項目目錄下新建 utils.wxs 文件,編寫時間轉換邏輯: // utils.wxs module.exports {// 將毫秒轉換為分鐘(保留1位小數)convertToMinutes: function(ms) {if (typeo…

ByteMD 插件系統詳解

ByteMD 插件系統詳解 ByteMD 的插件系統是其強大擴展性的核心。它允許開發者在 Markdown 解析、AST 轉換、HTML 渲染、以及編輯器 UI 交互的各個階段注入自定義邏輯。這得益于 ByteMD 深度集成了 unified 處理器和其豐富的生態系統(remark 用于 Markdown&#xff0c…

每日一練之 Lua 表

Lua 的 table 是什么數據結構?如何創建和訪問? 數據結構:Lua的table是一種哈希表,使用鍵值對存儲數據,支持動態擴容 創建方式: local t1 {} local t2 {10,20,30} local t3 {name"Alice",age25}訪問方式&#xff1a…

實現自動胡批量抓取唯品會商品詳情數據的途徑分享(官方API、網頁爬蟲)

在電商領域,數據就是企業的核心資產。無論是市場分析、競品研究,還是精準營銷,都離不開對大量商品詳情數據的深入挖掘。唯品會作為知名的電商平臺,其豐富的商品信息對于眾多從業者而言極具價值。本文將詳細探討實現自動批量抓取唯…

Zephyr 高階實踐:徹底講透 west 構建系統、模塊管理與跨平臺 CI/CD 配置

本文是 Zephyr 項目管理體系的高階解構與實戰指南,全面覆蓋 west 構建系統原理、模塊解耦與 west.yml 多模塊維護機制,結合企業級多平臺 CI/CD 落地流程,深入講解如何構建可靠、可維護、跨芯片架構的一體化 Zephyr 工程。 一、為什么 Zephyr …

我開源了一套springboot3快速開發模板

我開源了一套springboot3快速開發模板 開箱即用、按需組合、可快速二次開發的后端通用模板。 ? 主要特性 Spring Boot 3.x Java 17:跟隨 Spring 最新生態,利用現代語法特性。多模塊分層:common 抽象通用能力、starter 負責啟動、modules…

OpenCV CUDA模塊設備層-----在GPU上計算兩個uchar1類型像素值的反正切(arctangent)比值函數atan2()

操作系統:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 編程語言:C11 算法描述 對輸入的兩個 uchar1 像素值 a 和 b,先分別歸一化到 [0.0, 1.0] 浮點區間,然后計算它們的 四象限反正切函數。 函數原型…

從C++編程入手設計模式——觀察者模式

從C編程入手設計模式——觀察者模式 ? 觀察者模式簡直就是字如其名,觀察觀察,觀察到了告訴別人。觀察手的作用如此,觀察者模式的工作機制也是如此。這個模式的核心思路是:一個對象的狀態發生變化時,自動通知依賴它的…

MITM 中間人攻擊

?據Akamai 2023網絡安全報告顯示,MITM攻擊在數據泄露事件中占比達32.7%,平均每次事件造成企業損失$380,000? ?NIST研究指出:2022-2023年高級MITM攻擊增長41%,近70%針對金融和醫療行業? 一、MITM攻擊核心原理與技術演進 1. 中…

llama_index chromadb實現RAG的簡單應用

此demo是自己提的一個需求:用modelscope下載的本地大模型實現RAG應用。畢竟大模型本地化有利于微調,RAG使內容更有依據。 為什么要用RAG? 由于大模型存在一定的局限性:知識時效性不足、專業領域覆蓋有限以及生成結果易出現“幻覺…

TDMQ CKafka 版事務:分布式環境下的消息一致性保障

解鎖 CKafka 事務能力的神秘面紗 在當今數字化浪潮下,分布式系統已成為支撐海量數據處理和高并發業務的中流砥柱。但在這看似堅不可摧的架構背后,數據一致性問題卻如影隨形,時刻考驗著系統的穩定性與可靠性。 CKafka 作為分布式流處理平臺的…

常見的負載均衡算法

常見的負載均衡算法 在實現水平擴展過程中,負載均衡算法是決定請求如何在多個服務實例間分配的核心邏輯。一個合理的負載均衡策略能夠有效分散系統壓力,提升系統吞吐能力與穩定性。 負載均衡算法可部署在多種層級中,如七層HTTP反向代理&…

數據結構轉換與離散點生成

在 C 開發中&#xff0c;我們常常需要在不同的數據結構之間進行轉換&#xff0c;以滿足特定庫或框架的要求。本文將探討如何將 std::vector<gp_Pnt> 轉換為 QVector<QPointF>&#xff0c;并生成特定范圍內的二維離散點。 生成二維離散點 我們首先需要生成一系列…

零基礎學習Redis(12) -- Java連接redis服務器

在我們之前的內容中&#xff0c;我們會發現通過命令行操作redis是十分不科學的&#xff0c;所以redis官方提供了redis的應用層協議RESP&#xff0c;更具這個協議可以實現一個和redis服務器通信的客戶端程序&#xff0c;來簡化和完善redis的使用。現階段有很多封裝了RESP協議的庫…

clangd LSP 不能找到項目中的文件

clangd LSP 不能找到項目中的文件 clangd LSP 不能找到項目中的文件 clangd LSP 不能找到項目中的文件 Normally you need to create compile_commands.json。 如果你使用 cmake 作為構建工具&#xff0c;請執行下面的命令&#xff1a; cmake -DCMAKE_EXPORT_COMPILE_COMMAN…

【內存】Linux 內核優化實戰 - vm.overcommit_memory

目錄 vm.overcommit_memory 解釋一、概念與作用二、參數取值與含義三、相關參數與配置方式四、實際應用場景建議五、注意事項 vm.overcommit_memory 解釋 一、概念與作用 vm.overcommit_memory 是 Linux 內核中的一個參數&#xff0c;用于控制內存分配的“過度承諾”&#xf…

Python:.py文件轉換為雙擊可執行的Windows程序(版本2)

流程步驟&#xff1a; 這個流程圖展示了將 Python .py 文件轉換為 Windows 可執行程序的完整過程&#xff0c;主要包括以下步驟&#xff1a; 1、準備 Python文件&#xff0c;確保代碼可獨立運行 2、安裝打包工具&#xff08;如 PyInstaller&#xff09; 3、打開命令提示符并定位…

【請關注】mysql一些經常用到的高級SQL

經常去重復數據&#xff0c;數據需要轉等操作&#xff0c;匯總高級SQL MySQL操作 一、數據去重&#xff08;Data Deduplication&#xff09; 去重常用于清除重復記錄&#xff0c;保留唯一數據。 1. 使用DISTINCT關鍵字去重單列 -- 從用戶表中獲取唯一的郵箱地址 SELECT DISTIN…