使用 PyTorch 實現標準卷積神經網絡(CNN)

卷積神經網絡(CNN)是深度學習中的重要組成部分,廣泛應用于圖像處理、語音識別、視頻分析等任務。在這篇博客中,我們將使用 PyTorch 實現一個標準的卷積神經網絡(CNN),并介紹各個部分的作用。

什么是卷積神經網絡(CNN)?

卷積神經網絡(CNN)是一種專門用于處理圖像數據的深度學習模型,它通過卷積層提取圖像的特征。CNN 由多個層次組成,其中包括卷積層(Conv2d)、池化層(MaxPool2d)、全連接層(Linear)、激活函數(ReLU)等。這些層級合作,使得模型能夠從原始圖像中自動學習到重要特征。

CNN 的核心組成部分

  1. 卷積層(Conv2d):用于提取輸入圖像的局部特征,通過多個卷積核對圖像進行卷積運算。
  2. 激活函數(ReLU):增加非線性,使得模型能夠學習更復雜的特征。
  3. 池化層(MaxPool2d):通過對特征圖進行下采樣來減少空間尺寸,降低計算復雜度,同時保留重要的特征。
  4. 全連接層(Linear):將卷積和池化后得到的特征圖展平,送入全連接層進行分類或回歸預測。

PyTorch 實現 CNN

下面是我們實現的標準卷積神經網絡模型。它包含三個卷積層和兩個全連接層,適用于圖像分類任務,如 MNIST 數據集。

代碼實現

import torch
import torch.nn as nn
import torch.nn.functional as Fclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷積層1: 輸入1個通道(灰度圖像),輸出32個通道self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)# 卷積層2: 輸入32個通道,輸出64個通道self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)# 卷積層3: 輸入64個通道,輸出128個通道self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)# 全連接層1: 輸入128*7*7,輸出1024個節點self.fc1 = nn.Linear(128 * 7 * 7, 1024)# 全連接層2: 輸入1024個節點,輸出10個節點(假設是10分類問題)self.fc2 = nn.Linear(1024, 10)# Dropout層: 避免過擬合self.dropout = nn.Dropout(0.5)def forward(self, x):# 第一層卷積 + ReLU 激活 + 最大池化x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)  # 使用2x2的最大池化# 第二層卷積 + ReLU 激活 + 最大池化x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)# 第三層卷積 + ReLU 激活 + 最大池化x = F.relu(self.conv3(x))x = F.max_pool2d(x, 2, 2)# 展平層(將卷積后的特征圖展平成1D向量)x = x.view(-1, 128 * 7 * 7)  # -1代表自動推算batch size# 第一個全連接層 + ReLU 激活 + Dropoutx = F.relu(self.fc1(x))x = self.dropout(x)# 第二個全連接層(輸出最終分類結果)x = self.fc2(x)return x# 創建CNN模型
model = CNN()# 打印模型架構
print(model)

代碼解析

  1. 卷積層(Conv2d)

    • self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1):該層的輸入為 1 個通道(灰度圖像),輸出 32 個通道,卷積核大小為 3x3,步幅為 1,填充為 1,保持輸出特征圖的大小與輸入相同。
    • 后續的卷積層類似,只是輸出通道數量逐漸增多。
  2. 激活函數(ReLU)

    • F.relu(self.conv1(x)):ReLU 激活函數將輸入的負值轉為 0,并保留正值,增加了模型的非線性。
  3. 池化層(MaxPool2d)

    • F.max_pool2d(x, 2, 2):使用 2x2 的池化窗口和步幅為 2 進行池化,將特征圖尺寸縮小一半,減少計算復雜度。
  4. 展平(Flatten)

    • x = x.view(-1, 128 * 7 * 7):在經過卷積和池化操作后,我們將多維的特征圖展平成一維向量,供全連接層輸入。
  5. Dropout

    • self.dropout = nn.Dropout(0.5):Dropout 正則化技術在訓練時隨機丟棄一些神經元,防止過擬合。
  6. 全連接層(Linear)

    • self.fc1 = nn.Linear(128 * 7 * 7, 1024):第一個全連接層的輸入是卷積后得到的特征,輸出 1024 個節點。
    • self.fc2 = nn.Linear(1024, 10):最后的全連接層將 1024 個節點壓縮為 10 個輸出,代表分類結果。

訓練 CNN 模型

要訓練該模型,我們需要加載一個數據集、定義損失函數和優化器,然后進行訓練。以下是如何使用 MNIST 數據集進行訓練的示例。

import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 加載 MNIST 數據集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練循環
num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 100 == 99:  # 每100個batch輸出一次損失print(f'Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {running_loss / 100:.4f}')running_loss = 0.0print("Finished Training")

訓練過程說明

  • 數據加載器(DataLoader):用于批量加載訓練數據,支持數據的隨機打亂(shuffle)。
  • 損失函數(CrossEntropyLoss):用于多分類問題,計算預測和真實標簽之間的交叉熵損失。
  • 優化器(Adam):Adam 優化器自適應調整學習率,通常在深度學習中表現良好。
  • 訓練循環:每個 epoch 處理整個數據集,通過前向傳播、計算損失、反向傳播和優化步驟,更新網絡參數。

總結

在這篇文章中,我們實現了一個標準的卷積神經網絡(CNN),并使用 PyTorch 對其進行了定義和訓練。通過使用卷積層、池化層和全連接層,模型能夠自動學習圖像的特征并進行分類。我們還介紹了如何訓練模型、加載數據集以及使用常見的優化器和損失函數。希望這篇文章能幫助你理解 CNN 的基本架構及其實現方式!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/70209.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/70209.shtml
英文地址,請注明出處:http://en.pswp.cn/web/70209.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SpringBoot2.0整合Redis(Lettuce版本)

前言: 目前java操作redis的客戶端有jedis跟Lettuce。在springboot1.x系列中,其中使用的是jedis, 但是到了springboot2.x其中使用的是Lettuce。 因為我們的版本是springboot2.x系列,所以今天使用的是Lettuce。關于jedis跟lettuce的區別&#…

qt + opengl 給立方體增加陰影

在前幾篇文章里面學會了通過opengl實現一個立方體,那么這篇我們來學習光照。 風氏光照模型的主要結構由3個分量組成:環境(Ambient)、漫反射(Diffuse)和鏡面(Specular)光照。下面這張圖展示了這些光照分量看起來的樣子: 1 環境光照(Ambient …

大模型工具大比拼:SGLang、Ollama、VLLM、LLaMA.cpp 如何選擇?

簡介:在人工智能飛速發展的今天,大模型已經成為推動技術革新的核心力量。無論是智能客服、內容創作,還是科研輔助、代碼生成,大模型的身影無處不在。然而,面對市場上琳瑯滿目的工具,如何挑選最適合自己的那…

stream流常用方法

1.reduce 在Java中,可以使用Stream API的reduce方法來計算一個整數列表的乘積。reduce方法是一種累積操作,它可以將流中的元素組合起來,返回單個結果。對于計算乘積,你需要提供一個初始值(通常是1,因為乘法…

pgAdmin4在mac m1上面簡單使用(Docker)

問題 想要在本地簡單了解一下pgAdmin4一些簡單功能。故需要在本機先安裝看一看。 安裝步驟 拉取docker鏡像 docker pull dpage/pgadmin4直接簡單運行pgAdmin4 docker run --name pgAdmin4 -p 5050:80 \-e "PGADMIN_DEFAULT_EMAILuserdomain.com" \-e "PGAD…

ubuntu下安裝TFTP服務器

在 Ubuntu 系統下安裝和配置 TFTP(Trivial File Transfer Protocol)服務器可以按照以下步驟進行: 1. 安裝 TFTP 服務器軟件包 TFTP 服務器通常使用 tftpd-hpa 軟件包,你可以使用以下命令進行安裝: sudo apt update …

Softing線上研討會 | 自研還是購買——用于自動化產品的工業以太網

| 線上研討會時間:2025年1月27日 16:00~16:30 / 23:00~23:30 基于以太網的通信在工業自動化網絡中的重要性日益增加。設備制造商正面臨著一大挑戰——如何快速、有效且經濟地將工業以太網協議集成到其產品中。其中的關鍵問題包括:是否只需集成單一的工…

vscode創建java web項目

一.項目部署 1.shiftctrlp,選擇java項目 2.選擇maven create from arcetype 3.選擇webapp 4.目錄結構如下,其中index.jsp是首頁 5.找到左下角的servers,添加tomcat服務器 選擇 再選擇: 找到你下載的tomcat 的bin目錄的上一級目錄&#x…

C語言指針學習筆記

1. 指針的定義 指針(Pointer)是存儲變量地址的變量。在C語言中,指針是一種非常重要的數據類型,通過指針可以直接訪問和操作內存。 2. 指針的聲明與初始化 2.1 指針聲明 指針變量的聲明格式為:數據類型 *指針變量名…

DeepSeek R1生成圖片總結2(雖然本身是不能直接生成圖片,但是可以想辦法利用別的工具一起實現)

DeepSeek官網 目前階段,DeepSeek R1是不能直接生成圖片的,但可以通過優化文本后轉換為SVG或HTML代碼,再保存為圖片。另外,Janus-Pro是DeepSeek的多模態模型,支持文生圖,但需要本地部署或者使用第三方工具。…

什么是Dubbo?Dubbo框架知識點,面試題總結

本篇包含什么是Dubbo,Dubbo的實現原理,節點角色說明,調用關系說明,在實際開發的場景中應該如何選擇RPC框架,Dubbo的核心架構,Dubbo的整體架構設計及分層。 主頁還有其他的面試資料,有需要的可以…

kafka消費能力壓測:使用官方工具

背景 在之前的業務場景中,我們發現Kafka的實際消費能力遠低于預期。盡管我們使用了kafka-go組件并進行了相關測試,測試情況見《kafka-go:性能測試》這篇文章。但并未能準確找出消費能力低下的原因。 我們曾懷疑這可能是由我的電腦網絡帶寬問題或Kafka部…

【大學生職業規劃大賽備賽PPT資料PDF | 免費共享】

自取鏈接: 鏈接:https://pan.quark.cn/s/4fa45515325e 📢 同學,你是不是正在為職業規劃大賽發愁? 想展示獨特思路卻不知如何下手? 想用專業模板卻找不到資源? 別擔心!我整理了全網…

ubuntu20動態修改ip,springboot中yaml的內容的讀取,修改,寫入

文章目錄 前言引入包yaml原始內容操作目標具體代碼執行查看結果總結: 前言 之前有個需求,動態修改ubuntu20的ip,看了下: 本質上是修改01-netcfg.yaml文件,然后執行netplan apply就可以了。 所以,需求就變成了 如何對ya…

【算法】雙指針(下)

目錄 查找總價格為目標值的兩個商品 暴力解題 雙指針解題 三數之和 雙指針解題(左右指針) 四數之和 雙指針解題 雙指針關鍵點 注意事項 查找總價格為目標值的兩個商品 題目鏈接:LCR 179. 查找總價格為目標值的兩個商品 - 力扣(LeetCode&#x…

Windows 圖形顯示驅動開發-IoMmu 模型

輸入輸出內存管理單元 (IOMMU) 是一個硬件組件,它將支持具有 DMA 功能的 I/O 總線連接到系統內存。 它將設備可見的虛擬地址映射到物理地址,使其在虛擬化中很有用。 在 WDDM 2.0 IoMmu 模型中,每個進程都有一個虛擬地址空間,即&a…

軟件測評報告包括哪些內容?第三方軟件測評機構推薦

在當今信息技術飛速發展的時代,軟件的品質與性能直接影響到企業的運營效率和市場競爭力。為了確保軟件的可用性和可靠性,軟件測評成為一個不可或缺的環節,軟件測評報告也是對軟件產品進行全面評估后形成的一份文檔,旨在系統地紀錄…

深淺拷貝區別,怎么區別使用

在 JavaScript 中,深拷貝(Deep Copy) 和 淺拷貝(Shallow Copy) 是兩種不同的對象復制方式,它們的區別主要體現在對嵌套對象的處理上。以下是它們的詳細對比及使用場景: 1. 淺拷貝(Sh…

tailscale + derp中繼 + 阿里云服務器 (無域名版)

使用tailscale默認的中轉節點延遲很高,因為服務器都在國外。 感謝大佬提供的方案:Tailscale 搭建derp中繼節點,不需要域名,不需要備案,不需要申請證書(最新) - yafeng - 博客園 基于這個方案&…

【異常錯誤】pycharm debug view變量的時候顯示不全,中間會以...顯示

異常問題: 這個是在新版的pycharm中出現的,出現的問題,點擊view后不全部顯示,而是以...折疊顯示 在setting中這么設置一下就好了: 解決辦法: https://youtrack.jetbrains.com/issue/PY-75568/Large-stri…