深度學習----卷積神經網絡實現數字識別

一、準備工作

導入庫,導入數據集,劃分訓練批次數量,規定訓練硬件(這部分

import torch
from torch import nn  # 導入神經網絡模塊
from torch.utils.data import DataLoader  # 數據包管理工具,打包數據
from torchvision import datasets  # 封裝了很多與圖像相關的模型,和數據集
from torchvision.transforms import ToTensor  # 將其他數據類型轉化為張量train_data = datasets.MNIST(root='data',train=True,  # 是否讀取下載后數據中的訓練集download=True,  # 如果之前下載過則不用下載transform=ToTensor()
)
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()
)train_dataloader = DataLoader(train_data,batch_size=256)#是一個類,現在初始化了,但沒開始打包,訓練開始才打包
test_dataloader = DataLoader(test_data,batch_size=256)device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')

二、定義神經網絡(重點)

這部分比較重要,我分開講

1、類定義與繼承

class CNN(nn.Module):

這里定義了一個名為CNN的類,它繼承自 PyTorch 的nn.Module類。nn.Module是 PyTorch 中所有神經網絡模塊的基類,通過繼承它,我們可以利用 PyTorch 提供的各種功能,如參數管理、設備遷移等。

2、初始化方法

def __init__(self):super().__init__()

這是類的構造函數,super().__init__()調用了父類nn.Module的構造函數,確保父類得到正確初始化。

3、網絡層定義

  • 第一個卷積塊(conv1)
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,      # 輸入通道數,1表示灰度圖像out_channels=16,    # 輸出通道數/卷積核數量kernel_size=5,      # 卷積核大小5×5stride=1,           # 步長為1padding=2,          # 填充為2,保持特征圖大小不變),nn.ReLU(),               # ReLU激活函數nn.MaxPool2d(kernel_size=2),  # 2×2最大池化
)

這個卷積塊接收 1 通道的輸入,通過 16 個 5×5 的卷積核進行卷積操作,然后經過 ReLU 激活和 2×2 的最大池化。

  • 第二個卷積塊(conv2)
self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),  # 16→32通道,5×5卷積核nn.ReLU(),nn.Conv2d(32,32,5,1,2),  # 32→32通道,5×5卷積核nn.ReLU(),nn.Conv2d(32,32,5,1,2),  # 32→32通道,5×5卷積核nn.ReLU(),nn.Conv2d(32,64,5,1,2),  # 32→64通道,5×5卷積核nn.ReLU(),nn.Conv2d(64,64,5,1,2),  # 64→64通道,5×5卷積核nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 2×2最大池化
)

這個卷積塊包含多個卷積層,逐步增加通道數,并在最后進行一次最大池化。

  • 第三個卷積塊(conv3)
self.conv3 = nn.Sequential(nn.Conv2d(64,64,5,1,2),  # 64→64通道,5×5卷積核nn.ReLU(),
)

這是一個簡單的卷積塊,保持通道數不變。

  • 全連接層(out)
self.out = nn.Linear(64*7*7,10)

這是網絡的輸出層,將卷積得到的特征圖展平后映射到 10 個輸出(可能對應 10 類分類問題)。

4、前向傳播方法

def forward(self, x):x = self.conv1(x)    # 通過第一個卷積塊x = self.conv2(x)    # 通過第二個卷積塊x = self.conv3(x)    # 通過第三個卷積塊x = x.view(x.size(0),-1)  # 展平特征圖,保留批次維度output = self.out(x)  # 通過全連接層得到輸出return output

forward方法定義了數據在網絡中的流動路徑,即前向傳播過程。x.view(x.size(0),-1)將卷積操作得到的多維特征圖展平成一維向量,以便輸入到全連接層。

5、模型實例化

model = CNN().to(device)

創建 CNN 類的實例,并將模型遷移到指定的設備(CPU 或 GPU)上。

完整代碼:

 定義神經網絡,通過類的繼承
class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(#容器,添加網絡層nn.Conv2d(in_channels=1,out_channels = 16,kernel_size = 5,stride = 1,padding = 2,),nn.ReLU(),nn.MaxPool2d(kernel_size = 2),)self.conv2 = nn.Sequential(  # 容器,添加網絡層nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32,64,5,1,2),nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(64,64,5,1,2),nn.ReLU(),)self.out = nn.Linear(64*7*7,10)def forward(self, x):  # 前向傳播x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)output = self.out(x)return output
model = CNN().to(device)

三、模型的訓練

這一段和上一個博客的步驟一樣這里就不多做講解了

不了解的直接看

深度學習----由手寫數字識別案例來認識PyTorch框架-CSDN博客


def train(dataloader, model, loss_fn, optimizer):model.train()  # 開啟模型訓練模式,像 Dropout、BatchNorm 層會在訓練/測試時表現不同,需此設置batch_size_num = 1  # 用于計數當前處理到第幾個 batchfor X, y in dataloader:  # 從數據加載器中逐個取出 batch 的數據(特征 X、標簽 y )X, y = X.to(device), y.to(device)  # 把數據和標簽放到指定計算設備(CPU/GPU)pred = model(X)  # 將數據輸入模型,得到預測結果(模型自動做前向傳播計算 )loss = loss_fn(pred, y)  # 用損失函數計算預測結果和真實標簽的損失# 以下是反向傳播更新參數的標準流程optimizer.zero_grad()  # 清空優化器里參數的梯度,避免梯度累加影響計算loss.backward()  # 反向傳播,計算參數的梯度optimizer.step()  # 根據梯度,更新模型參數loss = loss.item()  # 取出損失張量的數值(脫離計算圖 )# 打印當前 batch 的損失和 batch 編號if batch_size_num % 100 == 0:print(f"loss:{loss:>7f}  [number:{batch_size_num}")batch_size_num += 1  # batch 計數加一def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for X , y in dataloader:X ,y = X.to(device),y.to(device)pred = model.forward(X)#.forward可以被省略test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizeprint(f"test result: \n Accuracy: {(100*correct)}%,Avg loss:{test_loss}")# print(list(model.parameters()))
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)#換優化器可以提高準確率,Adam,SGD等# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)
#epochs = 10
for t in range(epochs):print(f"輪次:{t+1}\n----------------------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done")
test(test_dataloader,model,loss_fn)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/94702.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/94702.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/94702.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

鴻蒙Harmony-從零開始構建類似于安卓GreenDao的ORM數據庫(四)

目錄 一,查詢表的所有數據 二,根據條件查詢數據 三,數據庫升級 前面章節已經講解了數據庫的創建,表的創建,已經增刪改等操作。下面我們來講解一下數據庫的查詢以及升級操作。 一,查詢表的所有數據 先來看看官方文檔: query(predicates: RdbPredicates, callback: Asy…

20250829_編寫10.1.11.213MySQL8.0異地備份傳輸腳本+在服務器上創建cron任務+測試成功

0.已知前提條件: 10.1.11.213 堡壘機訪問 mysql 8.0 版本 密碼在/root/.my.cnf 備份腳本:/data/backup_mysql/mysql_backup.sh alarm_system:動環數據庫 exit_and_entry:出入境數據庫 logs:備份日志 project_cg_view_prod:采購跟蹤系統 all :數據庫整體備份 imip_ecb…

PostgreSQL 流復制與邏輯復制性能優化與故障切換實戰經驗分享

PostgreSQL 流復制與邏輯復制性能優化與故障切換實戰經驗分享 在高可用和數據安全愈發受到重視的生產環境中,PostgreSQL 復制技術是保障業務連續性的重要手段。本文結合真實生產場景,分享流復制(Physical Replication)與邏輯復制&…

Django開發規范:構建可維護的AWS資源管理應用

引言 在現代Web開發中,遵循一致的開發規范對于項目的可維護性和團隊協作至關重要。本文基于實際的AWS資源管理項目,分享一套經過實踐檢驗的Django開發規范,涵蓋模型設計、Admin配置、管理命令和工具類開發等方面。 模型開發規范 數據模型設計原則 良好的數據模型設計是應…

機器學習可解釋庫Shapash的快速使用教程(五)

文章目錄1 快速使用1.1 安裝1.2 三個簡單步驟快速入門1.2.1 步驟 1:準備模型和數據1.2.2 步驟 2:聲明并編譯 SmartExplainer1.2.3 步驟 3:可視化和探索1.2.4 啟動 Web 應用1.2.5 將解釋結果導出為數據2 Shapash的后端集成2.1 方法一&#xff…

如何在emacs中添加imenu插件

在配置文件中添加: ;; 刪除現有的包管理器配置(如果有),然后添加以下:;; 初始化包管理器 (require package);; 清除現有的倉庫列表 (setq package-archives nil);; 添加正確的倉庫(注意:使用 H…

Linux下的網絡編程SQLITE3詳解

常用數據庫關系型數據庫將復雜的數據結構簡化為二維表格形式大型:Oracle、DB2中型:MySql、SQLServer小型:Sqlite非關系型數據庫以鍵值對存儲,且結構不固定JSONRedisMongoDBsqlite數據庫特點開源免費,C語言開發代碼量少…

適配openai

openai 腳本 stream腳本import os from openai import OpenAIclient OpenAI(base_url"http://127.0.0.1:9117/api/v1",api_keyos.environ["ACCESS_TOKEN"], )stream client.chat.completions.create(model "Qwen/Qwen2-7B-Instruct",messages…

一天認識一個神經網絡之--CNN卷積神經網絡

CNN 是一種非常強大的深度學習模型,尤其擅長處理像圖片這樣的網格結構數據。你可以把它想象成一個系統,它能像我們的大腦一樣,自動從圖片中學習并識別出各種特征,比如邊緣、角落、紋理,甚至是更復雜的物體部分&#xf…

13 SQL進階-InnoDB引擎(8.23)

一、邏輯存儲結構(1)表空間(ibd文件):一個mysql實例可以對應多個表空間,用于存儲記錄、索引等數據。cd /var/lib/mysql(2)段,分為數據段(leaf node segment&a…

MTK Linux DRM分析(二十四)- MTK mtk_drm_plane.c

一、代碼分析 mtk_drm_plane.h 和 mtk_drm_plane.c 兩個文件,并生成基于文本的函數調用圖,我將首先解析文件中的主要函數及其功能,然后根據代碼中的調用關系整理出調用圖。由于文件內容較長,我會專注于關鍵函數及其相互調用關系,并以清晰的文本形式呈現。 文件分析 1. …

滾珠導軌如何賦能精密制造?

在智能制造發展的趨勢下,新興行業對高精度、高穩定性的運動控制需求激增。作為直線傳動領域的“精密紐帶”,滾珠導軌憑借低摩擦、長壽命、高剛性優勢,廣泛應用于精密傳動領域,成為產業升級的關鍵。新能源汽車制造領域:…

醫療 AI 的 “破圈” 時刻:輔助診斷、藥物研發、慢病管理,哪些場景已落地見效?

一、引言在科技迅猛發展的當下,醫療領域正經歷著深刻變革,人工智能(AI)技術宛如一顆璀璨新星,強勢 “破圈” 闖入,為醫療行業帶來了前所未有的機遇與活力。從輔助醫生精準診斷病情,到助力藥企高…

【項目思維】編程思維學習路線(推薦)

本篇博客是一份系統性、分階段的 編程思維學習路線圖推薦,從零基礎小白到系統架構級別,幫助你全面建立和提升編程思維能力。 🚦 階段 0:思維準備(理解編程是什么) 🎯 學習目標: 理…

vue3+antd實現華為云OBS文件拖拽上傳詳解

1、文件上傳核心流程 選擇文件??:用戶通過拖拽或點擊選擇文件手動觸發上傳??:點擊"確定"按鈕后開始上傳(阻止自動上傳)??獲取上傳憑證??:從后端獲取華為云OBS的上傳配置構建表單數據??&#xff1…

Mac 開發環境與配置操作速查表

Mac 開發環境與配置操作速查表 安裝和配置 nvm / Node 安裝 Homebrew Homebrew 安裝參考文章 如果沒有VPN,不要使用此命令安裝! /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" brew --v…

【論文簡讀】MuGS

今天讀一篇ICCV 2025的文章,關注的是Generalizable Gaussian Splatting,作者來自華中科技大學。 文章鏈接:arxiv 代碼倉庫:https://github.com/EuclidLou/MuGS(摘要中的鏈接,但暫時404) 文章目…

基于SpringBoot和百度人臉識別API開發的保安門禁系統

角色: 管理員、保安 技術: Spring Boot, MyBatis, MySQL, PageHelper, Bootstrap, jQuery, JavaScript, CSS3, HTML5, JSP, 百度人臉識別API 核心功能: 小區保安門禁系統是一個基于Spring Boot技術棧開發的綜合性平臺,旨在實現小區…

抖音電商首創最嚴珠寶玉石質檢體系,推動行業規范與消費擴容

8月27日,“抖音電商開放日質檢專場”活動在廣州華林國際舉行。活動上,抖音電商首次對外介紹了質檢倉配一體化中心(QIC)的運作流程,并發布了服務升級計劃。這一行業首創的“先鑒定后發貨”模式,被認為推動了…

SpringBoot整合Spring WebFlux棄用自帶的logback,使用log4j2,并啟動異步日志處理

第一步&#xff1a;修改pom文件<!-- Spring Boot Starter WebFlux (排除默認日志) --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-webflux</artifactId><version>${spring-boot.vers…