深度學習實戰 04:卷積神經網絡之 VGG16 復現三(訓練)

在后續的系列文章中,我們將逐步深入探討 VGG16 相關的核心內容,具體涵蓋以下幾個方面:

  1. 卷積原理篇:詳細剖析 VGG 的 “堆疊小卷積核” 設計理念,深入解讀為何 3×3×2 卷積操作等效于 5×5 卷積,以及 3×3×3 卷積操作等效于 7×7 卷積。

  2. 架構設計篇:運用 PyTorch 精確定義 VGG16 類,深入解析 “Conv - BN - ReLU - Pooling” 這一標準模塊的構建原理與實現方式。

3. 訓練實戰篇:在小規模醫學影像數據集上對 VGG16 模型進行嚴格驗證,并精心調優如 batch_size、學習率等關鍵超參數,以實現模型性能的最優化。

若您希望免費獲取本系列文章的完整代碼,可通過添加 V 信:18983376561 來獲取。

一、VGG16 架構

VGG16 作為卷積神經網絡中的經典架構,其結構清晰且具有強大的特征提取能力。下面是 VGG16 的架構圖:

二、訓練流程與代碼解析

1. 數據預處理:讓圖像適應模型輸入

CIFAR-10是一個更接近普適物體的彩色圖像數據集。CIFAR-10 是由Hinton 的學生Alex Krizhevsky 和Ilya Sutskever 整理的一個用于識別普適物體的小型數據集。一共包含10 個類別的RGB 彩色圖片:飛機( airplane )、汽車( automobile )、鳥類( bird )、貓( cat )、鹿( deer )、狗( dog )、蛙類( frog )、馬( horse )、船( ship )和卡車( truck )。每個圖片的尺寸為32 × 32 ,每個類別有6000個圖像,數據集中一共有50000 張訓練圖片和10000 張測試圖片。

然而,VGG16 模型原設計是針對 224x224 的圖像輸入。為了使 CIFAR10 數據集能夠適配 VGG16 模型,我們需要對圖像進行預處理。具體而言,通過transforms.Resize((224, 224))將圖像縮放至 224x224 的尺寸,再利用Normalize進行標準化處理,將均值和標準差均設為 0.5,從而使像素值歸一化到 [-1, 1] 區間。以下是關鍵代碼片段:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summaryfrom VGG16 import VGG16
device = torch.device('cuda')transform_train = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

2. 數據加載:高效讀取與批量處理

為了實現數據的高效讀取與批量處理,我們使用DataLoader來加載數據。設置batch_size = 128,以平衡內存使用和訓練效率;同時,設置num_workers = 12,利用多線程技術加速數據讀取過程。對于訓練集,我們將shuffle參數設置為True,打亂數據順序,避免模型記憶數據順序而導致過擬合;對于測試集,將shuffle參數設置為False,保持數據順序,便于結果的復現和評估。以下是具體代碼:

train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=12)test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=False, num_workers=12)

3. 模型構建:調用自定義 VGG16 網絡

在代碼中,我們假設VGG16類已經被正確定義,該類應包含 16 層卷積層和全連接層結構。通過model.to(device)將模型部署到 GPU 上進行訓練,以加速訓練過程。由于 CIFAR10 是一個 10 分類任務,因此模型的最終全連接層輸出維度應為 10。如果沒有可用的 GPU,需要將device設置為cpu,但訓練速度會顯著降低。

4. 訓練配置:損失函數與優化策略

在訓練過程中,我們需要選擇合適的損失函數和優化策略來指導模型的學習。具體配置如下:

  • 損失函數:選用CrossEntropyLoss來處理多分類問題,該損失函數會自動整合 Softmax 計算,簡化了代碼實現。
  • 優化器:選擇隨機梯度下降(SGD)作為優化器,設置學習率lr = 0.1,動量momentum = 0.9以加速收斂過程,同時設置權重衰減weight_decay = 0.0001,采用 L2 正則化防止模型過擬合。
  • 學習率調度器:使用ReduceLROnPlateau根據驗證損失自動調整學習率。當驗證損失連續 5 個 epoch 未下降時,學習率將乘以 0.1(factor = 0.1),這樣可以避免模型陷入局部最優解。以下是相關代碼:
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
model = VGG16().to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.1, patience=5)EPOCHS = 200
for epoch in range(EPOCHS):losses = []running_loss = 0for i, inp in enumerate(trainloader):inputs, labels = inpinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)losses.append(loss.item())loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 0 and i > 0:print(f'Loss [{epoch + 1}, {i}](epoch, minibatch): ', running_loss / 100)running_loss = 0.0avg_loss = sum(losses) / len(losses)scheduler.step(avg_loss)

5. 訓練循環:迭代優化與監控

在 200 個 epoch 的訓練過程中,我們每 100 個批次打印一次平均損失,以便實時監控模型的訓練進度。從輸出日志可以看出,模型初始損失較高(第 1 個 epoch 約為 2.3),隨著訓練的不斷進行,損失逐漸下降,最終損失趨近于 0.001 左右,這表明模型對訓練數據的擬合效果良好。

print('Training Done')
# Loss [1, 100](epoch, minibatch):  3.8564858746528627
# Loss [1, 200](epoch, minibatch):  2.307221896648407
# Loss [1, 300](epoch, minibatch):  2.304955897331238
# Loss [2, 100](epoch, minibatch):  2.3278213500976563
# Loss [2, 200](epoch, minibatch):  2.3041475653648376
# Loss [2, 300](epoch, minibatch):  2.3039899492263793
# ...
# Loss [199, 100](epoch, minibatch):  0.001291145777431666
# Loss [199, 200](epoch, minibatch):  0.0017596399529429619
# Loss [199, 300](epoch, minibatch):  0.0013808918403083225
# Loss [200, 100](epoch, minibatch):  0.0013940928343799896
# Loss [200, 200](epoch, minibatch):  0.0011531753832969116
# Loss [200, 300](epoch, minibatch):  0.001689423452335177

三、訓練結果與問題分析

在訓練完成后,我們可以對模型進行保存和加載操作,以便后續的使用和評估。以下是保存和加載模型的代碼示例:

# 保存整個模型
torch.save(model, 'VGG16.pth')# 或者只保存模型的參數
torch.save(model.state_dict(), 'VGG16_params.pth')# 加載整個模型
loaded_model = torch.load('VGG16.pth')# 或者加載模型的參數
loaded_params = torch.load('VGG16_params.pth')# 如果只加載了模型的參數,需要先將參數加載到模型對象中
# 假設我們有一個新的模型實例
new_model = VGG16(num_classes=10)
new_model.load_state_dict(loaded_params)correct = 0
total = 0with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy on 10,000 test images: ', 100 * (correct / total), '%')

通過測試集計算模型的準確率,我們得到約 86.5% 的結果。然而,需要注意以下兩個問題:

  • CIFAR10 的挑戰:CIFAR10 數據集中的圖像分辨率較低(32x32),圖像細節較少,并且部分類別之間存在一定的相似性(如狗與貓),這對模型的特征提取能力提出了較高的要求。
  • 過擬合風險:訓練損失極低,但測試準確率未能達到 90% 以上,這可能表明模型存在過擬合現象,即模型在訓練集上的表現遠好于在測試集上的表現。

四、優化方向:如何讓模型更上一層樓

1. 數據增強:對抗過擬合的 “核武器”

原代碼未使用數據增強技術,為了提高模型的泛化能力,我們可以添加以下數據增強操作:

  • 隨機裁剪與翻轉:使用transforms.RandomCrop(32, padding = 4)transforms.RandomHorizontalFlip(),增加數據的多樣性,使模型能夠學習到更多不同視角和位置的特征。
  • 顏色擾動:通過transforms.ColorJitter(brightness = 0.1, contrast = 0.1, saturation = 0.1),增強模型對色彩變化的魯棒性,使其能夠適應不同光照和色彩條件下的圖像。
  • Cutout/MixUp:采用隨機遮擋圖像區域(Cutout)或混合樣本(MixUp)的方法,進一步提升模型的泛化能力。

2. 模型調整:更適配小數據集的設計

  • 使用預訓練模型:可以將在 ImageNet 上預訓練的 VGG16 模型權重遷移到 CIFAR10 任務中。但需要注意輸入尺寸的差異(從 224 調整為 32),可以嘗試凍結部分卷積層,只對后續層進行微調。
  • 輕量化改進:VGG16 模型的參數量較大(約 1.38 億),對于 CIFAR10 這樣的小數據集可能會導致過擬合。可以考慮改用更小的網絡,如 VGG11、ResNet18,或者減少通道數(如將起始通道從 64 減少到 32)。
  • 添加 Dropout:在全連接層前插入nn.Dropout(0.5),抑制神經元之間的共適應現象,降低模型過擬合的風險。

3. 優化策略升級

  • 學習率策略:可以改用余弦退火(Cosine Annealing)或周期性學習率(CLR)策略,動態調整學習率,幫助模型逃離鞍點,提高收斂速度和性能。
  • 優化器選擇:嘗試使用 AdamW(結合權重衰減的 Adam)或 RMSprop 等優化器,這些優化器在處理稀疏梯度場景時可能更有效。
  • 混合精度訓練:使用 PyTorch 的torch.cuda.amp模塊進行混合精度訓練,減少顯存占用并加速訓練過程,尤其適用于長周期的訓練任務。

4. 訓練技巧與調參

  • 早停(Early Stopping):監控驗證集損失,若連續多個 epoch 驗證集損失未提升,則提前終止訓練,避免無效的訓練過程。
  • 標簽平滑(Label Smoothing):在損失函數中引入標簽平滑技術,防止模型對單一類別過度自信,提高模型的泛化能力。
  • 調整批量大小:嘗試使用更小的batch_size(如 64)以增加梯度更新的頻率,或者使用更大的批量(如 256)以充分利用 GPU 的并行計算能力。

5. 測試階段優化

  • 測試時增強(TTA):在測試階段,對測試圖像進行多尺度裁剪、翻轉等操作,然后取預測結果的平均值,提升預測的魯棒性。
  • 集成學習:訓練多個不同初始化的 VGG 模型,通過投票或平均法融合這些模型的預測結果,降低模型的隨機性影響,提高整體性能。

五、總結與實踐建議

本次實戰通過在 CIFAR10 數據集上訓練 VGG16 模型,全面展示了深度學習從數據預處理到模型部署的完整流程。86.5% 的準確率僅僅是一個起點,通過采用數據增強、模型輕量化、優化策略調整等一系列優化手段,完全有能力將模型的準確率提升至 90% 以上(CIFAR10 的當前最優模型準確率可達 95% 以上)。

深度學習的學習過程需要理論與實踐緊密結合,希望大家能夠動手實踐,親自體驗模型優化的過程。如果您需要完整代碼或希望進行進一步的討論,歡迎在評論區留言。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/81639.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/81639.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/81639.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Ubuntu 20.04之Docker安裝ES7.17.14和Kibana7.17.14

你需要已經安裝如下運行環境: Ubuntu 20.04 docker 28 docker-compose 1.25 一、手動拉取鏡像 docker pull docker.elastic.co/kibana/kibana:7.17.14docker pull docker.elastic.co/elasticsearch/elasticsearch:7.17.14 或者手動導入鏡像 docker load -i es7.17.14.ta…

實時技術方案對比:SSE vs WebSocket vs Long Polling

早期網站僅展示靜態內容,而如今我們更期望:實時更新、即時聊天、通知推送和動態儀表盤。 那么要如何實現實時的用戶體驗呢?三大經典技術各顯神通: SSE(Server-Sent Events):輕量級單向數據流WebSocket:雙向全雙工通信Long Polling(長輪詢):傳統過渡方案假設目前有三…

測試開發面試題:Python高級特性通俗講解與實戰解析

前言:為什么測試工程師必須掌握Python高級特性? 通俗比喻: 基礎語法就像“錘子”,能敲釘子;高級特性就像“瑞士軍刀”,能應對復雜場景(如自動化框架、高并發測試)。面試官考察點&a…

C語言-9.指針

9.1指針 9.1-1取地址運算:&運算符取得變量的地址 運算符& scanf(“%d”,&i);里的&獲取變量的地址,它們操作數必須是變量int i;printf(“%x”,&i);地址的大小是否與int相同取決于編譯器int i;printf(“%p”,&i); &不能取的地址不能對沒有地址的…

【C++】Vcpkg 介紹及其常見命令

Vcpkg 簡介 Vcpkg 是微軟開發的一個跨平臺的 C/C 依賴管理工具,用于簡化第三方庫的獲取、構建和管理過程。 主要特點 跨平臺支持:支持 Windows、Linux 和 macOS開源免費:MIT 許可證大型庫集合:包含超過 2000 個開源庫簡化集成&…

Unity3D 動畫文件優化總結

前言 在Unity3D中,動畫文件的壓縮和優化是提升性能的重要環節,尤其在移動端或復雜場景中。以下是針對Animation Clip和Animator Controller的優化方法總結: 對惹,這里有一個游戲開發交流小組,希望大家可以點擊進來一…

前端工程的相關管理 git、branch、build

環境配置 標準環境打包 測試版:npm run build-test 預生產:npm run build-preview 正式版:npm run build 建議本地建里一個 .env.development.local 方便和后端聯調時修改配置相關信息。 和 src 同級有一下區分環境的文件: .env.d…

VAPO:視覺-語言對齊預訓練(對象級語義)詳解

簡介 多模態預訓練模型(Vision-Language Pre-training, VLP)近年來取得了飛躍發展。在視覺-語言模型中,模型需要同時理解圖像和文本,這要求模型學習二者之間的語義對應關系。早期方法如 VisualBERT、LXMERT 等往往使用預先提取的圖像區域特征和文本詞嵌入拼接輸入,通過 T…

docker運行Redis

創建目錄 mkdir -p /home/jie/docker/redis/{conf,data,logs}添加權限 chmod -R 777 /home/jie/docker/redis創建配置文件 cat > /home/jie/docker/redis/conf/redis.conf << EOF # 基本配置 bind 0.0.0.0 protected-mode yes port 6379# 安全配置 密碼是root require…

初識 java

目錄 前言 一、jdk&#xff0c;JRE和JVM之間的關系 二、JVM的內存劃分 前言 初步了解 jdk&#xff0c;JRE&#xff0c;JVM 之間的關系&#xff0c;JVM 的內存劃分。 一、jdk&#xff0c;JRE和JVM之間的關系 jdk 是 java 開發工具集&#xff0c;包含JRE&#xff1b; JRE 是…

關于百度地圖JSAPI自定義標注的圖標顯示不完整的問題(其實只是因為圖片尺寸問題)

下載了幾個阿里矢量圖標庫里的圖標作為百度地圖的自定義圖標&#xff0c;結果百度地圖顯示的圖標一直不完整。下載的PNG圖標已經被正常引入到前端代碼&#xff0c;anchor也設置為了圖標底部中心&#xff0c;結果還是顯示不完整。 if (iconUrl) {const icon new mapClass.Icon(…

系統安全及應用深度筆記

系統安全及應用深度筆記 一、賬號安全控制體系構建 &#xff08;一&#xff09;賬戶全生命周期管理 1. 冗余賬戶精細化治理 非登錄賬戶基線核查 Linux 系統默認創建的非登錄賬戶&#xff08;如bin、daemon、mail&#xff09;承擔系統服務支撐功能&#xff0c;其登錄 Shell 必…

02-前端Web開發(JS+Vue+Ajax)

介紹 在前面的課程中&#xff0c;我們已經學習了HTML、CSS的基礎內容&#xff0c;我們知道HTML負責網頁的結構&#xff0c;而CSS負責的是網頁的表現。 而要想讓網頁具備一定的交互效果&#xff0c;具有一定的動作行為&#xff0c;還得通過JavaScript來實現。那今天,我們就來講…

AXXI4總線協議 ------ AXI_FULL協議

https://download.csdn.net/download/mvpkuku/90855619 一、AXI_FULL協議的前提知識 1. 各端口的功能 2. 4K邊界問題 3. outstanding 4.時序仿真體驗 可通過VIVADO自帶ADMA工程觀察仿真波形圖 二、FPGA實現 &#xff08;主要用于讀寫DDR&#xff09; 1.功能模塊及框架 將…

React系列——nvm、node、npm、yarn(MAC)

nvm&#xff0c;node&#xff0c;npm之間的區別 1、nvm&#xff1a;nodejs版本管理工具。nvm 可以管理很多 node 版本和 npm 版本。 2、nodejs&#xff1a;在項目開發時的所需要的代碼庫 3、npm&#xff1a;nodejs包管理工具。nvm、nodejs、npm的關系 nvm 管理 nodejs 和 npm…

2025年AI與網絡安全的終極博弈:沖擊、重構與生存法則

引言 2025年&#xff0c;生成式AI的推理速度突破每秒千萬次&#xff0c;網絡安全行業正經歷前所未有的范式革命。攻擊者用AI批量生成惡意代碼&#xff0c;防御者用AI構建智能護盾&#xff0c;這場技術軍備競賽正重塑行業規則——60%的傳統安全崗位面臨轉型&#xff0c;70%的防…

【Android】Android 實現一個依賴注入的注解

Android 實現一個依賴注入的注解 &#x1f3af; 目標功能 自定義注解 Inject創建一個 Injector 類&#xff0c;用來掃描并注入對象支持 Activity 或其他類中的字段注入 &#x1f9e9; 步驟一&#xff1a;定義注解 import java.lang.annotation.ElementType; import java.lan…

Spring Boot與Kafka集成實踐:從入門到實戰

Spring Boot與Kafka集成實踐 引言 在現代分布式系統中&#xff0c;消息隊列是不可或缺的組件之一。Apache Kafka作為一種高吞吐量的分布式消息系統&#xff0c;廣泛應用于日志收集、流處理、事件驅動架構等場景。Spring Boot作為Java生態中最流行的微服務框架&#xff0c;提供…

ubuntu的虛擬機上的網絡圖標沒有了

非正常的關機導致虛擬機連接xshell連接不上&#xff0c;ping也ping不通。網絡的圖標也沒有了。 記錄一下解決步驟 1、重啟服務 sudo systemctl restart NetworkManager 2、圖標顯示 sudo nmcli network off sudo nmcli network on 3、sudo dhclient ens33 //(網卡) …

生產者 - 消費者模式實現方法整理

一、Channels &#xff08;一&#xff09;使用場景 適用于高并發、大數據量傳輸&#xff0c;且需要異步操作的場景&#xff0c;如實時數據處理系統。 &#xff08;二&#xff09;使用方法 創建 Channel<T>&#xff08;無界&#xff09;或 BoundedChannel<T>&…