一個基于 PyTorch 的完整模型訓練流程

一個基于 PyTorch 的完整模型訓練流程

flyfish

訓練步驟具體操作目的
1. 訓練前準備設置隨機種子、配置超參數(batch size、學習率等)、選擇計算設備(CPU/GPU)確保實驗可復現;統一控制訓練關鍵參數;利用硬件加速訓練
2. 數據預處理與加載對數據進行標準化/歸一化、轉換為張量;用DataLoader按batch加載數據統一輸入格式,適配模型要求;高效分批讀取數據,減少內存占用
3. 初始化組件定義模型結構并加載到計算設備;選擇損失函數(如交叉熵)和優化器(如Adam)搭建訓練核心框架:模型負責預測,損失函數量化誤差,優化器負責參數更新
4. 訓練循環(每個epoch)逐輪迭代優化模型參數
4.1 模型切換為訓練模式model.train()啟用dropout、批量歸一化的訓練模式,確保梯度計算有效
4.2 遍歷訓練數據(每個batch)逐批更新參數
4.2.1 清零梯度optimizer.zero_grad()消除歷史梯度累積,確保當前batch的梯度計算獨立
4.2.2 前向傳播output = model(data)用當前模型參數對輸入數據做預測,得到輸出結果
4.2.3 計算損失loss = criterion(output, target)量化預測結果與真實標簽的差距,作為優化目標
4.2.4 反向傳播loss.backward()從損失值反向推導,計算所有可訓練參數的梯度(參數對損失的影響程度)
4.2.5 參數更新optimizer.step()根據梯度,按優化器規則調整模型參數,減小損失
4.3 記錄訓練指標保存每個epoch的訓練損失、準確率跟蹤模型在訓練集上的學習效果
5. 驗證(每個epoch后)評估模型泛化能力
5.1 模型切換為評估模式model.eval()關閉dropout、固定批量歸一化參數,確保評估穩定
5.2 關閉梯度計算with torch.no_grad():減少內存占用,加速驗證過程(無需計算梯度)
5.3 計算驗證指標計算驗證損失、準確率評估模型在未見過的數據上的表現,判斷泛化能力
6. 模型保存保存表現最優的模型參數(如驗證準確率最高時)留存最佳模型,便于后續部署或繼續訓練
7. 訓練后分析繪制損失/準確率曲線,統計訓練時間直觀展示訓練過程,分析模型收斂狀態和效率

前向傳播→計算損失→反向傳播→參數優化

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
import time# 設置隨機種子,保證結果可復現
def set_seed(seed=42):torch.manual_seed(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False# 定義超參數
class Config:def __init__(self):self.batch_size = 64self.learning_rate = 0.001self.epochs = 10self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.save_path = './models'self.log_interval = 100# 定義簡單的卷積神經網絡模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)self.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)  # 展平x = self.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x# 準備數據
def prepare_data(config):# 定義數據變換transform = transforms.Compose([ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST數據集的均值和標準差])# 加載MNIST數據集train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)test_dataset = datasets.MNIST(root='./data',train=False,download=True,transform=transform)# 創建數據加載器train_loader = DataLoader(train_dataset,batch_size=config.batch_size,shuffle=True,num_workers=2)test_loader = DataLoader(test_dataset,batch_size=config.batch_size,shuffle=False,num_workers=2)return train_loader, test_loader# 訓練函數
def train(model, train_loader, criterion, optimizer, config, epoch):model.train()  # 設置為訓練模式train_loss = 0.0correct = 0total = 0# 使用tqdm顯示進度條pbar = tqdm(train_loader, desc=f'Train Epoch {epoch}')for batch_idx, (data, target) in enumerate(pbar):data, target = data.to(config.device), target.to(config.device)# 清零梯度optimizer.zero_grad()# 前向傳播output = model(data)loss = criterion(output, target)# 反向傳播和優化loss.backward()optimizer.step()# 統計訓練信息train_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 打印日志if batch_idx % config.log_interval == 0:pbar.set_postfix({'loss': f'{train_loss/(batch_idx+1):.6f}','accuracy': f'{100.*correct/total:.2f}%'})# 計算平均損失和準確率avg_loss = train_loss / len(train_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy# 驗證函數
def validate(model, test_loader, criterion, config):model.eval()  # 設置為評估模式test_loss = 0.0correct = 0total = 0# 不計算梯度with torch.no_grad():for data, target in test_loader:data, target = data.to(config.device), target.to(config.device)output = model(data)test_loss += criterion(output, target).item()# 統計準確率_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 計算平均損失和準確率avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalprint(f'\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)\n')return avg_loss, accuracy# 保存模型
def save_model(model, optimizer, epoch, loss, config):# 創建保存目錄if not os.path.exists(config.save_path):os.makedirs(config.save_path)# 保存模型狀態torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, f"{config.save_path}/model_epoch_{epoch}.pth")print(f"Model saved to {config.save_path}/model_epoch_{epoch}.pth")# 主函數
def main():# 初始化設置set_seed()config = Config()print(f"Using device: {config.device}")# 準備數據train_loader, test_loader = prepare_data(config)# 初始化模型、損失函數和優化器model = SimpleCNN().to(config.device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)# 記錄訓練過程中的指標history = {'train_loss': [],'train_acc': [],'val_loss': [],'val_acc': []}# 開始訓練start_time = time.time()best_val_acc = 0.0for epoch in range(1, config.epochs + 1):print(f"\nEpoch {epoch}/{config.epochs}")print("-" * 50)# 訓練train_loss, train_acc = train(model, train_loader, criterion, optimizer, config, epoch)history['train_loss'].append(train_loss)history['train_acc'].append(train_acc)# 驗證val_loss, val_acc = validate(model, test_loader, criterion, config)history['val_loss'].append(val_loss)history['val_acc'].append(val_acc)# 保存最佳模型if val_acc > best_val_acc:best_val_acc = val_accsave_model(model, optimizer, epoch, val_loss, config)# 計算總訓練時間end_time = time.time()total_time = end_time - start_timeprint(f"Training complete in {total_time:.0f}s ({total_time/config.epochs:.2f}s per epoch)")print(f"Best validation accuracy: {best_val_acc:.2f}%")# 繪制訓練曲線plot_training_history(history)# 繪制訓練歷史
def plot_training_history(history):plt.figure(figsize=(12, 4))# 繪制損失曲線plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Training Loss')plt.plot(history['val_loss'], label='Validation Loss')plt.title('Loss Curves')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()# 繪制準確率曲線plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Training Accuracy')plt.plot(history['val_acc'], label='Validation Accuracy')plt.title('Accuracy Curves')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.tight_layout()plt.savefig('training_history.png')print("Training history plot saved as 'training_history.png'")plt.show()if __name__ == '__main__':main()
......
--------------------------------------------------
Train Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 124.14it/s, loss=0.024222, accuracy=99.22%]Test set: Average loss: 0.0256, Accuracy: 9926/10000 (99.26%)Model saved to ./models/model_epoch_9.pthEpoch 10/10
--------------------------------------------------
Train Epoch 10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 127.89it/s, loss=0.021473, accuracy=99.31%]Test set: Average loss: 0.0266, Accuracy: 9927/10000 (99.27%)Model saved to ./models/model_epoch_10.pth
Training complete in 85s (8.52s per epoch)
Best validation accuracy: 99.27%
Training history plot saved as 'training_history.png'

在這里插入圖片描述
一、左側:Loss Curves(損失曲線)
藍色:訓練損失(Training Loss)
橙色:驗證損失(Validation Loss)

二、右側:Accuracy Curves(準確率曲線)
藍色:訓練準確率(Training Accuracy)
橙色:驗證準確率(Validation Accuracy)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/93074.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/93074.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/93074.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

ffmpeg,ffplay, vlc,rtsp-simple-server,推拉流命令使用方法,及測試(二)

一、常用命令 ffmpeg 推流命令 : ffmpeg -re -i input.mp4 -c copy -f flv rtmp://39.105.129.233/myapp/ffmpeg -re -i input.mp4 -c copy -f flv rtsp://39.105.129.233/myapp/-re 讀取流 -i 輸入文件 -f # 指定推流formatffplay 拉流命令 : ffplay rtmp://39.105.129.233/m…

使用行為樹控制機器人(三) ——通用端口

文章目錄一、通用端口功能實現1. 功能實現1.1 頭文件定義1.2 源文件實現1.3 main文件實現1.4 tree.xml 實現2. 執行結果使用行為樹控制機器人(一) —— 節點使用行為樹控制機器人(二) —— 黑板使用行為樹控制機器人(三) —— 通用端口有了上述前兩節我們已經可以實現節點間的通…

DataDome反爬蟲驗證技術深度解析:無感、滑塊與設備驗證全攻略

DataDome反爬蟲驗證技術深度解析&#xff1a;無感、滑塊與設備驗證全攻略 隨著網絡安全威脅的不斷演進&#xff0c;企業對數據保護的需求日益增強。DataDome作為業界領先的反爬蟲解決方案&#xff0c;以其三層防護機制在眾多知名網站中得到廣泛應用。本文將深入解析DataDome的…

RabbitMQ 消息轉換器詳解

RabbitMQ 消息轉換器詳解 一、為什么需要消息轉換器&#xff1f; RabbitMQ 的消息傳輸協議只識別字節流&#xff1a; 發送對象時&#xff0c;需要序列化成字節數組接收消息時&#xff0c;需要將字節數組反序列化成對象 如果不使用消息轉換器&#xff1a; 需要手動序列化和反序列…

內網穿透的應用-告別“現場救火”!用 cpolar遠程調試讓內網故障排查進入“云時代”

文章目錄前言**常見困境與解決方案****實際應用價值**1. Remote JVM Debug2. 系統要求與環境準備2.1 服務器環境2.2 本地開發環境3. 內網服務器準備及開始3.1 安裝cpolar配置支持遠程ssh登錄3.1.1 什么是cpolar&#xff1f;3.1.2 安裝cpolar3.1.3 注冊及配置cpolar系統服務3.1.…

Cherryusb UAC例程對接STM32內置ADC和PWM播放音樂和錄音(下)=>UAC+STM32 ADC+PWM實現錄音和播放

1. 程序基本框架整個程序框架, 與之前的一篇文章《Cherryusb UAC例程對接STM32內置ADC和DAC播放音樂和錄音(中)>UACSTM32 ADCDAC實現錄音和播放》基本一致, 只是這次將DAC替換成了PWM。因此這里不再贅述了。 2. audio_v1_mic_speaker_multichan_template.c的修改說明(略) 參…

1 JQ6500語音播報模塊詳解(STM32)

系列文章目錄 文章目錄系列文章目錄前言1 JQ6500簡介2 基本參數說明2.1 硬件參數2.2 模塊管腳說明3 控制方式3.1 通信格式3.2 通信指令4 硬件設計5 軟件設計5.1 main.c5.2 board_config5.2.1board_config.h5.2.2 board_config.c5.3 module_config5.3.1 module_config.h5.3.2 mo…

常用數據分析工具

Tableau丨Power BI丨FineBI丨SQL丨影刀丨Excel丨Python丨 參考視頻&#xff1a;【戴師兄】數據分析有哪些必學工具&#xff1f;2023最新版&#xff01;Tableau丨Power BI丨FineBI丨SQL丨影刀丨Excel丨Python丨課程教程自學攻略_嗶哩嗶哩_bilibili 文檔資料&#xff1a; 【戴師兄…

OBOO鷗柏丨智能會議平板教學查詢一體機交互式觸摸終端招標投標核心標底參數要求

整機參數要求&#xff1a;55寸/65寸/75寸/85-86寸/98寸/100寸/110寸/115寸智能會議平板教學觸控一體機/智慧黑板觸摸屏招標投標核心標底參數要求1、整機屏幕采用≥采用超高清原廠原包原裝工業LCD液晶屏面板&#xff1b;具有高色域&#xff0c;顯示動態視頻、web及3D動畫時&…

無人機在環保監測中的應用:低空經濟發展的智能監測與高效治理

一、行業背景與技術革新 隨著全球環境問題日益嚴峻&#xff0c;傳統環保監測手段已難以滿足現代環境管理的需求。固定監測站點建設成本高、覆蓋范圍有限&#xff0c;地面巡查效率低下且存在安全風險。在此背景下&#xff0c;無人機技術憑借其獨特的空間優勢和技術特性&#xff…

PO、BO、VO、DTO、POJO、DAO、DO基本概念

一、圖解二、相關概念 1、PO&#xff08;Persistant Object - 持久化對象&#xff09; 核心定位&#xff1a; 直接與數據庫表結構一一映射的對象&#xff0c;通常用于 ORM&#xff08;對象關系映射&#xff09;框架&#xff08;如 MyBatis、Hibernate&#xff09;中。 特點&…

todoList清單(HTML+CSS+JavaScript)

&#x1f30f;個人博客主頁&#xff1a; 前言&#xff1a; 前段時間學習了JavaScript&#xff0c;然后寫了一個todoList小項目&#xff0c;現在和大家分享一下我的清單以及如何實現的&#xff0c;希望對大家有所幫助 &#x1f525;&#x1f525;&#x1f525;文章專題&#xff…

Mac M1探索AnythingLLM+Ollama+知識庫問答

AnythingLLM內置 RAG、AI Agent、可視化/無代碼的 Agent 編排&#xff0c;支持多家模型與本地/云端向量庫&#xff0c;并提供多用戶與可嵌入的聊天組件&#xff0c;用來快速驗證“知識 模型 工具”拼成的 AI 應用。 1 AnythingLLM、Ollama準備 1&#xff09;AnythingLLM 打…

【 Navicat Premium 17 完全圖形化新手指南(從零開始)】

Navicat Premium 17 完全圖形化新手指南&#xff08;從零開始&#xff09; 一、準備階段&#xff1a;清理現有環境 1. 刪除已創建的測試數據庫&#xff08;如需重新開始&#xff09;打開Navicat Premium 17 雙擊桌面圖標啟動程序在左側連接面板中找到你的MySQL連接&#xff08;…

Web學習筆記5

Javascript概述1、JS簡介JS是運行在瀏覽器的腳本編程語言&#xff0c;最初用于Web表單的校驗。現在的作用主要有三個&#xff1a;網頁特效、表單驗證、數據交互JS由三部分組成&#xff0c;分別是ECMAscript、DOM、BOM&#xff0c;其中ECMAscript規定了JS的基本語法和規則&#…

部署一個開源的證件照系統

以下數據來自官方網站,記錄下來,方便自己 項目簡介 &#x1f680; 謝謝你對我們的工作感興趣。您可能還想查看我們在圖像領域的其他成果&#xff0c;歡迎來信:zeyi.linswanhub.co. HivisionIDPhoto 旨在開發一種實用、系統性的證件照智能制作算法。 它利用一套完善的AI模型工作…

Linux客戶端利用MinIO對服務器數據進行同步

接上篇 Windows客戶端利用MinIO對服務器數據進行同步 本篇為Linux下 操作&#xff0c;先看下我本地的系統版本 所以我這里下載的話&#xff0c;是AMD64 文檔在這 因為我這里只是需要用到客戶端&#xff0c;獲取數據而已&#xff0c;所以我只需要下載個MC工具用來數據獲取就可以…

Docker 中部署 MySQL 5.7 并遠程連接 Navicat 的完整指南

個人名片 &#x1f393;作者簡介&#xff1a;java領域優質創作者 &#x1f310;個人主頁&#xff1a;碼農阿豪 &#x1f4de;工作室&#xff1a;新空間代碼工作室&#xff08;提供各種軟件服務&#xff09; &#x1f48c;個人郵箱&#xff1a;[2435024119qq.com] &#x1f4f1…

自己動手造個球平衡機器人

你是否曾對那些能夠精妙地保持平衡的機器設備感到好奇&#xff1f; 從無人機到獨輪平衡車&#xff0c;背后都蘊藏著復雜的控制系統。 今天&#xff0c;我們來介紹一個充滿挑戰與樂趣的項目——制作一個球平衡機器人。這不僅是一個酷炫的擺件&#xff0c;更是一次深入學習機器…

21.Linux HTTPS服務

Linux : HTTPS服務協議傳輸方式端口安全性HTTP明文傳輸80無加密&#xff0c;可被竊聽HTTPS加密傳輸443HTTP SSL/TLS 數據加密&#xff08;防竊聽&#xff09;身份認證&#xff08;防偽裝&#xff09;完整性校驗&#xff08;防篡改&#xff09;OpenSSL 證書操作核心命令命令選項…