pytorch--模型訓練的一般流程

文章目錄

    • 前言
    • 0、數據集準備
    • 1、數據集
    • 2、dataset
    • 3、model
    • 4、訓練模型

前言

在pytorch中模型訓練一般分為以下幾個步驟:
0、數據集準備
1、數據集讀取(dataset模塊)
2、數據集轉換為tensor(dataloader模塊)
3、定義模型model(編寫模型代碼,主要是前向傳播)
4、定義損失函數loss
5、定義優化器optimizer
6、最后一步是模型訓練階段train:這一步會,利用循環把dataset->dataloader->model->loss->optimizer合并起來。
相比于普通的函數神經網絡并沒有特別神奇的地方,我們不妨訓練過程看成普通函數參數求解的過程,也就是最優化求解參數。以Alex模型為例,進行分類訓練。

0、數據集準備

分類數據不需要進行標注,只需要給出類別就可以了,對應分割,檢測需要借助labelme或者labelimg進行標注。將數據分為訓練集,驗證集,測試集。訓練集用于模型訓練,驗證集用于訓練過程中檢驗模型訓練參數的表現,測試集是模型訓練完成之后驗證模型的表現。

1、數據集

從這里下載數據集The TU Darmstadt Database (formerly the ETHZ Database)一個三種類型115 motorbikes + 50 x 2 cars + 112 cows = 327張照片,把數據分為訓練train和驗證集val

在這里插入圖片描述

并對train和val文件夾形成對應的標簽文件,每一行為照片的名稱和對應的類別編號(從0開始):
在這里插入圖片描述

2、dataset

現在寫一個名為dataset.py文件,寫一個VOCDataset的類,來讀取訓練集和驗證集,VOCDataset繼承了torch.utils.data.Dataset,并重寫父類的兩個函數__getitem__:返回每個圖像及其對應的標簽,def __len__返回數據集的數量:


import torch  
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from PIL import Image
import osclass VOCDataset(Dataset):def __init__(self, img_dir, label_root, transform=None):self.img_root = img_dirself.label_root = label_rootself.transform = transform# 獲取所有圖像路徑self.img_paths= [os.path.join(self.img_root, f) for f in os.listdir(self.img_root) if f.endswith('.png')]# 讀取txt中class標簽,txt文件每行格式為: img_name class_idself.label_classes = {}with open(label_root, 'r') as f:for line in f:img_name, class_id = line.strip().split()self.label_classes[img_name] = int(class_id)def __len__(self):return len(self.img_paths)def __getitem__(self, idx):img_path = self.img_paths[idx]img = Image.open(img_path).convert('RGB')# 獲取對應的標簽img_name = os.path.basename(img_path)target = self.label_classes.get(img_name, -1)if target == -1:raise ValueError(f"Image {img_name} not found in label file.")if self.transform:img = self.transform(img)else:img = transforms.ToTensor()(img)return img, target

3、model

新建一個model.py的文件,寫一個Alex的類(參考動手學深度學習7.1),繼承torch.nn.Module,重寫forword函數:

from torch import nn
from torchvision import modelsclass AlexNet(nn.Module):def __init__(self,num_class=3):super(AlexNet, self).__init__()self.conv2d1=nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=1)self.pool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=0)self.conv2d2=nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,stride=1,padding=2)self.pool2=nn.MaxPool2d(kernel_size=3,stride=2,padding=0)self.conv2d3=nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3,stride=1,padding=1)self.conv2d4=nn.Conv2d(in_channels=384,out_channels=384,kernel_size=3,stride=1,padding=1)self.conv2d5=nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1)self.pool3=nn.MaxPool2d(kernel_size=3,stride=2,padding=0)# 全連接層4096self.fc1=nn.Linear(256*5*5,4096)self.fc2=nn.Linear(4096,4096)self.fc3=nn.Linear(4096,num_class)self.sequential = nn.Sequential(self.conv2d1,nn.ReLU(),self.pool1,self.conv2d2,nn.ReLU(),self.pool2,self.conv2d3,nn.ReLU(),self.conv2d4,nn.ReLU(),self.conv2d5,nn.ReLU(),self.pool3,nn.Flatten(),self.fc1,nn.ReLU(),nn.Dropout(0.5),self.fc2,nn.ReLU(),nn.Dropout(0.5),self.fc3)# 初始化權重for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def forward(self,x):x = self.sequential(x)return x

4、訓練模型

首先定義損失函數和優化器:

  criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)

新建一個train.py的文件:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import VOCDataset
from model import AlexNet, ResnetPretrained
from torchvision import models
from torchvision.datasets import CIFAR10
from dataset import VOCDataset
import tensorboarddef train(model, train_dataset, val_dataset, num_epochs=20, batch_size=32, learning_rate=0.001):# 1. 創建數據加載器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)# 2. 定義損失函數和優化器criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 3. 修正學習率調度器(放在循環外)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)# 4. 訓練模型best_acc = 0.0for epoch in range(num_epochs):model.train()running_loss = 0.0total = 0for i, (inputs, labels) in enumerate(train_loader):inputs, labels = inputs.cuda(), labels.cuda()optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)total += inputs.size(0)if i % 100 == 0:avg_loss = running_loss / totalprint(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {avg_loss:.4f}')# 每個epoch結束后驗證model.eval()correct = 0total_val = 0val_loss = 0.0with torch.no_grad():for inputs, labels in val_loader:inputs, labels = inputs.cuda(), labels.cuda()outputs = model(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct += (predicted == labels).sum().item()val_loss += loss.item() * inputs.size(0)epoch_acc = 100 * correct / total_valavg_val_loss = val_loss / total_valprint(f'Epoch {epoch+1}/{num_epochs} | 'f'Train Loss: {running_loss/total:.4f} | 'f'Val Loss: {avg_val_loss:.4f} | 'f'Val Acc: {epoch_acc:.2f}%')# 更新學習率(基于驗證集準確率)#scheduler.step(epoch_acc)# 保存最佳模型if epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), 'best_alexnet_cifar10.pth')print(f'Best Validation Accuracy: {best_acc:.2f}%')if __name__ == "__main__":# 1. 定義數據集路徑train_img_dir = r'F:\dataset\tud\TUDarmstadt\PNGImages\train'val_img_dir = r'F:\dataset\tud\TUDarmstadt\PNGImages\val'train_label_file = r'F:\dataset\tud\TUDarmstadt\PNGImages/train_set.txt'val_label_file = r'F:\dataset\tud\TUDarmstadt\PNGImages/val_set.txt'# 2. 創建數據集實例# 增強數據增強transform_train = transforms.Compose([transforms.Resize((256, 256)),  # 先放大transforms.RandomCrop(224),  # 隨機裁剪transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 驗證集不需要數據增強,但需要同樣的預處理transform_val = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 創建訓練和驗證數據集train_dataset = VOCDataset(train_img_dir, train_label_file, transform=transform_train)val_dataset = VOCDataset(val_img_dir, val_label_file, transform=transform_val)print(f'Train dataset size: {len(train_dataset)}')print(f'Validation dataset size: {len(val_dataset)}')# 2. 下載并利用CIFAR-10數據集進行分類# # # 定義數據增強和預處理# transform_train = transforms.Compose([#     transforms.Resize((224, 224)),#     transforms.RandomHorizontalFlip(),#     transforms.RandomCrop(224, padding=4),#     transforms.ToTensor(),#     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], #                          std=[0.2470, 0.2435, 0.2616])# ])# transform_val = transforms.Compose([#     transforms.Resize((224, 224)),#     transforms.ToTensor(),#     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], #                          std=[0.2470, 0.2435, 0.2616])# ])# # 下載CIFAR-10訓練集和驗證集# train_dataset = CIFAR10(root='data', train=True, download=True, transform=transform_train)# val_dataset = CIFAR10(root='data', train=False, download=True, transform=transform_val)# print(f'Train dataset size: {len(train_dataset)}')# print(f'Validation dataset size: {len(val_dataset)}')# 3. 創建模型實例model = AlexNet(num_class=10)  # CIFAR-10有10個類別  # 檢查是否有可用的GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)  # 將模型移動到GPU或CPU# 打印模型結構#print(model)# 4. 開始訓練train(model, train_dataset, val_dataset, num_epochs=20, batch_size=32, learning_rate=0.001)print('Finished Training')# 5. 保存模型torch.save(model.state_dict(), 'output/alexnet.pth')print('Model saved as alexnet.pth')

運行main函數就可以進行訓練了,后面會講一些如何改進這個模型和一些訓練技巧。

參考:
1
2
3

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/912651.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/912651.shtml
英文地址,請注明出處:http://en.pswp.cn/news/912651.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

智能合同管理實戰:基于區塊鏈的電子簽約技術實現

在數字經濟時代,傳統紙質合同簽署方式已難以滿足企業高效、安全、合規的業務需求。智能合同管理(Smart Contract Management)結合區塊鏈技術,正在重塑電子簽約流程,實現合同全生命周期的自動化、可追溯和防篡改。本文將深入探討基于區塊鏈的電子簽約技術實現,涵蓋核心架構…

設計模式精講 Day 22:模板方法模式(Template Method Pattern)

【設計模式精講 Day 22】模板方法模式(Template Method Pattern) 文章標簽 設計模式, 模板方法模式, Java開發, 面向對象設計, 軟件架構, 設計模式實戰, Java應用開發 文章簡述 模板方法模式是一種行為型設計模式,它通過定義一個算法的骨架…

如何在pytorch中使用tqdm:優雅實現訓練進度監控

文章目錄 為什么需要進度條?tqdm 簡介基礎用法示例深度學習中的實戰應用1. 數據加載進度監控2. 訓練循環增強版3. 驗證階段集成 高級技巧與最佳實踐1. 自定義進度條樣式2. 嵌套進度條(多任務)3. 分布式訓練支持4. 與日志系統集成 性能優化建議…

Linux中的xxd命令詳解

xxd 是一個 十六進制轉儲(hex dump)工具,通常用于將二進制文件轉換為十六進制格式,或者反向轉換(十六進制→二進制)。它是 vim 的一部分,但在大多數 Linux 系統(如 Ubuntu&#xff0…

磐維數據庫panweidb3.1.0單節點多實例安裝

0 說明 業務科室提單需要在某臺主機上部署多個單機磐維數據庫,用于業務測試。以下內容展示如何在單節點安裝多個磐維數據庫實例。 1 部署環境準備 1.1 IP 地址及端口 instipport實例1192.168.131.1717700實例2192.168.131.1727700 在131.17上分別安裝兩個實例&…

轉錄組分析流程(三):功能富集分析

我們的教程主要是以一個具體的例子作為線索,通過對公共數據庫數據bulk-RNA-seq的挖掘,利用生物信息學分析來探索目標基因集作為某種疾病數據預后基因的潛能及其潛在分子機制,同時在單細胞水平分析(對scRNA-seq進行挖掘)預后基因的表達,了解細胞之間的通訊網絡,以期為該疾病…

全面掌握 tkinter:Python GUI 編程的入門與實戰指南

在自動化、工具開發、數據可視化等領域,圖形用戶界面(GUI)往往是提升用戶體驗的重要方式。作為 Python 官方內置的 GUI 庫,tkinter 以其輕量、跨平臺、易于學習的特性成為初學者和輕量級應用開發者首選。 本文將以深入淺出的方式…

TDH社區開發版安裝教程

(注:本文章來源于星環官網安裝手冊) 后面放置了視頻和安裝手冊連接 1、硬件及環境要求 Docker17及以上版本,支持Centos,Ubuntu等系統(注:這里我使用CentOS-7版本,最佳版本推薦為7.…

Linux基本命令篇 —— grep命令

grep是Linux/Unix系統中一個非常強大的文本搜索工具,它的名字來源于"Global Regular Expression Print"(全局正則表達式打印)。grep命令用于在文件中搜索包含特定模式的行,并將匹配的行打印出來。 目錄 一、基本語法 二…

蒼穹外賣問題系列之 蒼穹外賣訂單詳情前端界面和網課給的不一樣

問題 如圖,我的前端界面和網課里面給的不一樣,沒有“申請退款”和一些其他的該有的東西。 原因分析 “合計”這一欄顯示undefined說明我們的總金額沒有輸入進去。可以看看訂單提交那塊的代碼,是否可以正確輸出。還有就是訂單詳細界面展示這…

CppCon 2018 學習:EMULATING THE NINTENDO 3DS

我們來逐個分析一下這個 組件交互模型 和 仿真 & 序列化 的關系,特別是主線程(Main Thread)與其他系統組件之間的交互。 1. Main Thread — simple (basically memcpy) --> GPU Main Thread(主線程)負責游戲的…

[Python 基礎課程]數字

數字 數字數據類型用于存儲數值,比如整數、小數等。數據類型是不允許改變的,這就意味著如果改變數字數據類型的值,將重新分配內存空間。 創建數字類型的變量: var1 1 var2 10創建完變量后,如果想廢棄掉這個變量&a…

Linux CentOS環境下Java連接MySQL數據庫指南

文章目錄 前言一、環境準備1.1 系統更新1.2 Java環境安裝1.3 MySQL數據庫安裝1.4 下載JDBC驅動 二、編寫Java程序2.1 代碼如下2.2 編譯和運行2.3 驗證創建結果 三、代碼上傳至Gitee3.1 安裝配置Git3.2 克隆倉庫到本地3.3 添加Java項目文件3.4 提交代碼到本地倉庫3.5 推送到Gite…

LLM面試12

訊飛算法工程師面試題 SVM核函數能否映射到無窮維 可以的,多項式核函數將低維數據映射到高維(維度是有限的),而高斯核函數可以映射到無窮維。由 描述下xgb原理,損失函數 首先需要說一說GBDT,它是一種基于boosting增強…

類加載生命周期與內存區域詳解

類加載生命周期與內存區域詳解 Java 類加載的生命周期包括加載、驗證、準備、解析、初始化五個階段,每個階段在內存中的存儲區域和賦值機制各有不同。以下是詳細解析: 一、類加載生命周期階段 1. 加載(Loading) 內存區域&…

正交視圖三維重建2 筆記 2d線到3d線2 先生成3d線然后判斷3d線在不在

應該先連線再判斷線在不在 if(fx1tx1&&tx1tx2){ const A[fx1, fy1, ty1];const Ahat[fx1, fy1, ty2];drawlines(A[0], A[1], A[2], Ahat[0], Ahat[1], Ahat[2], lineId, type,2);}if(fx2tx1&&tx1tx2){ const B[fx2, fy2, ty1];const Bhat[fx2, fy2, ty2];drawl…

Hibernate對象生命周期全解析

Hibernate對象生命周期詳解 Hibernate作為Java領域主流的ORM框架,其核心機制之一就是對持久化對象生命周期的管理。理解Hibernate對象生命周期對于正確使用Hibernate進行數據持久化操作至關重要。Hibernate將對象分為三種主要狀態:瞬時態(Transient)、持久態(Persistent)和游…

MCP 協議使用核心講解

📄 MCP 協議使用核心講解 ? MCP 協議的核心在于以下幾個方面 一、MCP 請求結構(MCPRequest) {"messages": [{"role": "user","content": "幫我查詢一下上海的天氣"}],"tools"…

云計算中的幾何方法:曲面變形的可視化與動畫-AI云計算數值分析和代碼驗證

著重強調微分方程底層的幾何和代數結構,以進行更深入的分析和求解方法。開發結構保持的數值方法,以在計算中保持定性特征。統一符號和數值方法,實現有效的數學建模。利用幾何解釋(如雙曲幾何)求解經典微分方程。利用計…

OpenCV篇——項目(一)OCR識別讀取銀行卡號碼

目錄 信用卡數字識別系統:前言與代碼解析 前言 項目代碼 ??????結果演示 代碼模塊解析 1. 參數解析模塊 2. 輪廓排序函數 3. 圖像預處理模塊 4. 輸入圖像處理流程 5. 卡號區域定位 6. 數字識別與輸出 系統優勢 信用卡數字識別系統:前言…