PyTorch:學習 CIFAR-10 分類

🔍 開始你的圖像分類之旅:一步一步學習 CIFAR-10 分類

圖像分類是計算機視覺中最基礎的任務之一,如果你是初學者,那么以 CIFAR-10 為訓練場是一個不錯的選擇。本文一步一步帶你從零開始,學習如何用深度學習模型實現圖像分類。


一、CIFAR-10 數據集是什么?

CIFAR-10 是一個小型圖像分類數據集,共包括 10 個類別:? 飛機(airplane)🚗 汽車(automobile)🐦 鳥(bird)🐱 貓(cat)🦌 鹿(deer)🐶 狗(dog)🐸 青蛙(frog)🐴 馬(horse)🚢 船(ship)🚚 卡車(truck)

每張圖片都是 32x32 的小圖,有 RGB 三個顏色通道。

總共有 60000 張圖,其中:

  • 訓練集: 50000 張
  • 測試集: 10000 張

這些圖片內容豐富,分辨率低,適合初學者練手。


二、模型訓練的整體流程

我們用以下流程完成圖像分類:

  1. 數據加載和預處理
  2. 構建模型(CNN)
  3. 設置損失函數和優化器
  4. 訓練模型(前向 + 反向傳播 + 更新參數)
  5. 測試模型效果

🧠 類比理解: 把整個過程比作“學會識別水果”:

  • 數據加載:收集不同水果的照片
  • 模型:像是大腦處理這些圖像的神經元網絡
  • 損失函數:告訴我們判斷錯誤的嚴重程度
  • 優化器:幫助我們不斷修正錯誤,直到準確

三、數據加載和預處理

我們使用 PyTorch 中的 transforms 來將圖片:

  • 轉換成 Tensor(張量)
  • 正則化顏色值到 -1~1 之間,加快模型收斂
import torch
import torchvision
import torchvision.transforms as transforms# 定義圖像轉換操作:將圖片轉換為 Tensor,并進行標準化
transform = transforms.Compose([transforms.ToTensor(),  # 將圖片轉為 Tensor 類型,方便計算transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 將圖片標準化,均值和標準差都設為0.5
])# 加載 CIFAR-10 訓練數據集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
# 使用 DataLoader 進行批量加載數據,batch_size 是每次加載的圖片數量
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True)  # shuffle=True 表示打亂數據# 加載 CIFAR-10 測試數據集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False)  # 測試集不需要打亂數據

🎓 例子解釋

  • 如果 batch_size=64,就意味著每次訓練取 64 張圖片
  • shuffle=True 可以打亂圖片順序,防止模型記住順序而不是學習特征

🍎 通俗類比

  • ToTensor 就像是把一張照片轉成表格(數字表示顏色)
  • Normalize 就像是統一標準,把所有顏色亮度調成統一區間,好比較

四、構建 CNN 模型

CNN(卷積神經網絡)特別適合處理圖像。我們構建一個簡單 CNN 模型,包含:

  • 兩個卷積層 + ReLU 激活
  • 兩個最大池化層(縮小圖片尺寸)
  • 一個全連接隱藏層 + 一個輸出層(10類)
import torch.nn as nn
import torch.nn.functional as F# 定義簡單的卷積神經網絡
class SimpleCNN(nn.Module):def __init__(self):super().__init__()# 第一層卷積層,輸入通道為3(RGB圖像),輸出通道為32,卷積核大小為3x3,padding=1 保證輸出尺寸不變self.conv1 = nn.Conv2d(3, 32, 3, padding=1)# 最大池化層:2x2池化,用來減少圖像尺寸self.pool = nn.MaxPool2d(2, 2)# 第二層卷積層,輸入通道為32,輸出通道為64self.conv2 = nn.Conv2d(32, 64, 3, padding=1)# 全連接層:將卷積層輸出展平為一維向量,連接到一個128維的隱藏層self.fc1 = nn.Linear(64 * 8 * 8, 128)  # CIFAR-10圖像尺寸32x32,經過兩次池化后尺寸為8x8# 輸出層:10類,CIFAR-10數據集包含10個類別self.fc2 = nn.Linear(128, 10)def forward(self, x):# 第一層卷積 + ReLU 激活 + 池化x = self.pool(F.relu(self.conv1(x)))# 第二層卷積 + ReLU 激活 + 池化x = self.pool(F.relu(self.conv2(x)))# 展平特征圖為一維向量,便于輸入全連接層x = x.view(-1, 64 * 8 * 8)# 全連接層 + ReLU 激活x = F.relu(self.fc1(x))# 輸出層,返回每個類別的預測概率x = self.fc2(x)return x

📷 類比

  • 卷積操作就像“掃描照片”的濾鏡,用來提取邊緣、顏色塊等圖像特征
  • 最大池化像是“縮略圖”,保留最顯著的特征,減少計算量

五、設置損失函數和優化器

import torch.optim as optim# 初始化模型并將其放到 GPU 或 CPU 上
model = SimpleCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 損失函數:交叉熵損失函數,用于多分類問題
criterion = nn.CrossEntropyLoss()# 優化器:Adam 優化器,適用于大部分情況,學習率設置為0.001
optimizer = optim.Adam(model.parameters(), lr=0.001)

🧠 例子類比

  • 損失函數就像考試成績:越高說明你錯越多
  • 優化器就像“老師指出你的錯誤”并教你怎么改正

📘 數值示例

  • 如果模型預測飛機的概率是 [0.1, 0.05, …, 0.7](第10類)
  • 但真實標簽是第1類(飛機),交叉熵損失會很大
  • 優化器就會調整參數,使下一次飛機的概率盡量靠近第1類

六、訓練模型:前向、反向、更新

# 訓練過程:迭代多個 epoch,每個 epoch 會遍歷所有訓練數據
for epoch in range(10):  # 總共訓練 10 個 epochrunning_loss = 0.0  # 每個 epoch 初始化損失值為 0for inputs, labels in trainloader:  # 遍歷每個 batchinputs, labels = inputs.to(device), labels.to(device)  # 將數據移到 GPU 或 CPUoptimizer.zero_grad()  # 清除上一次的梯度outputs = model(inputs)  # 前向傳播,得到每張圖片的預測結果loss = criterion(outputs, labels)  # 計算損失值loss.backward()  # 反向傳播,計算梯度optimizer.step()  # 更新模型參數running_loss += loss.item()  # 累加損失值print(f"Epoch {epoch + 1}, Loss: {running_loss:.3f}")  # 打印當前 epoch 的損失

📐 具體數值例子

  • 模型初始權重是 0.3,預測錯誤 → loss = 2.5
  • 反向傳播算出權重梯度是 -0.8
  • 學習率為 0.01,更新后權重 = 0.3 - 0.01 × (-0.8) = 0.308

七、測試模型效果

correct = 0  # 初始化正確預測的個數
total = 0  # 初始化總預測的個數
model.eval()  # 設置模型為評估模式,關閉 Dropout 等訓練時的特殊操作with torch.no_grad():  # 在測試時,不需要計算梯度,減少計算量for data in testloader:  # 遍歷測試集中的數據images, labels = dataimages, labels = images.to(device), labels.to(device)  # 將數據移到 GPU 或 CPUoutputs = model(images)  # 獲取模型的輸出_, predicted = torch.max(outputs, 1)  # 獲取預測結果,torch.max 返回最大值和其索引,這里我們只取索引total += labels.size(0)  # 累加總的測試樣本數量correct += (predicted == labels).sum().item()  # 統計預測正確的樣本數量print(f"測試準確率: {100 * correct / total:.2f}%")  # 打印測試集上的準確率

在這里插入圖片描述

📊 例子

  • 如果 total=10000,correct=7000,準確率就是 70%

在這里插入圖片描述

🏁 完整代碼快速運行包

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt  # 添加matplotlib用于可視化
from matplotlib import rcParams  # 用于設置字體# 1. 數據加載與預處理
transform = transforms.Compose([transforms.ToTensor(),  # 將圖像轉換為Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 歸一化到[-1, 1]之間
])# 加載訓練集與測試集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)# 2. 定義模型:簡單的CNN模型
class SimpleCNN(nn.Module):def __init__(self):super().__init__()# 第一個卷積層:3個輸入通道,32個輸出通道,卷積核大小3x3,padding為1self.conv1 = nn.Conv2d(3, 32, 3, padding=1)# 最大池化層:2x2池化self.pool = nn.MaxPool2d(2, 2)# 第二個卷積層:32個輸入通道,64個輸出通道,卷積核大小3x3,padding為1self.conv2 = nn.Conv2d(32, 64, 3, padding=1)# 全連接層,輸入大小64x8x8,輸出128self.fc1 = nn.Linear(64 * 8 * 8, 128)# 最后一層全連接層,輸出10個類別self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))  # 卷積+ReLU+池化x = self.pool(F.relu(self.conv2(x)))  # 卷積+ReLU+池化x = x.view(-1, 64 * 8 * 8)  # 展平數據,準備全連接x = F.relu(self.fc1(x))  # 全連接+ReLUx = self.fc2(x)  # 最后一層輸出return x# 3. 初始化模型、損失函數與優化器
model = SimpleCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判斷是否使用GPU
model.to(device)  # 將模型轉移到GPU或CPU上criterion = nn.CrossEntropyLoss()  # 使用交叉熵作為損失函數
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam優化器,學習率為0.001# 4. 訓練模型
for epoch in range(10):  # 訓練10個epochrunning_loss = 0.0for inputs, labels in trainloader:inputs, labels = inputs.to(device), labels.to(device)  # 數據轉移到GPU或CPUoptimizer.zero_grad()  # 清除上一次的梯度outputs = model(inputs)  # 前向傳播loss = criterion(outputs, labels)  # 計算損失loss.backward()  # 反向傳播optimizer.step()  # 優化器更新參數running_loss += loss.item()  # 累加損失print(f"Epoch {epoch + 1}, Loss: {running_loss:.3f}")  # 打印每個epoch的損失# 5. 測試模型效果
correct = 0
total = 0
model.eval()  # 將模型設置為評估模式with torch.no_grad():  # 禁止計算梯度,提高效率for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)  # 數據轉移到GPU或CPUoutputs = model(images)_, predicted = torch.max(outputs, 1)  # 獲取最大概率的類別total += labels.size(0)  # 累加總樣本數correct += (predicted == labels).sum().item()  # 統計正確的樣本數# 可視化一個批次的預測結果plt.figure(figsize=(10, 10))for i in range(8):  # 顯示前8張圖片plt.subplot(2, 4, i + 1)plt.imshow(images[i].cpu().permute(1, 2, 0) * 0.5 + 0.5)  # 反歸一化并顯示圖片plt.title(f"Label: {labels[i].item()}\nPrediction: {predicted[i].item()}")plt.axis('off')plt.show()break  # 僅顯示一個批次# 輸出測試準確率
print(f"測試準確率: {100 * correct / total:.2f}%")

你可以將以下內容保存為 train_cifar10.py 并運行:

python train_cifar10.py

💡 不需要修改任何內容就能開始訓練和測試!有 CUDA 就用 GPU,否則自動使用 CPU。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/79337.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/79337.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/79337.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

3.學習筆記--Spring-AOP總結(p39)-Spring事務簡介(P40)-Spring事務角色(P41)-Spring事務屬性(P42)

1.AOP總結:面向切面編程,在不驚動原始基礎上為方法進行功能增強。 2.AOP核心概念: (1)代理:SpringAOP的核心是采用代理模式 (2)連接點:在SpringAOP中,理解為任…

數據庫-day06

一、實驗名稱和性質 分類查詢 驗證 綜合 設計 二、實驗目的 1.掌握數據查詢的Group by ; 2. 掌握聚集函數的使用方法。 三、實驗的軟硬件環境要求 硬件環境要求: PC機(單機) 使用的軟件名稱、版本號以及模塊: …

看門狗定時器(WDT)超時

一、問題 Arduino 程序使用<Ticker.h>包時&#xff0c;使用不當情況下&#xff0c;會導致“看門狗WDT超時” 1.1問題控制臺報錯 在串口監視器顯示 --------------- CUT HERE FOR EXCEPTION DECODER ---------------Soft WDT resetException (4): epc10x402077cb epc2…

AI在多Agent協同領域的核心概念、技術方法、應用場景及挑戰 的詳細解析

以下是 AI在多Agent協同領域的核心概念、技術方法、應用場景及挑戰 的詳細解析&#xff1a; 1. 多Agent協同的定義與核心目標 多Agent系統&#xff08;MAS, Multi-Agent System&#xff09;&#xff1a; 由多個獨立或協作的智能體&#xff08;Agent&#xff09;組成&#xff…

Wireshark TS | 異常 ACK 數據包處理

問題背景 來自于學習群里群友討論的一個數據包跟蹤文件&#xff0c;在其中涉及到兩處數據包異常現象&#xff0c;而產生這些現象的實際原因是數據包亂序。由于這兩處數據包異常&#xff0c;都有點特別&#xff0c;本篇也就其中一個異常現象單獨展開說明。 問題信息 數據包跟…

【React】項目的搭建

create-react-app 搭建vite 搭建相關下載 在Vue中搭建項目的步驟&#xff1a;1.首先安裝腳手架的環境&#xff0c;2.通過腳手架的指令創建項目 在React中有兩種方式去搭建項目&#xff1a;1.和Vue一樣&#xff0c;先安裝腳手架然后通過腳手架指令搭建&#xff1b;2.npx create-…

深入淺出 NVIDIA CUDA 架構與并行計算技術

&#x1f407;明明跟你說過&#xff1a;個人主頁 &#x1f3c5;個人專欄&#xff1a;《深度探秘&#xff1a;AI界的007》 &#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目錄 一、引言 1、CUDA為何重要&#xff1a;并行計算的時代 2、NVIDIA在…

pytorch學習02

自動微分 自動微分模塊torch.autograd負責自動計算張量操作的梯度&#xff0c;具有自動求導功能。自動微分模塊是構成神經網絡訓練的必要模塊&#xff0c;可以實現網絡權重參數的更新&#xff0c;使得反向傳播算法的實現變得簡單而高效。 1. 基礎概念 張量 Torch中一切皆為張…

Java虛擬機(JVM)平臺無關?相關?

計算機的概念模型 計算機實際上就是實現了一個圖靈機模型。即&#xff0c;輸入參數&#xff0c;根據程序計算&#xff0c;輸出結果。圖靈機模型如圖。 Tape是輸入數據&#xff0c;Program是針對這些數據進行計算的程序&#xff0c;中間橫著的方塊表示的是機器的狀態。 目前使…

satoken的奇奇怪怪的錯誤

發了 /user/getBrowseDetail和/user/getResponDetail&#xff0c;但為什么進入handle里面有三次&#xff1f;且第一次的handle類型是AbstractHandleMapping$PreFlightHttpRequestHandlerxxx,這一次進來的時候flag為false&#xff0c;StpUtils.checkLogin拋出了異常 第二次進來的…

【KWDB 創作者計劃】_上位機知識篇---SDK

文章目錄 前言一、SDK的核心組成API(應用程序接口)庫文件(Libraries)開發工具文檔與示例依賴項與環境配置二、SDK的作用簡化開發流程確保兼容性與穩定性加速產品迭代功能擴展與定制三、SDK的典型應用場景硬件設備開發操作系統與平臺云服務與API集成游戲與圖形開發四、SDK與…

golang處理時間的包time一次性全面了解

本文旨在對官方time包有個全面學習了解。不鉆摳細節&#xff0c;但又有全面了解&#xff0c;重點介紹常用的內容&#xff0c;一些低頻的可能這輩子可能都用不上。主打一個花最少時間辦最大事。 Duration對象: 兩個time實例經過的時間,以長度為int64的納秒來計數。 常見的durati…

PyCharm Flask 使用 Tailwind CSS 配置

使用 Tailwind CSS 步驟 1&#xff1a;初始化項目 在 PyCharm 終端運行&#xff1a;npm init -y安裝 Tailwind CSS&#xff1a;npm install -D tailwindcss postcss autoprefixer初始化 Tailwind 配置文件&#xff1a;npx tailwindcss init這會生成 tailwind.config.js。 步…

【英語語法】基本句型

目錄 前言一&#xff1a;主謂二&#xff1a;主謂賓三&#xff1a;主系表四&#xff1a;主謂雙賓五&#xff1a;主謂賓補 前言 英語基本句型是語法體系的基石&#xff0c;以下是英語五大基本句型。 一&#xff1a;主謂 結構&#xff1a;主語 不及物動詞 例句&#xff1a; T…

隔離DCDC輔助電源解決方案與產品應用科普

**“隔離”與“非隔離的區別** 隔離&#xff1a; 1、AC-DC&#xff0c;也叫“一次電源”&#xff0c;人可能會碰到的應用場合&#xff0c;起安全保護作用&#xff1b; 2、為了抗干擾&#xff0c;通過隔離能有效隔絕干擾信號傳輸。 非隔離&#xff1a; 1、“安全特低電壓&#…

DS-SLAM 運動一致性檢測的源碼解讀

運動一致性檢測是Frame.cc的Frame::ProcessMovingObject(const cv::Mat &imgray)函數。 對應DS-SLAM流程圖Moving consistency check的部分 把這個函數單獨摘出來&#xff0c;寫了一下對兩幀檢測&#xff0c;查看效果的程序&#xff1a; #include <opencv2/opencv.hpp…

安全測試的全面知識體系及實現路徑

以下是安全測試的全面知識體系及實現路徑,結合最新工具和技術趨勢(截至2025年): 一、安全測試核心類型與工具 1. 靜態應用安全測試(SAST) 知識點: 通過分析源代碼、字節碼或二進制文件識別漏洞(如SQL注入、緩沖區溢出)支持早期漏洞發現,減少修復成本,適合白盒測試場…

GPT-4o Image Generation Capabilities: An Empirical Study

GPT-4o 圖像生成能力:一項實證研究 目錄 介紹研究背景方法論文本到圖像生成圖像到圖像轉換圖像到 3D 能力主要優勢局限性與挑戰對比性能影響與未來方向結論介紹 近年來,圖像生成領域發生了巨大的變化,從生成對抗網絡 (GAN) 發展到擴散模型,再到可以處理多種模態的統一生成架…

Redis之全局唯一ID

全局ID生成器 文章目錄 全局ID生成器一、全局ID生成器的定義定義核心作用 二、全局ID生成器需滿足的特征1. 唯一性&#xff08;Uniqueness&#xff09;?2. 高性能&#xff08;High Performance&#xff09;?3. 可擴展性&#xff08;Scalability&#xff09;?4. 有序性&#…

nginx中的代理緩存

1.緩存存放路徑 對key取哈希值之后&#xff0c;設置cache內容&#xff0c;然后得到的哈希值的倒數第一位作為第一個子目錄&#xff0c;倒數第三位和倒數第二位組成的字符串作為第二個子目錄&#xff0c;如圖。 proxy_cache_path /xxxx/ levels1:2 2.文件名哈希值