【深度學習】實驗四 卷積神經網絡CNN

?實驗四 ?卷積神經網絡CNN

一、實驗學時:?2學時

二、實驗目的

  1. 掌握卷積神經網絡CNN的基本結構;
  2. 掌握數據預處理、模型構建、訓練與調參;
  3. 探索CNN在MNIST數據集中的性能表現;

三、實驗內容

實現深度神經網絡CNN。

四、主要實驗步驟及結果

1.搭建一個CNN網絡,使用MNIST手寫數字數據集進行訓練與測試,并體現模型最終結果,CNN網絡的具體框架可參考下圖,也可自己設計:

圖4-1 CNN架構圖

(1)該圖表示輸入層為28*28*1的尺寸,符合MNIST數據集的標準尺寸。

(2)第一個卷積層,使用5*5卷積核,32個濾波器,填充(Padding)為2。輸出尺寸為28*28*32。

(3)第一個池化層,使用2*2池化窗口,步長(stride)為2。輸出尺寸為14*14*32。

(4)第二個卷積層,使用5*5卷積核,64個濾波器,填充(Padding)為2。輸出尺寸為14*14*64。

(5)第二個池化層,使用2*2池化窗口,步長(stride)為2。輸出尺寸為7*7*64。

(6)全連接層包含1024個神經元,輸出尺寸為1*1*1024。

(7)Dropout層用于防止過擬合。

(8)輸出層包含10個神經元,對應手寫數字的0-9。輸出尺寸為1*1*10。

模型實現:

以該架構圖搭建CNN網絡,使用MNIST手寫數字數據集進行訓練與測試,訓練和測試結果如圖4-2所示:

圖4-2?CNN測試結果

2.嘗試使用不同的數據增強方法、優化器、損失函數、學習率、batch size和迭代次數來進行訓練,記錄訓練過程,評估模型性能,保存最佳模型。

編號

batch size

訓練輪次

學習率

數據增強方法

優化器

實驗結果

1

32

2

1e-4

Adam

98.62%

2

64

2

1e-4

Adam

98.56%

3

64

4

1e-4

Adam

99.08%

4

64

4

3e-4

Adam

99.08%

5

64

4

3e-4

旋轉+平移

Adam

98.90%

5

64

4

3e-4

Adam(L2正則化)

99.23%

6

64

4

1e-4

SGD+momentum

97.30%

其中數據增強方法采用隨機旋轉和平移嗎,原始代碼中包含ToTensor()和Normalize(),給原始代碼添加隨機旋轉10度和隨機平移10%,代碼如下:

# 數據加載(歸一化)
transform = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(10),  # 隨機旋轉10度torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),  # 隨機平移10%torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

優化器選擇方面使用SGD+momentum(0.9)替代原Adam優化器,

# 使用SGD+momentum
optimizer = torch.optim.SGD(model.parameters(), lr=LEARN_RATE, momentum=0.9)

根據訓練過程記錄的數據,最佳模型尊卻綠為99.23%,最佳模型代碼如下:

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoaderBATCH_SIZE = 64
EPOCHS = 4
LEARN_RATE = 3e-4
DROPOUT_RATE = 0.5device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 數據加載(歸一化)
transform = torchvision.transforms.Compose([# torchvision.transforms.RandomRotation(10),  # 隨機旋轉10度# torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),  # 隨機平移10%torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))
])train_data = torchvision.datasets.MNIST(root='./mnist',train=True,download=True,transform=transform
)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)test_data = torchvision.datasets.MNIST(root='./mnist',train=False,transform=transform
)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_layers = nn.Sequential(# 第一層卷積:5x5 卷積核,32 個過濾器,padding=2nn.Conv2d(1, 32, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),  # 池化后 14x14x32# 第二層卷積:5x5 卷積核,64 個過濾器,padding=2nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2)  # 池化后 7x7x64)self.fc_layers = nn.Sequential(nn.Linear(64 * 7 * 7, 1024),  # 全連接層:7x7x64 → 1024nn.ReLU(),nn.Dropout(DROPOUT_RATE),  # Dropout層nn.Linear(1024, 10)  # 輸出層:1024 → 10)self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')nn.init.constant_(m.bias, 0)def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)  # 展平操作x = self.fc_layers(x)return xmodel = CNN().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE, weight_decay=1e-5)
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARN_RATE, momentum=0.9)  # 使用SGD+momentum
# 訓練循環
for epoch in range(EPOCHS):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch {epoch + 1} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.4f}')# 測試
model.eval()
correct = 0
with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)pred = output.argmax(dim=1)correct += pred.eq(target).sum().item()print(f'Test Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)')

3.使用畫圖工具將自己的學號逐個寫出,使用保存的最佳模型對每個數字進行推理,比較模型對每個數字的準確率預測,也可以嘗試實現一個實時識別手寫數字的demo。
(1)使用畫圖工具將自己的學號逐個寫出,進行反色處理,并將圖片命名為“x_001.png”格式。

圖4-3手寫數字

(2)在訓練代碼(CNN.py)中添加模型保存代碼。

torch.save(model.state_dict(), 'mnist_cnn.pth')

(3)編寫推理代碼讀取img文件夾中的手寫圖片并預測,預測代碼如下所示:

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os# 定義模型結構(需與訓練代碼一致)
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2))self.fc_layers = nn.Sequential(nn.Linear(64 * 7 * 7, 1024),nn.ReLU(),nn.Dropout(0.5),nn.Linear(1024, 10))def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)x = self.fc_layers(x)return x# 加載模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN().to(device)
model.load_state_dict(torch.load('mnist_cnn.pth', map_location=device))
model.eval()# 定義預處理(與訓練一致)
transform = transforms.Compose([transforms.Resize((28, 28)),  # 確保輸入為28x28transforms.Grayscale(num_output_channels=1),  # 轉換為單通道transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 遍歷img文件夾中的圖片并推理
img_dir = 'img'
digit_stats = {str(i): {'correct': 0, 'total': 0} for i in range(10)}for filename in os.listdir(img_dir):if filename.lower().endswith(('.png', '.jpg', '.jpeg')):# 從文件名中提取真實標簽(假設文件名為 "label_xxx.png")try:true_label = filename.split('_')[0]  # 例如文件名 "3_001.png" → 標簽為3true_label = int(true_label)if true_label < 0 or true_label > 9:continueexcept:print(f"跳過文件 {filename}(文件名格式錯誤)")continue# 加載并預處理圖像img_path = os.path.join(img_dir, filename)image = Image.open(img_path)image = transform(image).unsqueeze(0).to(device)  # 添加batch維度# 推理with torch.no_grad():output = model(image)pred = output.argmax(dim=1).item()# 統計結果digit_stats[str(true_label)]['total'] += 1if pred == true_label:digit_stats[str(true_label)]['correct'] += 1print(f"圖片 {filename} 真實標簽: {true_label}, 預測: {pred} → {'正確' if pred == true_label else '錯誤'}")# 計算每個數字的準確率
accuracies = {}
for digit in digit_stats:if digit_stats[digit]['total'] > 0:acc = digit_stats[digit]['correct'] / digit_stats[digit]['total']accuracies[digit] = accprint(f"數字 {digit} 的準確率: {acc:.2%}")

預測結果如圖4-4所示:

圖4-4預測結果

預測結果顯示“1”和“4”預測結果錯誤,其他均正確。

五、實驗小結(包括問題和解決辦法、心得體會、意見與建議等)

1.問題和解決辦法:

問題1:RuntimeError: Dataset not found. You can use download=True to download it。

解決方法:添加下載訓練集的參數download=True。

問題2:使用SGD+momentum優化器后,準確率反而下降了。

解決方法:因為SGD對學習率比較敏感,學習率沒有適配,使用StepLR梯度衰減,另外也可以增加訓練輪次。

問題3:預測結果全部錯誤。

解決方法:圖片要像素28*28,且黑色背景,白色筆跡,對Windows畫圖的圖片反色處理即可。

2.心得體會:通過本次CNN手寫數字識別實驗的完整實踐,我深刻體會到深度學習模型性能的提升是一個系統工程,需要從數據、模型、訓練策略到結果分析的全流程精細化把控,嘗試使用不同的數據增強方法、優化器、損失函數、學習率、batch size和迭代次數來進行訓練,迭代出最佳模型,再手寫數字進行測試。通過以上的學習和實踐,我對神經網絡的原理和應用有了更深入的理解。神經網絡的發展給人工智能帶來了巨大的影響,它在圖像識別、自然語言處理等領域發揮著重要的作用。我相信,隨著技術的進步,神經網絡將會有更廣泛的應用。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/82311.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/82311.shtml
英文地址,請注明出處:http://en.pswp.cn/web/82311.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SpringBoot高校宿舍信息管理系統小程序

概述 基于SpringBoot的高校宿舍信息管理系統小程序項目&#xff0c;這是一款非常適合高校使用的信息化管理工具。該系統包含了完整的宿舍管理功能模塊&#xff0c;采用主流技術棧開發&#xff0c;代碼結構清晰&#xff0c;非常適合學習和二次開發。 主要內容 這個宿舍管理系…

Redis 難懂命令-- ZINTERSTORE

**背景&#xff1a;**學習的過程中 常用的redis命令都能快速通過官方文檔理解 但是還是有一些比較難懂的命令 **目的&#xff1a;**寫博客記錄一下&#xff08;當然也可以使用AI搜索&#xff09; 在Redis中&#xff0c;ZINTERSTORE 是一個用于計算多個有序集合&#xff08;So…

React 路由管理與動態路由配置實戰

React 路由管理與動態路由配置實戰 前言 在現代單頁應用(SPA)開發中&#xff0c;路由管理已經成為前端架構的核心部分。隨著React應用規模的擴大&#xff0c;靜態路由配置往往難以滿足復雜業務場景的需求&#xff0c;尤其是當應用需要處理權限控制、動態菜單和按需加載等高級…

【學習筆記】深度學習-梯度概念

一、定義 梯度向量不僅表示函數變化的速度&#xff0c;還表示函數增長最快的方向 二、【問】為什么說它表示方向&#xff1f; 三、【問】那在深度學習梯度下降的時候&#xff0c;還要判斷梯度是正是負來更新參數嗎&#xff1f; 假設某個參數是 w&#xff0c;損失函數對它的…

題海拾貝:P8598 [藍橋杯 2013 省 AB] 錯誤票據

Hello大家好&#xff01;很高興我們又見面啦&#xff01;給生活添點passion&#xff0c;開始今天的編程之路&#xff01; 我的博客&#xff1a;<但凡. 我的專欄&#xff1a;《編程之路》、《數據結構與算法之美》、《題海拾貝》 歡迎點贊&#xff0c;關注&#xff01; 1、題…

webpack的安裝及其后序部分

npm install原理 這個其實就是npm從registry下載項目到本地&#xff0c;沒有什么好說的 值得一提的是npm的緩存機制&#xff0c;如果多個項目都需要同一個版本的axios&#xff0c;每一次重新從registry中拉取的成本過大&#xff0c;所以會有緩存&#xff0c;如果緩存里有這個…

百度golang研發一面面經

輸入一個網址&#xff0c;到顯示界面&#xff0c;中間的過程是怎樣的 IP 報文段的結構是什么 Innodb 的底層結構 知道幾種設計模式 工廠模式 簡單工廠模式&#xff1a;根據傳入類型參數判斷創建哪種類型對象工廠方法模式&#xff1a;由子類決定實例化哪個類抽象工廠模式&#…

使用 HTML + JavaScript 實現圖片裁剪上傳功能

本文將詳細介紹一個基于 HTML 和 JavaScript 實現的圖片裁剪上傳功能。該功能支持文件選擇、拖放上傳、圖片預覽、區域選擇、裁剪操作以及圖片下載等功能&#xff0c;適用于需要進行圖片處理的 Web 應用場景。 效果演示 項目概述 本項目主要包含以下核心功能&#xff1a; 文…

GO+RabbitMQ+Gin+Gorm+docker 部署 demo

更多個人筆記見&#xff1a; &#xff08;注意點擊“繼續”&#xff0c;而不是“發現新項目”&#xff09; github個人筆記倉庫 https://github.com/ZHLOVEYY/IT_note gitee 個人筆記倉庫 https://gitee.com/harryhack/it_note 個人學習&#xff0c;學習過程中還會不斷補充&…

【安全】VulnHub靶場 - W1R3S

【安全】VulnHub靶場 - W1R3S 備注一、故事背景二、Web滲透1.主機發現端口掃描2.ftp服務3.web服務 三、權限提升 備注 2025/05/22 星期四 簡單的打靶記錄 一、故事背景 您受雇對 W1R3S.inc 個人服務器進行滲透測試并報告所有發現。 他們要求您獲得 root 訪問權限并找到flag&…

WEB安全--SQL注入--MSSQL注入

一、SQLsever知識點了解 1.1、系統變量 版本號&#xff1a;version 用戶名&#xff1a;USER、SYSTEM_USER 庫名&#xff1a;DB_NAME() SELECT name FROM master..sysdatabases 表名&#xff1a;SELECT name FROM sysobjects WHERE xtypeU 字段名&#xff1a;SELECT name …

工作流引擎-18-開源審批流項目之 plumdo-work 工作流,表單,報表結合的多模塊系統

工作流引擎系列 工作流引擎-00-流程引擎概覽 工作流引擎-01-Activiti 是領先的輕量級、以 Java 為中心的開源 BPMN 引擎&#xff0c;支持現實世界的流程自動化需求 工作流引擎-02-BPM OA ERP 區別和聯系 工作流引擎-03-聊一聊流程引擎 工作流引擎-04-流程引擎 activiti 優…

Docker 筆記 -- 借助AI工具強勢輔助

常用命令 鏡像管理命令&#xff1a; docker images&#xff08;列出鏡像&#xff09; docker pull&#xff08;拉取鏡像&#xff09; docker build&#xff08;構建鏡像&#xff09; docker save/load&#xff08;保存/加載鏡像&#xff09; 容器操作命令 docker run&#…

5G-A時代與p2p

5G-A時代正在走來&#xff0c;那么對P2P的影響有多大。 5G-A作為5G向6G過渡的關鍵技術&#xff0c;將數據下載速率從千兆提升至萬兆&#xff0c;上行速率從百兆提升至千兆&#xff0c;時延降至毫秒級。這種網絡性能的跨越式提升&#xff0c;為P2P提供了更強大的底層支撐&#x…

Redis-6.2.9 主從復制配置和詳解

1 主從架構圖 192.168.254.120 u24-redis-120 #主庫 192.168.254.121 u24-redis-121 #從庫 2 redis軟件版本 rootu24-redis-121:~# redis-server --version Redis server v6.2.9 sha00000000:0 malloclibc bits64 build56edd385f7ce4c9b 3 主庫redis配置文件(192.168.254.1…

004 flutter基礎 初始文件講解(3)

之前&#xff0c;我們正向的學習了一些flutter的基礎&#xff0c;如MaterialApp&#xff0c;Scaffold之類的東西&#xff0c;那么接下來&#xff0c;我們將正式接觸原代碼&#xff1a; import package:flutter/material.dart;void main() {runApp(const MyApp()); }class MyAp…

Linux 系統 Docker Compose 安裝

個人博客地址&#xff1a;Linux 系統 Docker Compose 安裝 | 一張假鈔的真實世界 本文方法是直接下載 GitHub 項目的 release 版本。項目地址&#xff1a;GitHub - docker/compose: Define and run multi-container applications with Docker。 執行以下命令將發布程序加載至…

Tree 樹形組件封裝

整體思路 數據結構設計 使用遞歸的數據結構&#xff08;TreeNode&#xff09;表示樹形數據每個節點包含id、name、可選的children數組和selected狀態 狀態管理 使用useState在組件內部維護樹狀態的副本通過deepCopyTreeData函數進行深拷貝&#xff0c;避免直接修改原始數據 核…

tortoisegit 使用rebase修改歷史提交

在 TortoiseGit 中使用 rebase 修改歷史提交&#xff08;如修改提交信息、合并提交或刪除提交&#xff09;的步驟如下&#xff1a; --- ### **一、修改最近一次提交** 1. **操作**&#xff1a; - 右鍵項目 → **TortoiseGit** → **提交(C)** - 勾選 **"Amend…

中科院報道鐵電液晶:從實驗室突破到多場景應用展望

2020年的時候&#xff0c;相信很多關注科技前沿的朋友都注意到&#xff0c;中國科學院一篇報道聚焦一項有望改寫顯示產業格局的新技術 —— 鐵電液晶&#xff08;FeLC&#xff09;。這項被業內稱為 "下一代顯示核心材料" 的研究&#xff0c;究竟取得了哪些實質性進展…