PyTorch深度學習快速入門學習總結(三)

現有網絡模型的使用與調整

VGG — Torchvision 0.22 documentation

????????VGG 模型是由牛津大學牛津大學(Oxford University)的 Visual Geometry Group 于 2014 年提出的卷積神經網絡模型,在 ImageNet 圖像分類挑戰賽中表現優異,以其簡潔統一的網絡結構設計而聞名。

  • 優點:結構簡潔統一,易于理解和實現;小卷積核設計提升了特征提取能力,泛化性能較好。
  • 缺點:參數量巨大(主要來自全連接層),計算成本高,訓練和推理速度較慢,對硬件資源要求較高。

ImageNet — Torchvision 0.22 documentation

????????在 PyTorch 的torchvision庫中,ImageNet相關功能主要用于加載和預處理 ImageNet 數據集,方便用戶在該數據集上訓練或評估模型。

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader# train_data = torchvision.datasets.ImageNet('./data_image_net', split='train', download=True,
#                                        transform=torchvision.transforms.ToTensor())
# 指定要加載的數據集子集。這里設置為'train',表示加載的是 ImageNet 的訓練集(包含約 120 萬張圖像)。
# 若要加載驗證集,可將該參數改為'val'(驗證集包含約 5 萬張圖像)。vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)print(vgg16_true)# 利用現有網絡修改結構
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
# (add_linear): Linear(in_features=1000, out_features=10, bias=True)# 修改位置不同
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))# 直接修改某層
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

vgg16_false:適合用于訓練全新的任務,vgg16_true:常用于遷移學習場景

  • vgg16_false:pretrained=False表示不加載在大型數據集上(如 ImageNet)預訓練好的權重參數。此時,模型的權重參數會按照默認的隨機初始化方式進行初始化,比如卷積層和全連接層的權重會從一定范圍內的隨機值開始,偏置項通常初始化為 0 或一個較小的值。這種初始化方式下,模型需要從頭開始學習訓練數據中的特征表示。
  • vgg16_true:pretrained=True表示加載在 ImageNet 數據集上預訓練好的權重參數。VGG16 在 ImageNet 上經過大量圖像的訓練,已經學習到了通用的圖像特征,比如邊緣、紋理、形狀等。加載這些預訓練權重后,模型在新的任務上可以利用已學習到的特征,減少訓練所需的樣本數量和訓練時間,在很多情況下能更快地收斂到較好的性能。

模型的保存與加載

模型保存

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoadervgg16 = torchvision.models.vgg16(pretrained=False)# 方式一保存 模型結構+模型參數
torch.save(vgg16, 'vgg16_method1.pth')# 方式二保存 模型參數 (推薦)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')# 陷阱
class Chenxi(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return xchenxi = Chenxi()
torch.save(chenxi, 'chenxi_method1.pth')

相應模型加載方法

import torch
import torchvision
from module_save import * # 注意from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader# 方式一, 加載模型
model = torch.load('vgg16_method1.pth')
# print(model)# 方式二, 加載模型
vgg16 = torchvision.models.vgg16(pretrained=False) # 新建網絡模型結構
vgg16.load_state_dict(torch.load('vgg16_method2.pth'))
# print(vgg16)
# model = torch.load('vgg16_method2.pth')
# print(model)# 陷阱1
# 方法一必須要有模型# class Chenxi(nn.Module):
#     def __init__(self, *args, **kwargs) -> None:
#         super().__init__(*args, **kwargs)
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#     def forward(self, x):
#         x = self.conv1(x)
#         return x# chenxi = Chenxi() 不需要model = torch.load('chenxi_method1.pth')
print(model)

完整的模型訓練套路

CIFAR10為例

model文件

import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear# 搭建神經網絡
class Chenxi(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10),)def forward(self, x):x = self.model1(x)return xif __name__ == '__main__':chenxi = Chenxi()input = torch.ones((64, 3, 32, 32))output = chenxi(input)print(output.shape)# torch.Size([64, 10])

train文件

import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d, Sequential, Conv2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import *# 準備數據集
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())train_data_size = len(train_data)
test_data_size = len(test_data)print("訓練數據集的長度為:{}".format(train_data_size))
print("測試數據集的長度為:{}".format(test_data_size))
# 訓練數據集的長度為:50000
# 測試數據集的長度為:10000train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 創建網絡模型
chenxi = Chenxi()# 損失函數
loss_fn = nn.CrossEntropyLoss()# 優化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(chenxi.parameters(),lr=learning_rate)# 設置訓練網絡的參數
# 記錄訓練次數
total_train_step = 0
# 記錄測試次數
total_test_step = 0
# 訓練的輪數
epoch = 10# 添加Tensorboard
writer = SummaryWriter('./logs_train')for i in range(epoch):print('---------第{}輪訓練開始-----------'.format(i + 1))# 訓練開始chenxi.train()for data in train_dataloader:imgs, target = dataoutputs = chenxi(imgs)loss = loss_fn(outputs, target)# 優化器優化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:print("訓練次數:{}, loss:{}".format(total_train_step, loss.item()))writer.add_scalar('train_loss', loss.item(), total_train_step)# 測試步驟開始chenxi.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():# 禁用梯度計算for data in test_dataloader:imgs, targets = dataoutputs = chenxi(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整體測試集上的Loss:{}".format(total_test_loss))print("整體測試集上的正確率:{}".format(total_accuracy/test_data_size))writer.add_scalar('test_loss', loss.item(), total_test_step)writer.add_scalar('test_accuracy', loss.item(), total_test_step)total_test_step += 1torch.save(chenxi, "chenxi_{}.pth".format(i))# torch.save(chenxi.state_dict(), "chenxi_{}.pth".format(i))print("模型已保存")writer.close()

train優化部分

import torch
outputs = torch.tensor([[0.1, 0.2],[0.05, 0.4]])
# print(outputs.argmax(0)) # 縱向
# print(outputs.argmax(1)) # 橫向preds = outputs.argmax(0)
targets = torch.tensor([0, 1])
print((preds == targets).sum())

使用GPU進行訓練

方法一:使用.cuda()

????????對 網絡模型、數據及標注、損失函數, 進行 .cuda()操作

????????如果電腦不支持GPU,可以使用谷歌瀏覽器:

https://colab.research.google.com/drive/1HKuF0FtulVXkHaiWV8VzT-VXZmbq4kK4#scrollTo=861yC3qEpi3F

方法二:使用.to(device)

device = torch.device('cpu')
# device = torch.device('cuda')
x = x.to(device)
# 代替.cuda()

import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d, Sequential, Conv2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# from model import *# 準備數據集
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())train_data_size = len(train_data)
test_data_size = len(test_data)print("訓練數據集的長度為:{}".format(train_data_size))
print("測試數據集的長度為:{}".format(test_data_size))
# 訓練數據集的長度為:50000
# 測試數據集的長度為:10000train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 創建網絡模型
class Chenxi(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10),)def forward(self, x):x = self.model1(x)return xchenxi = Chenxi()
if torch.cuda.is_available():chenxi = chenxi.cuda()
# 損失函數
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():loss_fn = loss_fn.cuda()# 優化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(chenxi.parameters(),lr=learning_rate)# 設置訓練網絡的參數
# 記錄訓練次數
total_train_step = 0
# 記錄測試次數
total_test_step = 0
# 訓練的輪數
epoch = 10# 添加Tensorboard
writer = SummaryWriter('./logs_train')for i in range(epoch):print('---------第{}輪訓練開始-----------'.format(i + 1))# 訓練開始chenxi.train()for data in train_dataloader:imgs, targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = chenxi(imgs)loss = loss_fn(outputs, targets)# 優化器優化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:print("訓練次數:{}, loss:{}".format(total_train_step, loss.item()))writer.add_scalar('train_loss', loss.item(), total_train_step)# 測試步驟開始chenxi.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():# 禁用梯度計算for data in test_dataloader:imgs, targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = chenxi(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整體測試集上的Loss:{}".format(total_test_loss))print("整體測試集上的正確率:{}".format(total_accuracy/test_data_size))writer.add_scalar('test_loss', loss.item(), total_test_step)writer.add_scalar('test_accuracy', loss.item(), total_test_step)total_test_step += 1torch.save(chenxi, "chenxi_{}.pth".format(i))# torch.save(chenxi.state_dict(), "chenxi_{}.pth".format(i))print("模型已保存")writer.close()

完整的模型驗證套路

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/91291.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/91291.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/91291.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

是否需要買一個fpga開發板?

糾結要不要買個 FPGA 開發板?真心建議搞一塊,尤其是想在數字電路、嵌入式領域扎根的同學,這玩意兒可不是可有可無的擺設。入門級的選擇不少,全新的像 Cyclone IV、Artix 7 系列,幾百塊就能拿下,要是去二手平…

【模型細節】MHSA:多頭自注意力 (Multi-head Self Attention) 詳細解釋,使用 PyTorch代碼示例說明

MHSA:使用 PyTorch 實現的多頭自注意力 (Multi-head Self Attention) 代碼示例,包含詳細注釋說明:線性投影 通過三個線性層分別生成查詢(Q)、鍵(K)、值(V)矩陣: QWq?x,KWk?x,VWv?xQ W_qx, \quad K W_kx, \quad V W_vxQWq??x,KWk??x…

PGSQL運維優化:提升vacuum執行時間觀測能力

本文是 IvorySQL 2025 生態大會暨 PostgreSQL 高峰論壇上的演講內容,作者:NKYoung。 6 月底濟南召開的 HOW2025 IvorySQL 生態大會上,我在內核論壇分享了“提升 vacuum 時間觀測能力”的主題,提出了新增統計信息的方法&#xff0c…

神奇的數據跳變

目的 上周遇上了一個非常奇怪的問題,就是軟件的數據在跳變,本來數據應該是158吧,數據一會變成10,一會又變成158,數據在不斷地跳變,那是怎么回事?? 這個問題非常非常的神奇,讓人感覺太不可思議了。 這是這段時間,我遇上的最神奇的事了,沒有之一,最神奇的事,下面…

【跨國數倉遷移最佳實踐3】資源消耗減少50%!解析跨國數倉遷移至MaxCompute背后的性能優化技術

本系列文章將圍繞東南亞頭部科技集團的真實遷移歷程展開,逐步拆解 BigQuery 遷移至 MaxCompute 過程中的關鍵挑戰與技術創新。本篇為第3篇,解析跨國數倉遷移背后的性能優化技術。注:客戶背景為東南亞頭部科技集團,文中用 GoTerra …

【MySQL集群架構與實踐3】使用Dcoker實現讀寫分離

目錄 一. 在Docker中安裝ShardingSphere 二 實踐:讀寫分離 2.1 應用場景 2.2 架構圖 2.3 服務器規劃 2.4 啟動數據庫服務器 2.5. 配置讀寫分離 2.6 日志配置 2.7 重啟ShardingSphere 2.8 測試 2.9. 負載均衡 2.9.1. 隨機負載均衡算法示例 2.9.2. 輪詢負…

maven的阿里云鏡像地址

在 Maven 中配置阿里云鏡像可以加速依賴包的下載,尤其是國內環境下效果明顯。以下是阿里云 Maven 鏡像的配置方式: 配置步驟:找到 Maven 的配置文件 settings.xml 全局配置:位于 Maven 安裝目錄的 conf/settings.xml用戶級配置&am…

大語言模型信息抽取系統解析

這段代碼實現了一個基于大語言模型的信息抽取系統,能夠從金融和新聞類文本中提取結構化信息。下面我將詳細解析整個代碼的結構和功能。1. 代碼整體結構代碼主要分為以下幾個部分:模式定義:定義不同領域(金融、新聞)需要抽取的實體類型示例數據…

Next實習項目總結串聯講解(一)

下面是一些 Next.js 前端面試中常見且具深度的問題,按照邏輯模塊整理,同時提供示范回答建議,便于你條理清晰地展示理解與實踐經驗。 ? 面試講述結構建議 先講 Next.js 是什么,它為什么比 React 更高級。(支持 SSR/SSG/ISR,提升S…

React開發依賴分析

1. React小案例: 在界面顯示一個文本:Hello World點擊按鈕后,文本改為為:Hello React 2. React開發依賴 2.1. 開發React必須依賴三個庫: 2.1.1. react: 包含react所必須的核心代碼2.1.2. react-dom: react渲染在不同平…

工具(一)Cursor

目錄 一、介紹 二、如何打開文件 1、從idea跳轉文件 2、單獨打開項目 三、常見使用 1、Chat 窗口 Ask 對話模式 1.1、使用技巧 1.2 發送和使用 codebase 發送區別 1.3、問題快速修復 2、Chat 窗口 Agent 對話模式 2.1、agent模式功能 2.2、Chat 窗口回滾&撤銷 2.3…

Prompt編寫規范指引

1、📖 引言 隨著人工智能生成內容(AIGC)技術的快速發展,越來越多的開發者開始利用AIGC工具來輔助代碼編寫。然而,如何編寫有效的提示詞(Prompt)以引導AIGC生成高質量的代碼,成為了許…

自我學習----繪制Mark點

在PCB的Layout過程中我們需在光板上放置Mark點以方便生產時的光學定位(三點定位);我個人Mark點繪制步驟如下: layer層:1.放置直徑1mm的焊盤(無網絡連接) 2.放置一個圓直徑2mm,圓心與…

2025年財稅行業拓客破局:小藍本財稅版AI拓客系統助力高效拓客

2025年,在"金稅四期"全面實施的背景下,中國財稅服務市場迎來爆發式增長,根據最新的市場研究報告,2025年中國財稅服務行業產值將達2725.7億元。然而,行業高速發展的背后,80%的財稅公司卻陷入獲客成…

雙向鏈表,對其實現頭插入,尾插入以及遍歷倒序輸出

1.創建一個節點,并將鏈表的首節點返回創建一個獨立節點,沒有和原鏈表產生任何關系#include "head.h"typedef struct Node { int num; struct Node*pNext; struct Node*pPer; }NODE;后續代碼:NODE*createNode(int value) {NODE*new …

2025年自動化工程與計算機網絡國際會議(ICAECN 2025)

2025年自動化工程與計算機網絡國際會議(ICAECN 2025) 2025 International Conference on Automation Engineering and Computer Networks一、大會信息會議簡稱:ICAECN 2025 大會地點:中國柳州 審稿通知:投稿后2-3日內通…

12.Origin2021如何繪制誤差帶圖?

12.Origin2021如何繪制誤差帶圖?選中Y3列→點擊統計→選擇描述統計→選擇行統計→選擇打開對話框輸入范圍選擇B列到D列點擊輸出量→勾選均值和標準差Control選擇下面三列點擊繪圖→選擇基礎2D圖→選擇誤差帶圖雙擊圖像→選擇符號和顏色點擊第二個Sheet1→點擊誤差棒→連接選擇…

如何使用API接口獲取淘寶店鋪訂單信息

要獲取淘寶店鋪的訂單信息,您需要通過淘寶開放平臺(Taobao Open Platform, TOP)提供的API接口來實現。以下是詳細步驟:1. 注冊淘寶開放平臺賬號訪問淘寶開放平臺注冊開發者賬號并完成實名認證創建應用獲取App Key和App Secret2. 申請API權限在"我的…

【Kiro Code 從入門到精通】重要的功能

一、Kiro 是什么? Kiro 是一款智能型集成開發環境(IDE),借助規格說明(specs)、向導(steer)、鉤子(hooks)幫助你高效完成工作。 二、Specs 規格說明 規范&…