PyTorch 系列教程:使用CNN實現圖像分類

圖像分類是計算機視覺領域的一項基本任務,也是深度學習技術的一個常見應用。近年來,卷積神經網絡(cnn)和PyTorch庫的結合由于其易用性和魯棒性已經成為執行圖像分類的流行選擇。

理解卷積神經網絡(cnn)

卷積神經網絡是一類深度神經網絡,對分析視覺圖像特別有效。他們利用多層構建一個可以直接從圖像中識別模式的模型。這些模型對于圖像識別和分類等任務特別有用,因為它們不需要手動提取特征。

cnn的關鍵組成部分

  • 卷積層:這些層對輸入應用卷積操作,將結果傳遞給下一層。每個過濾器(或核)可以捕獲不同的特征,如邊緣、角或其他模式。
  • 池化層:這些層減少了表示的空間大小,以減少參數的數量并加快計算速度。池化層簡化了后續層的處理。
  • 完全連接層:在這些層中,神經元與前一層的所有激活具有完全連接,就像傳統的神經網絡一樣。它們有助于對前一層識別的對象進行分類。
    在這里插入圖片描述

使用PyTorch進行圖像分類

PyTorch是開源的深度學習庫,提供了極大的靈活性和多功能性。研究人員和從業人員廣泛使用它來輕松有效地實現尖端的機器學習模型。

設置PyTorch

首先,確保在開發環境中安裝了PyTorch。你可以通過pip安裝它:

pip install torch torchvision

用PyTorch創建簡單的CNN示例

下面是如何定義簡單的CNN來使用PyTorch對圖像進行分類的示例。

import torch
import torch.nn as nn
import torch.nn.functional as F# 定義CNN模型(修復了變量引用問題)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)      # 第一個卷積層:3輸入通道,6輸出通道,5x5卷積核self.pool = nn.MaxPool2d(2, 2)        # 最大池化層:2x2窗口,步長2self.conv2 = nn.Conv2d(6, 16, 5)     # 第二個卷積層:6輸入通道,16輸出通道,5x5卷積核self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全連接層1:400輸入 -> 120輸出self.fc2 = nn.Linear(120, 84)      # 全連接層2:120輸入 -> 84輸出self.fc3 = nn.Linear(84, 10)       # 輸出層:84輸入 -> 10類 logitsdef forward(self, x):# 輸入形狀:[batch_size, 3, 32, 32]x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 6, 14, 14](池化后尺寸減半)x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 16, 5, 5] x = x.view(-1, 16 * 5 * 5)            # 展平為一維向量:16 * 5 * 5=400x = F.relu(self.fc1(x))             # -> [batch, 120]x = F.relu(self.fc2(x))             # -> [batch, 84]x = self.fc3(x)                     # -> [batch, 10](未應用softmax,配合CrossEntropyLoss使用)return x

這個特殊的網絡接受一個輸入圖像,通過兩組卷積和池化層,然后是三個完全連接的層。根據數據集的復雜性和大小調整網絡的架構和超參數。

模型定義

  • SimpleCNN 繼承自 nn.Module
  • 使用兩個卷積層提取特征,三個全連接層進行分類
  • 最終輸出未應用 softmax,而是直接輸出 logits(與 CrossEntropyLoss 配合使用)

訓練網絡

對于訓練,你需要一個數據集。PyTorch通過torchvision包提供了用于數據加載和預處理的實用程序。

import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader# 初始化模型、損失函數和優化器
net = SimpleCNN()               # 實例化模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵損失函數(自動處理softmax)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001,      # 學習率momentum=0.9)   # 動量參數# 數據預處理和加載
transform = transforms.Compose([transforms.ToTensor(),          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 加載CIFAR-10訓練集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True,  # 自動下載數據集transform=transform
)trainloader = DataLoader(trainset, batch_size=4,   # 每個batch包含4張圖像shuffle=True)  # 打亂數據順序

模型配置

  • 損失函數CrossEntropyLoss(自動包含 softmax 和 log_softmax)
  • 優化器:SGD with momentum,學習率 0.001

數據加載

  • 使用 torchvision.datasets.CIFAR10 加載數據集

  • batch_size:4(根據 GPU 內存調整,CIFAR-10 建議 batch size ≥ 32)

  • transforms.Compose 定義數據預處理流程:

    • ToTensor():將圖像轉換為 PyTorch Tensor
    • Normalize():標準化圖像像素值到 [-1, 1]

加載數據后,訓練過程包括通過數據集進行多次迭代,使用反向傳播和合適的損失函數:

# 訓練循環
for epoch in range(2):  # 進行2個epoch的訓練running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data# 前向傳播outputs = net(inputs)loss = criterion(outputs, labels)# 反向傳播和優化optimizer.zero_grad()   # 清空梯度loss.backward()         # 計算梯度optimizer.step()       # 更新參數running_loss += loss.item()# 每2000個batch打印一次if i % 2000 == 1999:avg_loss = running_loss / 2000print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')running_loss = 0.0print("訓練完成!")

訓練循環

  • epoch:完整遍歷數據集一次
  • batch:數據加載器中的一個批次
  • 梯度清零:每次反向傳播前需要清空梯度
  • 損失計算outputs 的形狀為 [batch_size, 10]labels 為整數標簽

完整代碼

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader# 定義CNN模型(修復了變量引用問題)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)      # 第一個卷積層:3輸入通道,6輸出通道,5x5卷積核self.pool = nn.MaxPool2d(2, 2)        # 最大池化層:2x2窗口,步長2self.conv2 = nn.Conv2d(6, 16, 5)     # 第二個卷積層:6輸入通道,16輸出通道,5x5卷積核self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全連接層1:400輸入 -> 120輸出self.fc2 = nn.Linear(120, 84)      # 全連接層2:120輸入 -> 84輸出self.fc3 = nn.Linear(84, 10)       # 輸出層:84輸入 -> 10類 logitsdef forward(self, x):# 輸入形狀:[batch_size, 3, 32, 32]x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 6, 14, 14](池化后尺寸減半)x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 16, 5, 5] x = x.view(-1, 16 * 5 * 5)            # 展平為一維向量:16 * 5 * 5=400x = F.relu(self.fc1(x))             # -> [batch, 120]x = F.relu(self.fc2(x))             # -> [batch, 84]x = self.fc3(x)                     # -> [batch, 10](未應用softmax,配合CrossEntropyLoss使用)return x# 初始化模型、損失函數和優化器
net = SimpleCNN()               # 實例化模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵損失函數(自動處理softmax)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001,      # 學習率momentum=0.9)   # 動量參數# 數據預處理和加載
transform = transforms.Compose([transforms.ToTensor(),            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])# 加載CIFAR-10訓練集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True,  # 自動下載數據集transform=transform
)
trainloader = DataLoader(trainset, batch_size=4,   # 每個batch包含4張圖像shuffle=True)  # 打亂數據順序# 訓練循環
for epoch in range(2):  # 進行2個epoch的訓練running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data# 前向傳播outputs = net(inputs)loss = criterion(outputs, labels)# 反向傳播和優化optimizer.zero_grad()   # 清空梯度loss.backward()         # 計算梯度optimizer.step()       # 更新參數running_loss += loss.item()# 每2000個batch打印一次if i % 2000 == 1999:avg_loss = running_loss / 2000print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')running_loss = 0.0print("訓練完成!")

最后總結

通過PyTorch和卷積神經網絡,你可以有效地處理圖像分類任務。借助PyTorch的靈活性,可以根據特定的數據集和應用程序構建、訓練和微調模型。示例代碼僅為理論過程,實際項目中還有大量優化空間。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/72386.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/72386.shtml
英文地址,請注明出處:http://en.pswp.cn/web/72386.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Spring Cloud Stream - 構建高可靠消息驅動與事件溯源架構

一、引言 在分布式系統中,傳統的 REST 調用模式往往導致耦合,難以滿足高并發和異步解耦的需求。消息驅動架構(EDA, Event-Driven Architecture)通過異步通信、事件溯源等模式,提高了系統的擴展性與可觀測性。 作為 S…

王者榮耀道具頁面爬蟲(json格式數據)

首先這個和英雄頁面是不一樣的,英雄頁面的圖片鏈接是直接放在源代碼里面的,直接就可以請求到,但是這個源代碼里面是沒有的 雖然在檢查頁面能夠搜索到,但是應該是動態加載的,源碼中搜不到該鏈接 然后就去看看是不是某…

【一起來學kubernetes】12、k8s中的Endpoint詳解

一、Endpoint的定義與作用二、Endpoint的創建與管理三、Endpoint的查看與組成四、EndpointSlice五、Endpoint的使用場景六、Endpoint與Service的關系1、定義與功能2、創建與管理3、關系與交互4、使用場景與特點 七、Endpoint的kubectl命令1. 查看Endpoint2. 創建Endpoint3. 編輯…

結構型模式之橋接模式:解耦抽象和實現

在面向對象設計中,我們經常遇到需要擴展某些功能,但又不能修改現有代碼的情況。為了避免繼承帶來的復雜性和維護難度,橋接模式(Bridge Pattern)應運而生。橋接模式是一種結構型設計模式,旨在解耦抽象部分和…

如何用Java將實體類轉換為JSON并輸出到控制臺?

在軟件開發的過程中,Java是一種廣泛使用的編程語言,而在眾多應用中,數據的傳輸和存儲經常需要使用JSON格式。JSON(JavaScript Object Notation)是一種輕量級的數據交換格式,易于人類閱讀和編寫,…

Vue3 開發的 VSCode 插件

1. Volar Vue3 正式版發布,Vue 團隊官方推薦 Volar 插件來代替 Vetur 插件,不僅支持 Vue3 語言高亮、語法檢測,還支持 TypeScript 和基于 vue-tsc 的類型檢查功能。 2. Vue VSCode Snippets 為開發者提供最簡單快速的生成 Vue 代碼片段的方…

C# Enumerable類 之 集合操作

總目錄 前言 在 C# 中,System.Linq.Enumerable 類是 LINQ(Language Integrated Query)的核心組成部分,它提供了一系列靜態方法,用于操作實現了 IEnumerable 接口的集合。通過這些方法,我們可以輕松地對集合…

51c自動駕駛~合集54

我自己的原文哦~ https://blog.51cto.com/whaosoft/13517811 #Chameleon 快慢雙系統!清華&博世最新:無需訓練即可解決復雜道路拓撲 在自動駕駛技術中,車道拓撲提取是實現無地圖導航的核心任務之一。它要求系統不僅能檢測出車道和交…

Spring Cloud Eureka - 高可用服務注冊與發現解決方案

在微服務架構中,服務注冊與發現是確保系統動態擴展和高效通信的關鍵。Eureka 作為 Spring Cloud 生態的核心組件,不僅提供去中心化的服務治理能力,還通過自我保護、健康檢查等機制提升系統的穩定性,使其成為微服務架構中的重要支撐…

Unity屏幕適配——立項時設置

項目類型:2D游戲、豎屏、URP 其他類型,部分原理類似。 1、確定設計分辨率:750*1334 為什么是它? 因為它是 iphone8 的尺寸,寬高比適中。 方便后續適配到真機的 “更長屏” 或 “更寬屏” 2、在場景…

深度學習中LayerNorm與RMSNorm對比

LayerNorm不同于BatchNorm,其與batch大小無關,均值和方差 在 每個樣本的特征維度 C 內計算, 適用于 變長輸入(如 NLP 任務中的 Transformer) 詳細的BatchNorm在之前的一篇文章進行了詳細的介紹:深度學習中B…

使用WireShark解密https流量

概述 https協議是在http協議的基礎上,使用TLS協議對http數據進行了加密,使得網絡通信更加安全。一般情況下,使用WireShark抓取的https流量,數據都是加密的,無法直接查看。但是可以通過以下兩種方法,解密抓…

數字化轉型 - 數據驅動

數字化轉型 一、 數據驅動1.1 監控1.2 分析1.3 挖掘1.4 賦能 二、數據驅動案例2.1 能源工業互聯網:綠色節能的數字化路徑2.2 光伏產業的數字化升級2.3 數據中心的綠色轉型2.4云遷移的質效優化2.5 企業數字化運營的實踐2.6數字化轉型的最佳實踐 一、 數據驅動 從數…

解決 Docker 鏡像拉取超時問題:配置國內鏡像源

在使用 Docker 的過程中,經常會遇到鏡像拉取超時的問題,尤其是在國內網絡環境下。這不僅會浪費大量的時間,還可能導致一些項目無法順利進行。今天,我將分享一個簡單而有效的解決方法:配置國內鏡像源。 環境 操作系統 c…

Linux命令基礎,創建,輸入,輸出,查看,查詢

什么是命令、命令行 命令行:即:Linux終端(Terminal),是一種命令提示符頁面。以純“字符”的形式操作操作系統,可以使用各種字符化命令對操作系統發出操作指令。 命令:即Linux程序。一個命令就…

【GNU Radio】ZMQ模塊學習

【GNU Radio】ZMQ模塊學習 ZMQ 介紹前置知識Socket通信模型PUB/SUB(發布/訂閱)模型PUSH/PULL(推/拉)模型REQ/REP(請求/響應)模型 ZMQ 詳解基于通信模型分析基于數據格式分析Data BlocksMessage Blocks ZMQ …

【筆記】深度學習模型訓練的 GPU 內存優化之旅:綜述篇

開設此專題,目的一是梳理文獻,目的二是分享知識。因為筆者讀研期間的研究方向是單卡上的顯存優化,所以最初思考的專題名稱是“顯存突圍:深度學習模型訓練的 GPU 內存優化之旅”,英文縮寫是 “MLSys_GPU_Memory_Opt”。…

Vue 3 Diff 算法深度解析:與 Vue 2 雙端比對對比

文章目錄 1. 核心算法概述1.1 Vue 2 雙端比對算法1.2 Vue 3 快速 Diff 算法 2. 算法復雜度分析2.1 時間復雜度對比2.2 空間復雜度對比 3. 核心實現解析3.1 Vue 2 雙端比對代碼3.2 Vue 3 快速 Diff 代碼 4. 性能優化分析4.1 性能測試數據4.2 內存使用對比 5. 使用場景分析5.1 Vu…

神經網絡的基本知識

感知機 輸入:來自其他 n 個神經元傳遞過來的輸入信號 處理:輸入信號通過帶權重的連接進行傳遞, 神經元接受到總輸入值將與神經元的閾值進行比較 輸出:通過激活函數的處理以得到輸出 感知機由兩層神經元組成, 輸入層接受外界輸入信號傳遞給…

UE5與U3D引擎對比分析

Unreal Engine 5(UE5)和Unity 3D(U3D)是兩款主流的游戲引擎,適用于不同類型的項目開發。以下是它們的主要區別,分點整理: 1. 核心定位 UE5: 主打3A級高畫質項目(如主機/P…