使用PyTorch構建卷積神經網絡(CNN)實現CIFAR-10圖像分類

在計算機視覺領域,卷積神經網絡(CNN)已經成為處理圖像識別任務的事實標準。從人臉識別到醫學影像分析,CNN展現出了驚人的能力。本文將詳細介紹如何使用PyTorch框架構建一個CNN模型,并在經典的CIFAR-10數據集上進行圖像分類任務。

CIFAR-10數據集包含10個類別的60000張32x32彩色圖像,每個類別有6000張圖像,其中50000張用于訓練,10000張用于測試。這個數據集雖然圖像尺寸較小,但包含了足夠的復雜性,是學習計算機視覺和深度學習的理想起點。

一、卷積神經網絡基礎

1.1 卷積層

卷積層是CNN的核心組件,它通過卷積核(濾波器)在輸入圖像上滑動,計算局部區域的點積。PyTorch中的nn.Conv2d實現了這一功能:

self.conv1 = nn.Conv2d(3, 32, 3, padding=1)

這行代碼創建了一個卷積層,參數含義如下:

  • 輸入通道數:3(對應RGB三通道)

  • 輸出通道數:32(即使用32個不同的濾波器)

  • 卷積核大小:3×3

  • padding=1保持空間維度不變

卷積層能夠自動學習從簡單邊緣到復雜模式的各種特征,這種層次化的特征學習是CNN強大性能的關鍵。

1.2 池化層

池化層(通常是最大池化)用于降低特征圖的空間維度:

self.pool = nn.MaxPool2d(2, 2)

最大池化取2×2窗口中的最大值,步長為2,這會使特征圖尺寸減半。池化的作用包括:

  1. 減少計算量和參數數量

  2. 增強特征的位置不變性

  3. 防止過擬合

1.3 全連接層

在多個卷積和池化層之后,我們使用全連接層進行分類:

self.fc1 = nn.Linear(128 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10)

第一個全連接層將展平的特征向量(128×4×4)映射到512維空間,第二個則輸出10維向量對應10個類別。

二、數據準備與預處理

2.1 數據加載

PyTorch的torchvision.datasets模塊提供了便捷的CIFAR-10加載方式:

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)

2.2 數據預處理

良好的數據預處理對模型性能至關重要:

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

這里進行了兩個關鍵操作:

  1. ToTensor():將PIL圖像轉換為PyTorch張量,并自動將像素值從[0,255]縮放到[0,1]

  2. Normalize:用均值0.5和標準差0.5對每個通道進行標準化

2.3 數據批量加載

使用DataLoader實現高效的批量數據加載:

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=2)

參數說明:

  • batch_size=64:每次迭代處理64張圖像

  • shuffle=True:每個epoch打亂數據順序

  • num_workers=2:使用2個子進程加載數據

三、模型構建

3.1 網絡架構設計

我們構建的CNN包含四個卷積層和兩個全連接層:

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.conv4 = nn.Conv2d(128, 128, 3, padding=1)self.fc1 = nn.Linear(128 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)self.dropout = nn.Dropout(0.5)

3.2 前向傳播

定義數據在網絡中的流動路徑:

def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))x = F.relu(self.conv4(x))x = x.view(-1, 128 * 4 * 4)x = self.dropout(x)x = F.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

關鍵點:

  1. 每個卷積層后接ReLU激活函數引入非線性

  2. 使用view將三維特征圖展平為一維向量

  3. Dropout層以0.5的概率隨機失活神經元,防止過擬合

四、模型訓練

4.1 訓練設置

model = CNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

我們使用:

  • 交叉熵損失函數:適合多分類問題

  • Adam優化器:自適應學習率,通常比SGD表現更好

  • GPU加速(如果可用)

4.2 訓練循環

for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()

每個epoch中:

  1. 從DataLoader獲取一個batch的數據

  2. 清零梯度(防止梯度累積)

  3. 前向傳播計算輸出和損失

  4. 反向傳播計算梯度

  5. 優化器更新權重

  6. 統計損失和準確率

4.3 訓練可視化

繪制訓練過程中的損失和準確率曲線:

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

五、模型評估

5.1 測試集評估

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy on test images: {100 * correct / total:.2f}%')

關鍵點:

  1. with torch.no_grad():禁用梯度計算,節省內存和計算資源

  2. 計算模型在未見過的測試集上的準確率

5.2 示例預測

可視化一些測試圖像及其預測結果:

dataiter = iter(testloader)
images, labels = next(dataiter)imshow(torchvision.utils.make_grid(images[:4]))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))outputs = model(images.to(device))
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))

六、性能優化建議

雖然我們的基礎模型已經能達到75-80%的準確率,但還可以通過以下方法進一步提升:

  1. 網絡架構改進

    • 添加批量歸一化層(nn.BatchNorm2d)加速訓練并提高性能

    • 使用更深的網絡結構(如ResNet殘差連接)

  2. 數據增強

    transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
  3. 訓練技巧

    • 使用學習率調度器(如lr_scheduler.StepLR

    • 早停法防止過擬合

    • 嘗試不同的優化器(如AdamW)

  4. 正則化

    • 增加Dropout比例

    • 在優化器中添加權重衰減(L2正則化)

七、總結

本文詳細介紹了使用PyTorch實現CNN進行CIFAR-10圖像分類的完整流程。我們從CNN的基礎組件開始,逐步構建了一個包含卷積層、池化層和全連接層的網絡模型。通過合理的數據預處理、模型訓練和評估,我們實現了一個具有不錯分類性能的圖像識別系統。

CNN之所以在圖像任務中表現優異,關鍵在于它的兩個特性:

  1. 局部連接:卷積核只關注局部區域,大大減少了參數量

  2. 參數共享:同一卷積核在整個圖像上滑動使用,提高了效率

通過本實踐,讀者不僅能夠理解CNN的工作原理,還能掌握PyTorch實現深度學習模型的標準流程。這為進一步探索更復雜的計算機視覺任務(如目標檢測、圖像分割等)奠定了堅實基礎。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/95887.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/95887.shtml
英文地址,請注明出處:http://en.pswp.cn/web/95887.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Springboot3+SpringSecurity6Oauth2+vue3前后端分離認證授權-客戶端

客戶端服務整體流程前端技術棧項目結構代碼后端技術棧項目結構代碼整體流程 #mermaid-svg-G5v6x8CSkEcvglmt {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-G5v6x8CSkEcvglmt .error-icon{fill:#552222;}#mermaid-s…

DEEP THINK WITH CONFIDENCE-Meta-基于置信度的深度思考

原文地址 摘要 大型語言模型(LLM)通過自我一致性和多數投票等測試時間縮放方法,在推理任務中顯示出巨大的潛力。然而,這種方法經常導致精度回報遞減和高計算開銷。為了應對這些挑戰,我們引入了深度自信思考(DeepConf),這是一種簡…

零基礎學習數據采集與監視控制系統SCADA

新晉碼農一枚,小編定期整理一些寫的比較好的代碼,作為自己的學習筆記,會試著做一下批注和補充,轉載或者參考他人文獻會標明出處,非商用,如有侵權會刪改!歡迎大家斧正和討論! 目錄 一…

docker run 命令,不接it選項,run一個centos沒有顯示在運行,而run一個nginx卻可以呢?

docker run 命令,不接it選項,run一個centos沒有顯示在運行,而run一個nginx卻可以呢? ChatGPT said: 你問到的這個現象,其實就是 鏡像默認啟動命令 (ENTRYPOINT / CMD) 的差異導致的。🔍 情況分析 1. docker…

【完整源碼+數據集+部署教程】水培植物病害檢測系統源碼和數據集:改進yolo11-AKConv

背景意義 研究背景與意義 隨著全球人口的不斷增長,農業生產面臨著前所未有的挑戰,尤其是在資源有限的環境中,如何提高作物的產量和質量成為了亟待解決的問題。水培技術作為一種新興的農業生產方式,因其高效的水資源利用和較少的土…

第2課:環境搭建:基于DeepSeek API的開發環境配置

概述 在開始大模型RAG實戰之旅前,一個正確且高效的開發環境是成功的基石。本文將手把手指導您完成從零開始的環境配置過程,涵蓋Python環境設置、關鍵庫安裝、DeepSeek API配置以及開發工具優化。通過詳細的步驟說明、常見問題解答和最佳實踐分享&#x…

Boost電路:穩態和小信號分析

穩態分析 參考張衛平的《開關變換器的建模與控制》的1.3章節內容;伏秒平衡:在穩態下,一個開關周期內電感電流的增量是0,即 dIL(t)dt0\frac{dI_{L}(t)}{dt} 0dtdIL?(t)?0。電荷平衡:在穩態下,一個開關周期…

Vue-25-利用Vue3大模型對話框設計之前端和后端的基礎實現

文章目錄 1 設計思路 1.1 核心布局與組件 1.2 交互設計(Interaction Design) 1.3 視覺與用戶體驗 1.4 高級功能與創新設計 2 vue3前端設計 2.1 項目啟動 2.1.1 創建和啟動項目(vite+vue) 2.1.2 清理不需要的代碼 2.1.3 下載必備的依賴(element-plus) 2.1.4 完整引入并注冊(main…

Elasticsearch面試精講 Day 7:全文搜索與相關性評分

【Elasticsearch面試精講 Day 7】全文搜索與相關性評分 文章標簽:Elasticsearch, 全文搜索, 相關性評分, TF-IDF, BM25, 面試, 搜索引擎, 后端開發, 大數據 文章簡述: 本文是“Elasticsearch面試精講”系列的第7天,聚焦于全文搜索與相關性評…

Vllm-0.10.1:vllm bench serve參數說明

一、KVM 虛擬機環境 GPU:4張英偉達A6000(48G) 內存:128G 海光Cpu:128核 大模型:DeepSeek-R1-Distill-Qwen-32B 推理框架Vllm:0.10.1 二、測試命令(random ) vllm bench serve \ --backend vllm \ --base-url http://127.0.…

B.50.10.11-Spring框架核心與電商應用

Spring框架核心原理與電商應用實戰 核心理念: 本文是Spring框架深度指南。我們將從Spring的兩大基石——IoC和AOP的底層原理出發,詳細拆解一個Bean從定義到銷毀的完整生命周期,并深入探討Spring事務管理的實現機制。隨后,我們將聚焦于Spring …

雅菲奧朗SRE知識墻分享(六):『混沌工程的定義與實踐』

混沌工程不再追求“永不宕機”的童話,而是主動在系統中注入可控的“混亂”,通過實驗驗證系統在真實故障場景下的彈性與自我修復能力。混沌工程不是簡單的“搞破壞”,也不是運維團隊的專屬游戲。它是一種以實驗為導向、以度量為核心、以文化為…

從0死磕全棧第五天:React 使用zustand實現To-Do List項目

代碼世界是現實的鏡像,狀態管理教會我們:真正的控制不在于凝固不變,而在于優雅地引導變化。 這是「從0死磕全棧」系列的第5篇文章,前面我們已經完成了環境搭建、路由配置和基礎功能開發。今天,我們將引入一個輕量級但強大的狀態管理工具 —— Zustand,來實現一個完整的 T…

力扣29. 兩數相除題解

原題鏈接29. 兩數相除 - 力扣(LeetCode) 主要不能用乘除取余,于是用位運算代替: Java題解 class Solution {public int divide(int dividend, int divisor) {//全都轉為負數計算, 避免溢出, flag記錄結果的符號int flag 1;if(…

【工具類】Nuclei YAML POC 編寫以及批量檢測

Nuclei YAML POC 編寫以及批量檢測法律與道德使用聲明前言Nuclei 下載地址下載對應版本的文件關于檢查cpu架構關于hkws的未授權訪問參考資料關于 Neclei Yaml 腳本編寫BP Nuclei Template 插件下載并安裝利用插件編寫 POC YAML 文件1、找到有漏洞的頁面抓包發送給插件2、同時將…

自動化運維之ansible

一、認識自動化運維假如管理很多臺服務器,主要關注以下幾個方面“1.管理機與被管理機的連接(管理機如何將管理指令發送給被管理機)2.服務器信息收集(如果被管理的服務器有centos7.5外還有其它linux發行版,如suse,ubunt…

【溫室氣體數據集】亞洲地區長期空氣污染物和溫室氣體排放數據 REAS

目錄 REAS 數據集概述 REAS 數據版本及特點 數據內容(以 REASv3.2.1 為例) 數據形式 數據下載 參考 REAS 數據集(Regional Emission inventory in ASia,亞洲區域排放清單)是由日本國立環境研究所(NIES)及相關研究人員開發的一個覆蓋亞洲地區長期空氣污染物和溫室氣體排放…

中州養老項目:利用Redis解決權限接口響應慢的問題

目錄 在Java中使用Redis緩存 項目中集成SpringCache 在Java中使用Redis緩存 Redis作為緩存,想要在Java中操作Redis,需要 Java中的客戶端操縱Redis就像JDBC操作數據庫一樣,實際底層封裝了對Redis的基礎操作 如何在Java中使用Redis呢?先導入Redis的依賴,這個依賴導入后相當于把…

MathJax - LaTeX:WordPress 公式精準呈現方案

寫在前面:本博客僅作記錄學習之用,部分圖片來自網絡,如需引用請注明出處,同時如有侵犯您的權益,請聯系刪除! 文章目錄前言安裝 MathJax-LaTeX 插件修改插件文件效果總結互動致謝參考前言 在當今知識傳播與…

詳細解讀Docker

1.概述Docker是一種優秀的開源的容器化平臺。用于部署、運行應用程序,它通過將應用及其依賴打包成輕量級、可移植的容器,實現高效一致的運行效果,簡單來說,Docker就是一種輕量級的虛擬技術。2.核心概念2.1.容器(Contai…