一個基于 PyTorch 的完整模型訓練流程
flyfish
訓練步驟 | 具體操作 | 目的 |
---|---|---|
1. 訓練前準備 | 設置隨機種子、配置超參數(batch size、學習率等)、選擇計算設備(CPU/GPU) | 確保實驗可復現;統一控制訓練關鍵參數;利用硬件加速訓練 |
2. 數據預處理與加載 | 對數據進行標準化/歸一化、轉換為張量;用DataLoader按batch加載數據 | 統一輸入格式,適配模型要求;高效分批讀取數據,減少內存占用 |
3. 初始化組件 | 定義模型結構并加載到計算設備;選擇損失函數(如交叉熵)和優化器(如Adam) | 搭建訓練核心框架:模型負責預測,損失函數量化誤差,優化器負責參數更新 |
4. 訓練循環(每個epoch) | 逐輪迭代優化模型參數 | |
4.1 模型切換為訓練模式 | model.train() | 啟用dropout、批量歸一化的訓練模式,確保梯度計算有效 |
4.2 遍歷訓練數據(每個batch) | 逐批更新參數 | |
4.2.1 清零梯度 | optimizer.zero_grad() | 消除歷史梯度累積,確保當前batch的梯度計算獨立 |
4.2.2 前向傳播 | output = model(data) | 用當前模型參數對輸入數據做預測,得到輸出結果 |
4.2.3 計算損失 | loss = criterion(output, target) | 量化預測結果與真實標簽的差距,作為優化目標 |
4.2.4 反向傳播 | loss.backward() | 從損失值反向推導,計算所有可訓練參數的梯度(參數對損失的影響程度) |
4.2.5 參數更新 | optimizer.step() | 根據梯度,按優化器規則調整模型參數,減小損失 |
4.3 記錄訓練指標 | 保存每個epoch的訓練損失、準確率 | 跟蹤模型在訓練集上的學習效果 |
5. 驗證(每個epoch后) | 評估模型泛化能力 | |
5.1 模型切換為評估模式 | model.eval() | 關閉dropout、固定批量歸一化參數,確保評估穩定 |
5.2 關閉梯度計算 | with torch.no_grad(): | 減少內存占用,加速驗證過程(無需計算梯度) |
5.3 計算驗證指標 | 計算驗證損失、準確率 | 評估模型在未見過的數據上的表現,判斷泛化能力 |
6. 模型保存 | 保存表現最優的模型參數(如驗證準確率最高時) | 留存最佳模型,便于后續部署或繼續訓練 |
7. 訓練后分析 | 繪制損失/準確率曲線,統計訓練時間 | 直觀展示訓練過程,分析模型收斂狀態和效率 |
前向傳播→計算損失→反向傳播→參數優化
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
import time# 設置隨機種子,保證結果可復現
def set_seed(seed=42):torch.manual_seed(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False# 定義超參數
class Config:def __init__(self):self.batch_size = 64self.learning_rate = 0.001self.epochs = 10self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.save_path = './models'self.log_interval = 100# 定義簡單的卷積神經網絡模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)self.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7) # 展平x = self.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x# 準備數據
def prepare_data(config):# 定義數據變換transform = transforms.Compose([ToTensor(),transforms.Normalize((0.1307,), (0.3081,)) # MNIST數據集的均值和標準差])# 加載MNIST數據集train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)test_dataset = datasets.MNIST(root='./data',train=False,download=True,transform=transform)# 創建數據加載器train_loader = DataLoader(train_dataset,batch_size=config.batch_size,shuffle=True,num_workers=2)test_loader = DataLoader(test_dataset,batch_size=config.batch_size,shuffle=False,num_workers=2)return train_loader, test_loader# 訓練函數
def train(model, train_loader, criterion, optimizer, config, epoch):model.train() # 設置為訓練模式train_loss = 0.0correct = 0total = 0# 使用tqdm顯示進度條pbar = tqdm(train_loader, desc=f'Train Epoch {epoch}')for batch_idx, (data, target) in enumerate(pbar):data, target = data.to(config.device), target.to(config.device)# 清零梯度optimizer.zero_grad()# 前向傳播output = model(data)loss = criterion(output, target)# 反向傳播和優化loss.backward()optimizer.step()# 統計訓練信息train_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 打印日志if batch_idx % config.log_interval == 0:pbar.set_postfix({'loss': f'{train_loss/(batch_idx+1):.6f}','accuracy': f'{100.*correct/total:.2f}%'})# 計算平均損失和準確率avg_loss = train_loss / len(train_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy# 驗證函數
def validate(model, test_loader, criterion, config):model.eval() # 設置為評估模式test_loss = 0.0correct = 0total = 0# 不計算梯度with torch.no_grad():for data, target in test_loader:data, target = data.to(config.device), target.to(config.device)output = model(data)test_loss += criterion(output, target).item()# 統計準確率_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 計算平均損失和準確率avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalprint(f'\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)\n')return avg_loss, accuracy# 保存模型
def save_model(model, optimizer, epoch, loss, config):# 創建保存目錄if not os.path.exists(config.save_path):os.makedirs(config.save_path)# 保存模型狀態torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, f"{config.save_path}/model_epoch_{epoch}.pth")print(f"Model saved to {config.save_path}/model_epoch_{epoch}.pth")# 主函數
def main():# 初始化設置set_seed()config = Config()print(f"Using device: {config.device}")# 準備數據train_loader, test_loader = prepare_data(config)# 初始化模型、損失函數和優化器model = SimpleCNN().to(config.device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)# 記錄訓練過程中的指標history = {'train_loss': [],'train_acc': [],'val_loss': [],'val_acc': []}# 開始訓練start_time = time.time()best_val_acc = 0.0for epoch in range(1, config.epochs + 1):print(f"\nEpoch {epoch}/{config.epochs}")print("-" * 50)# 訓練train_loss, train_acc = train(model, train_loader, criterion, optimizer, config, epoch)history['train_loss'].append(train_loss)history['train_acc'].append(train_acc)# 驗證val_loss, val_acc = validate(model, test_loader, criterion, config)history['val_loss'].append(val_loss)history['val_acc'].append(val_acc)# 保存最佳模型if val_acc > best_val_acc:best_val_acc = val_accsave_model(model, optimizer, epoch, val_loss, config)# 計算總訓練時間end_time = time.time()total_time = end_time - start_timeprint(f"Training complete in {total_time:.0f}s ({total_time/config.epochs:.2f}s per epoch)")print(f"Best validation accuracy: {best_val_acc:.2f}%")# 繪制訓練曲線plot_training_history(history)# 繪制訓練歷史
def plot_training_history(history):plt.figure(figsize=(12, 4))# 繪制損失曲線plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Training Loss')plt.plot(history['val_loss'], label='Validation Loss')plt.title('Loss Curves')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()# 繪制準確率曲線plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Training Accuracy')plt.plot(history['val_acc'], label='Validation Accuracy')plt.title('Accuracy Curves')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.tight_layout()plt.savefig('training_history.png')print("Training history plot saved as 'training_history.png'")plt.show()if __name__ == '__main__':main()
......
--------------------------------------------------
Train Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 124.14it/s, loss=0.024222, accuracy=99.22%]Test set: Average loss: 0.0256, Accuracy: 9926/10000 (99.26%)Model saved to ./models/model_epoch_9.pthEpoch 10/10
--------------------------------------------------
Train Epoch 10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 127.89it/s, loss=0.021473, accuracy=99.31%]Test set: Average loss: 0.0266, Accuracy: 9927/10000 (99.27%)Model saved to ./models/model_epoch_10.pth
Training complete in 85s (8.52s per epoch)
Best validation accuracy: 99.27%
Training history plot saved as 'training_history.png'
一、左側:Loss Curves(損失曲線)
藍色:訓練損失(Training Loss)
橙色:驗證損失(Validation Loss)
二、右側:Accuracy Curves(準確率曲線)
藍色:訓練準確率(Training Accuracy)
橙色:驗證準確率(Validation Accuracy)