前文中,只是給了基礎模型:?
PyTorch 實現 CIFAR-10 圖像分類:從數據預處理到模型訓練與評估-CSDN博客
今天我們增加交叉驗證和超參數調優,
先看運行結果:
===== 在測試集上評估最終模型 =====
最終模型在測試集上的準確率:60.14%
最優模型已保存為 'cifar10_best_model.pth'(超參數:{'batch_size': 32, 'epochs': 5, 'lr': 0.01, 'momentum': 0.85})
Process finished with exit code 0
比基礎模型準確率高了一點,
?完整代碼如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from sklearn.model_selection import KFold, ParameterGrid # 用于交叉驗證和超參數網格搜索# --------------------------
# 1. 數據準備(與原代碼一致,但后續會在訓練集內部做交叉驗證)
# --------------------------
# 數據預處理:標準化(與原代碼相同)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 數據集路徑(請替換為你的實際路徑)
data_path = r'D:\workspace_py\deeplean\data'# 加載完整訓練集和測試集(測試集始終不變,用于最終評估)
full_trainset = datasets.CIFAR10(root=data_path, train=True, download=False, transform=transform)
testset = datasets.CIFAR10(root=data_path, train=False, download=False, transform=transform)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# --------------------------
# 2. 定義CNN模型(與原代碼一致)
# --------------------------
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# --------------------------
# 3. 交叉驗證函數(核心新增)
# --------------------------
def cross_validate(model, train_dataset, k_folds=5, epochs=5, lr=0.001, batch_size=32, momentum=0.9):"""5折交叉驗證:將訓練集分成5份,每次用4份訓練,1份驗證,返回平均準確率"""kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42) # 固定隨機種子,結果可復現fold_results = [] # 存儲每折的驗證準確率for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):print(f'\n===== 第 {fold + 1}/{k_folds} 折交叉驗證 =====')# 1. 劃分當前折的訓練集和驗證集train_subset = Subset(train_dataset, train_ids) # 本次訓練用的數據val_subset = Subset(train_dataset, val_ids) # 本次驗證用的數據# 2. 創建數據加載器train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)# 3. 初始化模型和優化器(每折都重新訓練新模型,避免干擾)model_instance = Net() # 重新實例化模型criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model_instance.parameters(), lr=lr, momentum=momentum)# 4. 訓練當前折的模型for epoch in range(epochs):model_instance.train() # 訓練模式running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model_instance(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 每200步打印一次損失(簡化輸出)if i % 200 == 199:print(f'折 {fold + 1},輪次 {epoch + 1},第 {i + 1} 步:平均損失 {running_loss / 200:.3f}')running_loss = 0.0# 5. 在驗證集上評估當前折的模型model_instance.eval() # 驗證模式correct = 0total = 0with torch.no_grad():for data in val_loader:images, labels = dataoutputs = model_instance(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()val_acc = 100 * correct / totalprint(f'第 {fold + 1} 折驗證準確率:{val_acc:.2f}%')fold_results.append(val_acc)# 計算所有折的平均準確率(該超參數組合的最終得分)avg_acc = sum(fold_results) / len(fold_results)print(f'\n===== 該超參數組合的平均驗證準確率:{avg_acc:.2f}% =====')return avg_acc# --------------------------
# 4. 超參數調優(核心新增)
# --------------------------
def hyperparameter_tuning(train_dataset):"""超參數網格搜索:嘗試不同的超參數組合,用交叉驗證選最優"""# 定義要測試的超參數組合(可根據需要增減)param_grid = {'lr': [0.001, 0.01], # 學習率:嘗試兩個值'batch_size': [32, 64], # 批大小:嘗試兩個值'momentum': [0.9, 0.85], # 動量:嘗試兩個值'epochs': [5] # 訓練輪次(固定為5,減少計算量)}best_acc = 0.0best_params = None # 存儲最優超參數# 遍歷所有超參數組合(共 2×2×2=8 種組合)for params in ParameterGrid(param_grid):print(f'\n---------- 測試超參數組合:{params} ----------')# 用交叉驗證評估當前組合的性能current_acc = cross_validate(model=Net(),train_dataset=train_dataset,k_folds=5,epochs=params['epochs'],lr=params['lr'],batch_size=params['batch_size'],momentum=params['momentum'])# 記錄最優組合if current_acc > best_acc:best_acc = current_accbest_params = paramsprint(f'★ 發現更優組合!當前最優準確率:{best_acc:.2f}%')print(f'\n===== 超參數調優完成 =====')print(f'最優超參數:{best_params}')print(f'最優平均驗證準確率:{best_acc:.2f}%')return best_params# --------------------------
# 5. 主函數:執行超參數調優 + 最終訓練 + 測試集評估
# --------------------------
if __name__ == '__main__':# 步驟1:超參數調優(用交叉驗證選最優參數)print('===== 開始超參數調優(這一步比較慢,需要耐心等待)=====')best_params = hyperparameter_tuning(full_trainset)# 步驟2:用最優超參數在完整訓練集上訓練最終模型print('\n===== 用最優超參數訓練最終模型 =====')final_model = Net()criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(final_model.parameters(),lr=best_params['lr'],momentum=best_params['momentum'])train_loader = DataLoader(full_trainset,batch_size=best_params['batch_size'],shuffle=True)# 訓練最終模型(輪次與調優時一致)for epoch in range(best_params['epochs']):final_model.train()running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = final_model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 199:print(f'最終模型訓練 - 輪次 {epoch + 1},第 {i + 1} 步:平均損失 {running_loss / 200:.3f}')running_loss = 0.0# 步驟3:在測試集上評估最終模型(用從未見過的測試數據)print('\n===== 在測試集上評估最終模型 =====')final_model.eval()test_loader = DataLoader(testset, batch_size=32, shuffle=False)correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = final_model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_acc = 100 * correct / totalprint(f'最終模型在測試集上的準確率:{test_acc:.2f}%')# 步驟4:保存最優模型torch.save(final_model.state_dict(), 'cifar10_best_model.pth')print(f"最優模型已保存為 'cifar10_best_model.pth'(超參數:{best_params})")
新增加的功能 :
(1)5 折交叉驗證(cross_validate
函數)
- 作用:把訓練集分成 5 份,每次用 4 份訓練、1 份驗證,重復 5 次,取平均準確率作為 “該參數組合的得分”。
- 白話舉例:相當于學生做 5 套模擬題,每次用 4 套復習、1 套測試,最后算平均分,比只做 1 套題更能反映真實水平。
- 關鍵細節:每折都重新訓練新模型,避免前一折的 “記憶” 影響結果。
(2)超參數調優(hyperparameter_tuning
函數)
- 作用:嘗試不同的超參數組合(如學習率 0.001 vs 0.01,批大小 32 vs 64),用交叉驗證選平均分最高的組合。
- 白話舉例:相當于學生嘗試不同的復習方法(每天學 1 小時 vs 2 小時,刷題 vs 看筆記),通過模擬題平均分找到最適合自己的方法。
- 參數網格:代碼中測試了 8 種組合(2 學習率 ×2 批大小 ×2 動量),可根據需要增減(組合越多,計算時間越長)。
(3)最終模型訓練
- 用調優得到的 “最優超參數” 在完整訓練集上重新訓練模型(之前交叉驗證只用了部分數據)。
- 最后在獨立的測試集上評估(測試集從未參與訓練和調優,相當于 “高考”)。
3. 運行說明
- 計算時間:超參數調優 + 交叉驗證會比原代碼慢很多(8 種組合 ×5 折 ×5 輪訓練),建議在有 GPU 的環境運行。
- 結果解讀:最終會輸出 “最優超參數” 和 “測試集準確率”,這個準確率比原代碼更可信(排除了偶然因素)。
- 可調整項:
param_grid
中的參數可以修改(如增加學習率選項[0.0001, 0.001, 0.01]
),但組合數會增加,計算時間變長。
通過這兩個步驟,模型的性能和可靠性會顯著提升,尤其適合數據量不大的場景(如醫學影像、小數據集)。
交叉驗證
一、什么是交叉驗證?為什么需要它?
1. 核心問題:如何判斷模型好壞?
假設你用一份訓練集訓練模型,然后用同一批數據測試,準確率 90%—— 這能說明模型好嗎?不能!因為模型可能 “死記硬背” 了訓練數據(過擬合),換一批新數據就不行了。
所以需要用 “沒見過的數據” 來驗證模型 —— 但我們只有一份訓練集,怎么辦?
2. 交叉驗證的解決思路
交叉驗證(以代碼中的5 折交叉驗證為例)就像 “多次模擬考試”:
- 把訓練集分成 5 等份(比如 5 個小數據集 A、B、C、D、E)。
- 第一次:用 A、B、C、D 訓練模型,用 E 驗證(看模型在 E 上的準確率)。
- 第二次:用 A、B、C、E 訓練,用 D 驗證。
- 重復 5 次(每次換一份做驗證集),最后取 5 次驗證準確率的平均值。
這樣做的好處:
- 避免 “一次驗證” 的偶然性(比如剛好抽到簡單的驗證集)。
- 更全面地評估模型在不同數據分布上的表現,結果更可靠。
3. 代碼中的交叉驗證實現(cross_validate 函數)
代碼里的cross_validate函數就是干這個的:
- 用KFold(n_splits=5)把訓練集分成 5 份。
- 循環 5 次(每折):
- 每次從 5 份中選 4 份做 “臨時訓練集”,1 份做 “臨時驗證集”。
- 用臨時訓練集訓練模型,用臨時驗證集算準確率。
- 最后返回 5 次準確率的平均值,作為這個模型 / 超參數組合的 “評分”。
二、什么是超參數調優?為什么需要它?
1. 超參數是什么?
超參數是訓練前手動設定的參數,不是模型自己學出來的。比如代碼中的:
- lr(學習率):模型更新參數的 “步長”,太大可能跑過頭,太小可能學太慢。
- batch_size(批大小):每次訓練用多少數據,影響訓練速度和穩定性。
- momentum(動量):優化器的參數,幫助模型更快收斂。
這些參數直接影響模型的訓練效果,但沒有 “標準答案”,需要試出來。
2. 超參數調優的目的
找到一組最好的超參數組合,讓模型的性能(比如準確率)達到最高。
比如:學習率 0.01 + 批大小 32 + 動量 0.9 可能比 學習率 0.001 + 批大小 64 + 動量 0.85 效果更好,我們需要找到這個 “更好” 的組合。
3. 代碼中的超參數調優實現(網格搜索)
代碼用了 “網格搜索” 的方法,原理很簡單:
- 列清單:先定義每個超參數的可能取值(比如lr選 [0.001, 0.01],batch_size選 [32, 64])。
- 組合所有可能:把這些取值的所有搭配列出來(比如 2×2×2=8 種組合)。
- 逐個測試:對每種組合,用交叉驗證算它的 “評分”(平均驗證準確率)。
- 選最優:最后挑出評分最高的組合,作為 “最佳超參數”。
對應代碼中的hyperparameter_tuning函數:
- param_grid定義了要測試的超參數和可能值。
- ParameterGrid自動生成所有組合。
- 循環每個組合,用cross_validate算分,保存最高分的組合。
三、交叉驗證和超參數調優的關系
簡單說:超參數調優是 “找最好的配方”,交叉驗證是 “判斷配方好不好的工具”。
- 沒有交叉驗證,直接用一組數據測試超參數,可能因為 “運氣好” 選錯(比如剛好驗證集簡單)。
- 用交叉驗證評估每個超參數組合,結果更可靠,能真正找到 “穩定好” 的組合。
總結
- 交叉驗證:通過多次 “訓練 - 驗證” 劃分,更可靠地評估模型性能,避免偶然性。
- 超參數調優:通過嘗試不同的超參數組合(網格搜索),結合交叉驗證的評分,找到讓模型表現最好的 “參數配方”。
代碼中,先通過超參數調優找到最好的參數,再用這些參數訓練最終模型,最后在測試集上驗證 —— 這樣得到的模型更可能在新數據上表現良好。