增加交叉驗證和超參數調優

前文中,只是給了基礎模型:?

PyTorch 實現 CIFAR-10 圖像分類:從數據預處理到模型訓練與評估-CSDN博客

今天我們增加交叉驗證和超參數調優,

先看運行結果:
===== 在測試集上評估最終模型 =====
最終模型在測試集上的準確率:60.14%
最優模型已保存為 'cifar10_best_model.pth'(超參數:{'batch_size': 32, 'epochs': 5, 'lr': 0.01, 'momentum': 0.85})

Process finished with exit code 0
比基礎模型準確率高了一點,

?完整代碼如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from sklearn.model_selection import KFold, ParameterGrid  # 用于交叉驗證和超參數網格搜索# --------------------------
# 1. 數據準備(與原代碼一致,但后續會在訓練集內部做交叉驗證)
# --------------------------
# 數據預處理:標準化(與原代碼相同)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 數據集路徑(請替換為你的實際路徑)
data_path = r'D:\workspace_py\deeplean\data'# 加載完整訓練集和測試集(測試集始終不變,用于最終評估)
full_trainset = datasets.CIFAR10(root=data_path, train=True, download=False, transform=transform)
testset = datasets.CIFAR10(root=data_path, train=False, download=False, transform=transform)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# --------------------------
# 2. 定義CNN模型(與原代碼一致)
# --------------------------
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# --------------------------
# 3. 交叉驗證函數(核心新增)
# --------------------------
def cross_validate(model, train_dataset, k_folds=5, epochs=5, lr=0.001, batch_size=32, momentum=0.9):"""5折交叉驗證:將訓練集分成5份,每次用4份訓練,1份驗證,返回平均準確率"""kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)  # 固定隨機種子,結果可復現fold_results = []  # 存儲每折的驗證準確率for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):print(f'\n===== 第 {fold + 1}/{k_folds} 折交叉驗證 =====')# 1. 劃分當前折的訓練集和驗證集train_subset = Subset(train_dataset, train_ids)  # 本次訓練用的數據val_subset = Subset(train_dataset, val_ids)  # 本次驗證用的數據# 2. 創建數據加載器train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)# 3. 初始化模型和優化器(每折都重新訓練新模型,避免干擾)model_instance = Net()  # 重新實例化模型criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model_instance.parameters(), lr=lr, momentum=momentum)# 4. 訓練當前折的模型for epoch in range(epochs):model_instance.train()  # 訓練模式running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model_instance(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 每200步打印一次損失(簡化輸出)if i % 200 == 199:print(f'折 {fold + 1},輪次 {epoch + 1},第 {i + 1} 步:平均損失 {running_loss / 200:.3f}')running_loss = 0.0# 5. 在驗證集上評估當前折的模型model_instance.eval()  # 驗證模式correct = 0total = 0with torch.no_grad():for data in val_loader:images, labels = dataoutputs = model_instance(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()val_acc = 100 * correct / totalprint(f'第 {fold + 1} 折驗證準確率:{val_acc:.2f}%')fold_results.append(val_acc)# 計算所有折的平均準確率(該超參數組合的最終得分)avg_acc = sum(fold_results) / len(fold_results)print(f'\n===== 該超參數組合的平均驗證準確率:{avg_acc:.2f}% =====')return avg_acc# --------------------------
# 4. 超參數調優(核心新增)
# --------------------------
def hyperparameter_tuning(train_dataset):"""超參數網格搜索:嘗試不同的超參數組合,用交叉驗證選最優"""# 定義要測試的超參數組合(可根據需要增減)param_grid = {'lr': [0.001, 0.01],  # 學習率:嘗試兩個值'batch_size': [32, 64],  # 批大小:嘗試兩個值'momentum': [0.9, 0.85],  # 動量:嘗試兩個值'epochs': [5]  # 訓練輪次(固定為5,減少計算量)}best_acc = 0.0best_params = None  # 存儲最優超參數# 遍歷所有超參數組合(共 2×2×2=8 種組合)for params in ParameterGrid(param_grid):print(f'\n---------- 測試超參數組合:{params} ----------')# 用交叉驗證評估當前組合的性能current_acc = cross_validate(model=Net(),train_dataset=train_dataset,k_folds=5,epochs=params['epochs'],lr=params['lr'],batch_size=params['batch_size'],momentum=params['momentum'])# 記錄最優組合if current_acc > best_acc:best_acc = current_accbest_params = paramsprint(f'★ 發現更優組合!當前最優準確率:{best_acc:.2f}%')print(f'\n===== 超參數調優完成 =====')print(f'最優超參數:{best_params}')print(f'最優平均驗證準確率:{best_acc:.2f}%')return best_params# --------------------------
# 5. 主函數:執行超參數調優 + 最終訓練 + 測試集評估
# --------------------------
if __name__ == '__main__':# 步驟1:超參數調優(用交叉驗證選最優參數)print('===== 開始超參數調優(這一步比較慢,需要耐心等待)=====')best_params = hyperparameter_tuning(full_trainset)# 步驟2:用最優超參數在完整訓練集上訓練最終模型print('\n===== 用最優超參數訓練最終模型 =====')final_model = Net()criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(final_model.parameters(),lr=best_params['lr'],momentum=best_params['momentum'])train_loader = DataLoader(full_trainset,batch_size=best_params['batch_size'],shuffle=True)# 訓練最終模型(輪次與調優時一致)for epoch in range(best_params['epochs']):final_model.train()running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = final_model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 199:print(f'最終模型訓練 - 輪次 {epoch + 1},第 {i + 1} 步:平均損失 {running_loss / 200:.3f}')running_loss = 0.0# 步驟3:在測試集上評估最終模型(用從未見過的測試數據)print('\n===== 在測試集上評估最終模型 =====')final_model.eval()test_loader = DataLoader(testset, batch_size=32, shuffle=False)correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = final_model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_acc = 100 * correct / totalprint(f'最終模型在測試集上的準確率:{test_acc:.2f}%')# 步驟4:保存最優模型torch.save(final_model.state_dict(), 'cifar10_best_model.pth')print(f"最優模型已保存為 'cifar10_best_model.pth'(超參數:{best_params})")

新增加的功能 :

(1)5 折交叉驗證(cross_validate函數)
  • 作用:把訓練集分成 5 份,每次用 4 份訓練、1 份驗證,重復 5 次,取平均準確率作為 “該參數組合的得分”。
  • 白話舉例:相當于學生做 5 套模擬題,每次用 4 套復習、1 套測試,最后算平均分,比只做 1 套題更能反映真實水平。
  • 關鍵細節:每折都重新訓練新模型,避免前一折的 “記憶” 影響結果。
(2)超參數調優(hyperparameter_tuning函數)
  • 作用:嘗試不同的超參數組合(如學習率 0.001 vs 0.01,批大小 32 vs 64),用交叉驗證選平均分最高的組合。
  • 白話舉例:相當于學生嘗試不同的復習方法(每天學 1 小時 vs 2 小時,刷題 vs 看筆記),通過模擬題平均分找到最適合自己的方法。
  • 參數網格:代碼中測試了 8 種組合(2 學習率 ×2 批大小 ×2 動量),可根據需要增減(組合越多,計算時間越長)。
(3)最終模型訓練
  • 用調優得到的 “最優超參數” 在完整訓練集上重新訓練模型(之前交叉驗證只用了部分數據)。
  • 最后在獨立的測試集上評估(測試集從未參與訓練和調優,相當于 “高考”)。
3. 運行說明
  • 計算時間:超參數調優 + 交叉驗證會比原代碼慢很多(8 種組合 ×5 折 ×5 輪訓練),建議在有 GPU 的環境運行。
  • 結果解讀:最終會輸出 “最優超參數” 和 “測試集準確率”,這個準確率比原代碼更可信(排除了偶然因素)。
  • 可調整項param_grid中的參數可以修改(如增加學習率選項[0.0001, 0.001, 0.01]),但組合數會增加,計算時間變長。

通過這兩個步驟,模型的性能和可靠性會顯著提升,尤其適合數據量不大的場景(如醫學影像、小數據集)。

交叉驗證

一、什么是交叉驗證?為什么需要它?

1. 核心問題:如何判斷模型好壞?

假設你用一份訓練集訓練模型,然后用同一批數據測試,準確率 90%—— 這能說明模型好嗎?不能!因為模型可能 “死記硬背” 了訓練數據(過擬合),換一批新數據就不行了。

所以需要用 “沒見過的數據” 來驗證模型 —— 但我們只有一份訓練集,怎么辦?

2. 交叉驗證的解決思路

交叉驗證(以代碼中的5 折交叉驗證為例)就像 “多次模擬考試”:

  1. 把訓練集分成 5 等份(比如 5 個小數據集 A、B、C、D、E)。
  2. 第一次:用 A、B、C、D 訓練模型,用 E 驗證(看模型在 E 上的準確率)。
  3. 第二次:用 A、B、C、E 訓練,用 D 驗證。
  4. 重復 5 次(每次換一份做驗證集),最后取 5 次驗證準確率的平均值。

這樣做的好處:

  1. 避免 “一次驗證” 的偶然性(比如剛好抽到簡單的驗證集)。
  2. 更全面地評估模型在不同數據分布上的表現,結果更可靠。
3. 代碼中的交叉驗證實現(cross_validate 函數)

代碼里的cross_validate函數就是干這個的:

  1. KFold(n_splits=5)把訓練集分成 5 份。
  2. 循環 5 次(每折):
    1. 每次從 5 份中選 4 份做 “臨時訓練集”,1 份做 “臨時驗證集”。
    2. 用臨時訓練集訓練模型,用臨時驗證集算準確率。
  3. 最后返回 5 次準確率的平均值,作為這個模型 / 超參數組合的 “評分”。

二、什么是超參數調優?為什么需要它?

1. 超參數是什么?

超參數是訓練前手動設定的參數,不是模型自己學出來的。比如代碼中的:

  1. lr(學習率):模型更新參數的 “步長”,太大可能跑過頭,太小可能學太慢。
  2. batch_size(批大小):每次訓練用多少數據,影響訓練速度和穩定性。
  3. momentum(動量):優化器的參數,幫助模型更快收斂。

這些參數直接影響模型的訓練效果,但沒有 “標準答案”,需要試出來。

2. 超參數調優的目的

找到一組最好的超參數組合,讓模型的性能(比如準確率)達到最高。

比如:學習率 0.01 + 批大小 32 + 動量 0.9 可能比 學習率 0.001 + 批大小 64 + 動量 0.85 效果更好,我們需要找到這個 “更好” 的組合。

3. 代碼中的超參數調優實現(網格搜索)

代碼用了 “網格搜索” 的方法,原理很簡單:

  1. 列清單:先定義每個超參數的可能取值(比如lr選 [0.001, 0.01],batch_size選 [32, 64])。
  2. 組合所有可能:把這些取值的所有搭配列出來(比如 2×2×2=8 種組合)。
  3. 逐個測試:對每種組合,用交叉驗證算它的 “評分”(平均驗證準確率)。
  4. 選最優:最后挑出評分最高的組合,作為 “最佳超參數”。

對應代碼中的hyperparameter_tuning函數:

  1. param_grid定義了要測試的超參數和可能值。
  2. ParameterGrid自動生成所有組合。
  3. 循環每個組合,用cross_validate算分,保存最高分的組合。

三、交叉驗證和超參數調優的關系

簡單說:超參數調優是找最好的配方”,交叉驗證是 “判斷配方好不好的工具”

  1. 沒有交叉驗證,直接用一組數據測試超參數,可能因為 “運氣好” 選錯(比如剛好驗證集簡單)。
  2. 用交叉驗證評估每個超參數組合,結果更可靠,能真正找到 “穩定好” 的組合。

總結

  1. 交叉驗證:通過多次 “訓練 - 驗證” 劃分,更可靠地評估模型性能,避免偶然性。
  2. 超參數調優:通過嘗試不同的超參數組合(網格搜索),結合交叉驗證的評分,找到讓模型表現最好的 “參數配方”。

代碼中,先通過超參數調優找到最好的參數,再用這些參數訓練最終模型,最后在測試集上驗證 —— 這樣得到的模型更可能在新數據上表現良好。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/89909.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/89909.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/89909.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

解決pip指令超時問題

用pip指令,在安裝Django3.2時報錯,詢問ChatGpt后得到的解決方案pip 下載超時 —— 是 當前網絡連接到 PyPI 官方源太慢或不穩定,甚至可能連不上了,而 pip 默認的超時時間又太短,就導致了中途失敗:ReadTimeo…

Oracle定時清理歸檔日志

線上歸檔日志滿了,系統直接崩了,為解決這個問題,創建每月定時清理歸檔日志。 創建文件名 delete_archivelog.rman CONFIGURE ARCHIVELOG DELETION POLICY CLEAR; RUN {ALLOCATE CHANNEL c1 TYPE DISK;DELETE ARCHIVELOG ALL COMPLETED BEFORE…

ELF 文件操作手冊

目錄 一、ELF 文件結構概述 二、查看 ELF 文件頭信息 1、命令選項 2、示例輸出 3、內核數據結構 三、ELF 程序頭表 1、命令選項 2、示例輸出 3、關鍵說明 4、內核數據結構 四、ELF 節頭表詳解 查看節頭表信息 1、命令選項 2、示例輸出 3、標志說明 4、重要節說…

深入淺出Python函數:參數傳遞、作用域與案例詳解

🙋?♀? 博主介紹:顏顏yan_ ? 本期精彩:深入淺出Python函數:參數傳遞、作用域與案例詳解 🏆 熱門專欄:零基礎玩轉Python爬蟲:手把手教你成為數據獵人 🚀 專欄亮點:零基…

ps aux 和 ps -ef

在 Linux/Unix 系統中,ps aux 和 ps -ef 都是用于查看進程信息的命令,結合 grep node 可以篩選出與 Node.js 相關的進程。它們的核心功能相似,但在輸出格式和選項含義上有區別:1. 命令對比命令含義主要區別ps auxBSD 風格語法列更…

Spark ML 之 LSH

src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala test("approxSimilarityJoin for self join") {val data = {for (i <- 0 until 24) yield Vectors

關鍵成功因素法(CSF)深度解析:從戰略目標到數據字典

關鍵成功因素法由John Rockart提出&#xff0c;用于信息系統規劃&#xff0c;幫助企業識別影響系統成功的關鍵因素&#xff0c;從而確定信息需求&#xff0c;指導信息技術管理。該方法通過識別關鍵成功因素&#xff0c;找出關鍵信息集合&#xff0c;確定系統開發優先級&#xf…

Django母嬰商城項目實踐(六)- Models模型之ORM操作

6、Models模型操作 1 ORM概述 介紹 Django對數據進行增刪改操作是借助內置的ORM框架(Object Relational Mapping,對象關系映射)所提供的API方法實現的,允許你使用類和對象對數據庫進行操作,從而避免通過SQL語句操作數據庫。 簡單來說,ORM框架的數據操作API是在 QuerySet…

【PTA數據結構 | C語言版】哥尼斯堡的“七橋問題”

本專欄持續輸出數據結構題目集&#xff0c;歡迎訂閱。 文章目錄題目代碼題目 哥尼斯堡是位于普累格河上的一座城市&#xff0c;它包含兩個島嶼及連接它們的七座橋&#xff0c;如下圖所示。 可否走過這樣的七座橋&#xff0c;而且每橋只走過一次&#xff1f;瑞士數學家歐拉(Leo…

Redis 詳解:從入門到進階

文章目錄前言一、什么是 Redis&#xff1f;二、Redis 使用場景1. 緩存熱點數據2. 消息隊列3. 分布式鎖4. 限流與防刷5. 計數器、排行榜三、緩存三大問題&#xff1a;雪崩 / 穿透 / 擊穿1. ?? 緩存雪崩&#xff08;Cache Avalanche&#xff09;2. &#x1f50d; 緩存穿透&…

QCustomPlot 使用教程

下載網址&#xff1a;官方網站&#xff1a;http://www.qcustomplot.com/我的環境是 window10 qt5.9.9 下載后&#xff0c;官網提供了很多例子。可以作為參考直接運行自己如何使用&#xff1a;第一步&#xff1a;使用QCustomPlot非常簡單&#xff0c;只需要把qcustomplot.cpp和…

基于springboot+mysql的作業管理系統(源碼+論文)

一、開發環境 1 Spring Boot框架簡介 描述&#xff1a; 簡化開發&#xff1a;Spring Boot旨在簡化新Spring應用的初始搭建和開發過程。配置方式&#xff1a;采用特定的配置方式&#xff0c;減少樣板化配置&#xff0c;使開發人員無需定義繁瑣的配置。開發工具&#xff1a;可…

LVS 集群技術基礎

LVS(linux virual server)LVS集群技術---NAT模式一.準備四臺虛擬機1.client(eth0ip:172.254.100)2.lvs(eth0ip:172.254.200;eth1ip:192.168.0.200)3.rs1(eht0ip:192.168.0.10)4.rs2(eth0ip:192.168.0.20)二&#xff1a;在rs1和rs2安裝httpd功能dnf/yum install htppd -y三&…

Oracle RU19.28補丁發布,一鍵升級穩

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; 作者&#xff1a;IT邦德 中國DBA聯盟(ACDU)成員&#xff0c;15年DBA工作經驗 Oracle、PostgreSQL ACE CSDN博客專家及B站知名UP主&#xff0c;全網粉絲15萬 擅長主流Oracle、MySQL、PG、高斯及…

lvs 集群技術

LVS概念LVS&#xff1a;Linux Virtual Server&#xff0c;負載調度器&#xff0c;是一種基于Linux操作系統內核的高性能、高可用網絡服務負載均衡解決方案。LVS工作原理基于網絡層&#xff08;四層&#xff0c;傳輸層&#xff09;的負載均衡技術&#xff0c;它通過內核級別的IP…

AR巡檢和傳統巡檢的區別

隨著工業4.0時代的到來&#xff0c;數字化轉型逐漸成為各行各業提升效率、保障安全和降低成本的關鍵。而在這一轉型過程中&#xff0c;巡檢工作作為確保設備穩定運行的重要環節&#xff0c;逐步從傳統方式走向智能化、數字化。尤其是增強現實&#xff08;AR&#xff09;技術的引…

Axure設計設備外殼 - AxureMost 落葵網

在UI設計中&#xff0c;設備外殼&#xff08;硬件外殼與界面中的“虛擬外殼”&#xff09;和背景是構成視覺體驗的核心元素&#xff0c;它們不僅影響美觀&#xff0c;更直接關聯用戶對功能的理解和操作效率。以下從設計角度詳細解析其作用與使用邏輯&#xff1a; 一、設備外殼&…

基于深度學習的電信號分類識別與混淆矩陣分析

基于深度學習的電信號分類識別與混淆矩陣分析 1. 引言 1.1 研究背景與意義 電信號分類識別是信號處理領域的重要研究方向,在醫療診斷、工業檢測、通信系統等多個領域有著廣泛的應用。傳統的電信號分類方法主要依賴于手工提取特征和淺層機器學習模型,但這些方法往往難以捕捉…

Git 和Gitee遠程連接 上傳和克隆

第一步創建遠程庫第二步初始化本地庫創建鏈接刪掉.idea 和target(這兩個沒用運行就自動生成了)右鍵空白處選擇Git Bash Here 初始化本地庫git init建立遠程連接建立連接這里是我的地址&#xff0c;后面拼接你的地址git remote add origin https://gitee.com/liu-qing_liang/git…

零基礎100天CNN實戰計劃:用Python從入門到圖像識別高手

一、為什么你需要這份100天CNN學習計劃&#xff1f; 在人工智能領域&#xff0c;卷積神經網絡&#xff08;CNN&#xff09; 是計算機視覺的基石技術。無論是人臉識別、醫學影像分析還是自動駕駛&#xff0c;CNN都扮演著核心角色。但對于初學者來說&#xff0c;面對復雜的數學公…