卷積神經網絡實戰:MNIST手寫數字識別

夜漸深,我還在😘

老地方

睡覺了🙌

文章目錄

  • 📚 卷積神經網絡實戰:MNIST手寫數字識別
    • 🧠 4.1 預備知識
      • ?? 4.1.1 `torch.nn.Conv2d()` 三維卷積操作
      • 📏 4.1.2 `nn.MaxPool2d()` 池化層的作用
    • 📥 4.2 數據輸入與處理
      • 🗃? MNIST數據集加載
      • 🔍 數據格式驗證
    • 🚀 4.3 卷積模型構建與訓練
      • 🧩 4.3.1 網絡架構設計
      • ? 4.3.2 GPU加速與模型初始化
      • 📉 4.3.3 訓練與評估函數
      • 🔁 4.3.4 模型訓練循環
    • 🧪 4.4 函數式API
      • 🔌 4.4.1導入函數式模塊
      • ? 4.4.2激活函數應用
      • 🧮 4.4.3池化操作實現

📚 卷積神經網絡實戰:MNIST手寫數字識別

🧠 4.1 預備知識

?? 4.1.1 torch.nn.Conv2d() 三維卷積操作

torch.nn.Conv2d()是PyTorch中實現三維卷積的核心方法,其關鍵參數包括:

  • in_channels:輸入通道數(彩色圖為3,灰度圖為1)
  • out_channels:輸出通道數(卷積核數量)
  • kernel_size:卷積核尺寸(如3×3)
  • stride:步長(默認為1)
  • padding:填充(默認為0)
import torch
from torch import nn# 創建隨機輸入數據 (batch_size=20, 通道=3, 高=256, 寬=356)
input = torch.randn(20, 3, 256, 256) # 定義卷積層:輸入通道3→輸出通道16,3×3卷積核,步長1,填充1
conv_layer = nn.Conv2d(3, 16, (3, 3), stride=1, padding=1)# 執行卷積操作
output = conv_layer(input)
output.shape  # torch.Size([20, 16, 256, 256])

💡 輸出解析:經過卷積后,特征圖尺寸保持256×256不變(因padding=1),通道數從3增加到16

📏 4.1.2 nn.MaxPool2d() 池化層的作用

池化層的重要性

  1. 🎯 增大感受野:小卷積核視野有限,池化間接擴大覆蓋區域
  2. 🛡? 降低過擬合:減少參數量,增強模型泛化能力
  3. ? 加速計算:縮減特征圖尺寸,減少后續計算量

核心參數kernel_size(池化窗口尺寸)

# 創建隨機圖像批次 (64張256×256的RGB圖像)
img_batch = torch.randn(64, 3, 256, 256)# 2×2最大池化操作
pool_out = torch.max_pool2d(img_batch, kernel_size=(2, 2))
pool_out.shape  # torch.Size([64, 3, 128, 128])

💡 輸出解析:池化后圖像尺寸減半(256→128),通道數不變,實現特征降維


📥 4.2 數據輸入與處理

🗃? MNIST數據集加載

使用PyTorch內置工具加載手寫數字數據集:

import torchvision
from torchvision.transforms import ToTensor# 下載并加載訓練集/測試集
train_ds = torchvision.datasets.MNIST("data/", train=True, transform=ToTensor(), download=True
)
test_ds = torchvision.datasets.MNIST("data/", train=False, transform=ToTensor(), download=True
)# 創建數據加載器 (batch_size=64)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=64)

🔍 數據格式驗證

imgs, labels = next(iter(train_dl))
print(imgs.shape, labels.shape)  # torch.Size([64, 1, 28, 28]) torch.Size([64])

? 數據格式:符合卷積網絡輸入要求(batch_size, 通道, 高, 寬)


🚀 4.3 卷積模型構建與訓練

🧩 4.3.1 網絡架構設計

LeNet風格CNN模型

class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 卷積層1:1→6通道,5×5卷積核self.conv1 = nn.Conv2d(1, 6, 5)  # 卷積層2:6→16通道,5×5卷積核self.conv2 = nn.Conv2d(6, 16, 5)  # 全連接層1:256→256節點self.linear1 = nn.Linear(16*4*4, 256)  # 輸出層:256→10節點 (10個數字類別)self.linear2 = nn.Linear(256, 10)  def forward(self, x):# 卷積→ReLU→池化 (28×28 → 12×12)x = torch.max_pool2d(torch.relu(self.conv1(x)), (2, 2))  # 卷積→ReLU→池化 (12×12 → 4×4)x = torch.max_pool2d(torch.relu(self.conv2(x)), (2, 2))  # 展平特征圖x = x.view(-1, 16*4*4)  # 全連接層→ReLUx = torch.relu(self.linear1(x))  # 輸出層return self.linear2(x)  

? 4.3.2 GPU加速與模型初始化

# 自動檢測GPU加速
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)
model
Model((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(linear1): Linear(in_features=256, out_features=256, bias=True)(linear2): Linear(in_features=256, out_features=10, bias=True)
)

📉 4.3.3 訓練與評估函數

# 訓練函數
def train(dataloader, model, loss_fn, optimizer):model.train()total_samples = len(dataloader.dataset)total_batches = len(dataloader)train_loss, correct = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)# 前向傳播pred = model(X)loss = loss_fn(pred, y)# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()# 統計指標with torch.no_grad():correct += (pred.argmax(1) == y).sum().item()train_loss += loss.item()return train_loss/total_batches, correct/total_samples# 測試函數
def test(dataloader, model):model.eval()total_samples = len(dataloader.dataset)total_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).sum().item()return test_loss/total_batches, correct/total_samples

🔁 4.3.4 模型訓練循環

# 超參數設置
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
loss_fn = nn.CrossEntropyLoss()
epochs = 20# 訓練日志
for epoch in range(epochs):train_loss, train_acc = train(train_dl, model, loss_fn, optimizer)test_loss, test_acc = test(test_dl, model)# 打印訓練進度print(f"epoch:{epoch:2d}, train_loss:{train_loss:.5f}, "f"train_acc:{train_acc*100:.1f}%, test_loss:{test_loss:.5f}, "f"test_acc:{test_acc*100:.1f}%")

訓練輸出

epoch: 0, train_loss:0.24543, train_acc:92.8%, test_loss:0.07341, test_acc:97.7%
epoch: 1, train_loss:0.06720, train_acc:97.9%, test_loss:0.04788, test_acc:98.4%
...
epoch:19, train_loss:0.00509, train_acc:99.8%, test_loss:0.04585, test_acc:99.2%
Done

🎯 性能總結:模型在20個epoch內達到**99.2%**的測試準確率,顯著優于全連接網絡


🧪 4.4 函數式API

🔌 4.4.1導入函數式模塊

import torch.nn.functional as F # 行業標準導入方式

? 4.4.2激活函數應用

# 傳統方式
output = torch.relu(input)# 函數式API方式
output = F.relu(input)

🧮 4.4.3池化操作實現

# 傳統方式
pooled = torch.max_pool2d(input, kernel_size=2)# 函數式API方式
pooled = F.max_pool2d(input, kernel_size=2)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/88862.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/88862.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/88862.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

HarmonyOS應用無響應(AppFreeze)深度解析:從檢測原理到問題定位

HarmonyOS應用無響應(AppFreeze)深度解析:從檢測原理到問題定位 在日常應用使用中,我們常會遇到點擊無反應、界面卡頓甚至完全卡死的情況——這些都可能是應用無響應(AppFreeze) 導致的。對于開發者而言&am…

湖北設立100億元人形機器人產業投資母基金

湖北設立100億元人形機器人產業投資母基金 湖北工信 2025年07月08日 12:03 湖北 ,時長01:20 近日,湖北設立100億元人形機器人產業投資母基金,重點支持人形機器人和人工智能相關產業發展。 人形機器人產業投資母基金由湖北省財政廳依托省政府…

時序預測 | Pytorch實現CNN-LSTM-KAN電力負荷時間序列預測模型

預測效果 代碼主要功能 該代碼實現了一個結合CNN(卷積神經網絡)、LSTM(長短期記憶網絡)和KAN(Kolmogorov-Arnold Network)的混合模型,用于時間序列預測任務。主要流程包括: 數據加…

OCR 識別:車牌識別相機的 “火眼金睛”

車牌識別相機在交通管理、停車場收費等場景中,需快速準確識別車牌信息。但實際環境中,車牌可能存在污漬、磨損、光照不均等情況,傳統識別方式易出現誤讀、漏讀。OCR 技術讓車牌識別相機如虎添翼。它能精準提取車牌上的字符,不管是…

Java面試基礎:面向對象(2)

1. 接口里可以定義哪些方法抽象方法:抽象方法是接口的核心部分,所有實現接口的類都必須實現這些方法。抽象方法默認是 public 和 abstract 修飾,這些修飾符可以省略。public interface Animal {void Sound(); }默認方法:默認方法是…

有哪些更加簡潔的for循環?循環語句?

目錄 簡潔的for循環 循環過程修改循環變量 循環語句 不同編程語言支持的循環語句 foreach 無限循環 for循環歷史 break和continue 循環判斷結束值 循環標簽 循環語句優化 循環表達式返回值 簡潔的for循環 如果需要快速枚舉一個集合的元素,盡管C語言可以…

RK3568/3588 Android 12 源碼默認使用藍牙mic錄音

遇到客戶一個需求,如果連接了帶mic的藍牙耳機,默認所有的錄音要走藍牙mic通道。這個功能搞了好久,終于搞定了。1. 向RK尋求幫助,先打通 bt sco能力。此時,還無法默認就切換到藍牙 mic通道,接下來我們需求默…

解鎖HTTP:從理論到實戰的奇妙之旅

目錄一、HTTP 協議基礎入門1.1 HTTP 協議是什么1.2 HTTP 協議的特點1.3 HTTP 請求與響應的結構二、HTTP 應用場景大揭秘2.1 網頁瀏覽2.2 API 調用2.3 文件傳輸2.4 內容分發網絡(CDN)2.5 流媒體服務三、HTTP 應用實例深度剖析3.1 使用 JavaScript 的 fetc…

uvm_config_db examples

通過uvm_config_db類訪問的UVM配置數據庫,是在多個測試平臺組件之間傳遞不同對象的絕佳方式。 methods 有兩個主要函數用于從數據庫中放入和檢索項目,分別是 set() 和 get()。 static function void set ( uvm_component cntxt,string inst_name,string …

(C++)任務管理系統(文件存儲)(正式版)(迭代器)(list列表基礎教程)(STL基礎知識)

目錄 前言: 源代碼: 代碼解析: 一.頭文件和命名空間 1. #include - 輸入輸出功能2. #include - 鏈表容器3. #include - 字符串處理4. using namespace std; - 命名空間 可視化比喻:建造房子 🏠 二.menu()函數 …

Java 中的異步編程詳解

前言 在現代軟件開發中,異步編程(Asynchronous Programming) 已經成為構建高性能、高并發應用程序的關鍵技術之一。Java 作為一門廣泛應用于后端服務開發的語言,在其發展過程中不斷引入和優化異步編程的支持。從最初的 Thread 和…

MySQL邏輯刪除與唯一索引沖突解決

問題背景 在MySQL數據庫設計中,邏輯刪除(軟刪除)是一種常見的實踐,它通過設置標志位(如is_delete)來標記記錄被"刪除",而不是實際刪除數據。然而,當表中存在唯一約束時&am…

php命名空間用正斜杠還是反斜杠?

在PHP中,命名空間使用反斜杠(\)作為分隔符,這是PHP語言規范明確規定的。反斜杠在命名空間中扮演路徑分隔的角色,用于區分不同層級的命名空間。 具體說明:語法規則 PHP命名空間使用反斜杠(\&…

《從依賴糾纏到接口協作:ASP.NET Core注入式開發指南》

在C#的ASP.NET Core開發中,依賴注入絕非簡單的技術技巧,而是重構代碼關系的底層邏輯。它像一套隱形的神經網絡,讓程序模塊擺脫硬編碼的束縛,在運行時實現動態連接,從而為系統注入可測試、可進化的核心生命力。理解其深…

星云ERP本地環境搭建筆記

看到星云ERP兩個比較實用的功能,編號規則和打印模板,如下圖所示,于是本地跑起來學習學習。開發環境必備:1. JDK 1.82. MySQL 5.73. Redis 44. RabbitMQ 3.12.45. nodejs 206. pnpm 9.7.1 (npm install -g pnpm9.7.1)其他開發工具&…

RedisJSON 的 `JSON.ARRAPPEND`一行命令讓數組動態生長

1 、 為什么選擇 JSON.ARRAPPEND 在傳統的鍵值模型里,若要往數組尾部追加元素,通常需要 取→改→寫 三步: GET 整個 JSON;在應用層把元素 push 進數組;SET 回 Redis。 一條 JSON.ARRAPPEND 則可一次完成,具…

14:00開始面試,14:08就出來了,問的問題有點變態。。。

從小廠出來,沒想到在另一家公司又寄了。 到這家公司開始上班,加班是每天必不可少的,看在錢給的比較多的份上,就不太計較了。沒想到4月一紙通知,所有人不準加班,加班費不僅沒有了,薪資還要降40%…

Unity物理系統由淺入深第四節:物理約束求解與穩定性

Unity物理系統由淺入深第一節:Unity 物理系統基礎與應用 Unity物理系統由淺入深第二節:物理系統高級特性與優化 Unity物理系統由淺入深第三節:物理引擎底層原理剖析 Unity物理系統由淺入深第四節:物理約束求解與穩定性 物理引擎的…

深入淺出Kafka Consumer源碼解析:設計哲學與實現藝術

一、Kafka Consumer全景架構 1.1 核心組件交互圖 #mermaid-svg-JDEEOd2M5PzLkYa6 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JDEEOd2M5PzLkYa6 .error-icon{fill:#552222;}#mermaid-svg-JDEEOd2M5PzLkYa6 .erro…

Matplotlib(一)- 數據可視化與Matplotlib

文章目錄一、數據可視化1. 數據可視化的概念2. 數據可視化流程3. 數據可視化目的4. 常見的可視化圖表4.1 折線圖4.2 柱形圖4.3 條形圖4.4 堆積圖4.4.1 堆積面積圖4.4.2 堆積柱形圖和堆積條形圖4.5 直方圖4.6 箱形圖4.7 餅圖4.8 散點圖4.9 氣泡圖4.10 誤差棒圖4.11 雷達圖二、Py…