卷積神經網絡(二):手寫數字識別項目(一)

文章目錄

  • 手寫數字識別項目
    • 一、準備數據集
    • 二、定義模型
    • 三、模型訓練
      • 3.1 導入依賴庫
      • 3.2 設備設置(CPU/GPU 自動選擇)
      • 3.3 超參數定義
      • 3.4數據集準備
        • 1.獲取數據集
        • 2.劃分訓練集與驗證集
        • 3.創建 DataLoader(按批次加載數據)
      • 3.5模型初始化與斷點續訓
      • 3.6損失函數與優化器定義
      • 3.7訓練函數(train ())
      • 3.8驗證函數(valid ())
      • 3.9主訓練循環(多輪訓練與驗證)
    • 四、模型訓練完整代碼
    • 五、總結流程

手寫數字識別項目

一、準備數據集

首先我們創建一個卷積模型,訓練的時候就需要一個原始的數據集,那么數據集哪里來?Pytorch官網其實有一些數據集,數據集地址
在這里插入圖片描述

我們使用到的數據集是MNIST

導入包

import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

使用數據集,所有的官方數據集都繼承 torch.utils.data.Dataset,如果你沒有數據集,那download = True,它會聯網下載到你本地。

# label: 數據集傳入的標簽值
def target_transform(label):return torch.tensor(label)ds = MNIST(root='./data',  # 保存或讀取數據的目錄train=True,  # 是否加載訓練數據集download=False,  # 是否下載數據集transform=ToTensor(),  # 用于轉換圖片的函數# target_transform=target_transform  # 用于轉換標簽的函數target_transform=lambda label: torch.tensor(label)  # 直接匿名函數轉換成張量
)

測試打印數據

print(len(ds))
print(ds[0])
print(ds[0][0].shape)

二、定義模型

簡單的圖像識別模型的套路:卷積 -> 激活 -> 池化 -> … -> 卷積 -> 激活 -> 池化 ->展平 -> 全連接層 -> 激活-> … -> 全連接層輸出,會將圖片縮小的同時增加通道數,當特征圖縮小到 10 以內,就結束卷積過程。之后我們會講到LeNet5模型,這兒我們簡單的定義一個模型進行訓練。

from torch import nn# 卷積激活池化 模塊
class ConvActivatePool(nn.Module):def __init__(self, in_channels, out_channels, kernel_size):super().__init__()# 一般卷積后會選擇讓圖片大小保持不變 進行填充self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')self.relu = nn.ReLU()# 池化在此處提取了特征的同時,讓圖片下采樣了self.pool = nn.MaxPool2d(2)def forward(self, x):x = self.conv(x)x = self.relu(x)y = self.pool(x)return yclass NumberRecognition(nn.Module):def __init__(self):super().__init__()self.cap1 = ConvActivatePool(1, 64, 11)self.cap2 = ConvActivatePool(64, 128, 5)# 分類層self.classifier = nn.Sequential(# 展平nn.Flatten(start_dim=1),# 全連接層nn.Linear(128 * 7 * 7, 2048),nn.ReLU(),nn.Dropout(p=0.3),nn.Linear(2048, 1024),nn.ReLU(),# 輸出結果為 10 分類,所以輸出層全連接輸出 10nn.Linear(1024, 10))# x 形狀 (N, C=1, H=28, W=28)def forward(self, x):x = self.cap1(x)# N x 64 x 14 x 14x = self.cap2(x)# N x 128 x 7 x 7# 圖片縮小到 10 以內,則停止卷積# 調用分類器,對圖片進行分類y = self.classifier(x)return yif __name__ == '__main__':import torchmodel = NumberRecognition()x = torch.rand(16, 1, 28, 28)y = model(x)print(y.shape)

三、模型訓練

3.1 導入依賴庫

import math
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from model import NumberRecognition

3.2 設備設置(CPU/GPU 自動選擇)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3.3 超參數定義

EPOCH = 10          # 訓練輪次:整個訓練集遍歷10次
LR = 1e-2           # 學習率:控制參數更新的步長(1e-2 = 0.01)
BATCH_SIZE = 10     # 批次大小:每次訓練用10個樣本更新一次參數
val_rate = 0.2      # 驗證集比例:從訓練集中劃分20%作為驗證集

3.4數據集準備

1.獲取數據集
ds = MNIST(root='./data',        # 數據集保存路徑(若不存在會自動創建)train=True,           # 加載訓練集(False則加載測試集)download=False,       # 是否自動下載數據集(首次運行需設為True)transform=ToTensor(), # 對圖像的變換:PIL→Tensor(0-1歸一化+維度調整)target_transform=lambda label: torch.tensor(label) # 對標簽的變換:int→Tensor
)
2.劃分訓練集與驗證集
ds_total_len = len(ds)          # 總樣本數:MNIST訓練集共60000個樣本
train_len = int(ds_total_len * (1 - val_rate)) # 訓練集樣本數:60000×0.8=48000
val_len = ds_total_len - train_len             # 驗證集樣本數:60000×0.2=12000
train_ds, val_ds = random_split(ds, [train_len, val_len]) # 隨機劃分
3.創建 DataLoader(按批次加載數據)
# 計算總批次數(向上取整,避免最后一批樣本被丟棄)
train_total_batch = math.ceil(train_len / BATCH_SIZE) # 48000/10=4800批
val_total_batch = math.ceil(val_len / BATCH_SIZE)     # 12000/10=1200批# 訓練集DataLoader
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True  # 訓練集每次epoch前打亂樣本順序(避免模型記憶樣本順序,提升泛化)
)# 驗證集DataLoader
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True  # 驗證集打亂無意義(僅計算損失),建議設為False以提高效率
)

3.5模型初始化與斷點續訓

# 初始化自定義模型(NumberRecognition在model.py中定義,需確保輸入輸出維度匹配)
model = NumberRecognition()# 嘗試加載歷史模型參數(支持斷點續訓)
try:# 加載參數文件(weights_only=True)state_dict = torch.load('./weights/model.pth', weights_only=True)model.load_state_dict(state_dict) # 將參數加載到模型中print('加載模型參數成功')
except:# 若文件不存在(首次訓練),打印提示print('未找到模型參數')# 將模型遷移到指定設備(CPU/GPU)
model.to(device)

3.6損失函數與優化器定義

# 損失函數:交叉熵損失(適合多分類任務,如MNIST的10類數字)
loss_fn = nn.CrossEntropyLoss()# 優化器:Adam優化器(常用優化器,結合SGD的動量和RMSprop的自適應學習率)
optimizer = torch.optim.Adam(model.parameters(),  # 需優化的參數(模型的所有權重和偏置)lr=LR,               # 學習率(與超參數一致)weight_decay=1e-4    # L2正則化(權重衰減,防止模型參數過大導致過擬合)
)

3.7訓練函數(train ())

# 全局變量:累計訓練損失和批次數量(用于計算平均損失)
train_total_loss = 0.
train_count = 0def train():global train_total_loss, train_count # 聲明使用全局變量print('開始訓練')model.train() # 將模型設為“訓練模式”(關鍵!啟用Dropout/BatchNorm更新)# 遍歷訓練集DataLoader,每次取一個批次for i, (images, labels) in enumerate(train_dl):# 1. 將數據遷移到指定設備(與模型設備一致)images, labels = images.to(device), labels.to(device)# 2. 清空上一輪的梯度(PyTorch梯度會累加,不清空會導致梯度錯誤)optimizer.zero_grad()# 3. 前向傳播:模型預測輸出y_pred = model(images) # 輸出形狀:(BATCH_SIZE, 10),每一行是10個類的得分# 4. 計算損失(預測值與真實標簽的差距)loss = loss_fn(y_pred, labels)# 5. 累計損失和批次數量(用于后續計算平均損失)train_total_loss += loss.item() # loss是Tensor,用.item()轉為Python數值train_count += 1# 6. 反向傳播:計算參數梯度(自動微分核心)loss.backward()# 7. 優化器更新參數(根據梯度調整權重和偏置)optimizer.step()# 每100個批次打印一次訓練進度(避免打印過于頻繁)if (i + 1) % 100 == 0:avg_loss = train_total_loss / train_countprint(f'BATCH: [{i + 1}/{train_total_batch}]; loss: {avg_loss:.4f}')# 返回本輪訓練的平均損失(用于epoch結束時打印)return train_total_loss / train_count

3.8驗證函數(valid ())

def valid():# 局部變量:累計驗證損失和批次數量(每輪驗證重新初始化,避免與訓練混淆)val_total_loss = 0.val_count = 0print('開始驗證')model.eval() # 將模型設為“評估模式”(關鍵!禁用Dropout/BatchNorm更新)# 禁用梯度計算(驗證階段無需反向傳播,節省內存和時間)with torch.no_grad():# 遍歷驗證集DataLoaderfor i, (images, labels) in enumerate(val_dl):# 1. 數據遷移到指定設備images, labels = images.to(device), labels.to(device)# 2. 前向傳播(無梯度計算)y_pred = model(images)# 3. 計算驗證損失loss = loss_fn(y_pred, labels)val_total_loss += loss.item()val_count += 1# 每100個批次打印驗證進度if (i + 1) % 100 == 0:avg_loss = val_total_loss / val_countprint(f'BATCH: [{i + 1}/{val_total_batch}]; loss: {avg_loss:.4f}')# 返回本輪驗證的平均損失return val_total_loss / val_count

3.9主訓練循環(多輪訓練與驗證)

# 遍歷所有訓練輪次
for epoch in range(EPOCH):print(f'\nEPOCH: [{epoch + 1}/{EPOCH}]') # 打印當前輪次(從1開始更直觀)# 1. 訓練本輪并獲取訓練平均損失train_loss = train()# 2. 驗證本輪并獲取驗證平均損失val_loss = valid()# 3. 打印本輪訓練結果print(f'EPOCH END; train loss: {train_loss:.4f}; val loss: {val_loss:.4f}')# 訓練結束后,保存最終模型參數(覆蓋原有文件)
torch.save(model.state_dict(), './weights/model.pth')
print('\n模型參數已保存至 ./weights/model.pth')

四、模型訓練完整代碼

import math
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from model import NumberRecognitiondevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')EPOCH = 10
LR = 1e-2
BATCH_SIZE = 10
val_rate = 0.2ds = MNIST(root='./data', train=True, download=False, transform=ToTensor(),target_transform=lambda label: torch.tensor(label))ds_total_len = len(ds)
train_len = int(ds_total_len * (1 - val_rate))
val_len = len(ds) - train_len
train_ds, val_ds = random_split(ds, [train_len, val_len])train_total_batch = math.ceil(train_len / BATCH_SIZE)
val_total_batch = math.ceil(val_len / BATCH_SIZE)train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)model = NumberRecognition()
try:state_dict = torch.load('./weights/model.pth', weights_only=True)model.load_state_dict(state_dict)print('加載模型參數成功')
except:print('未找到模型參數')model.to(device)loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)train_total_loss = 0.
train_count = 0def train():global train_total_loss, train_countprint('開始訓練')model.train()for i, (images, labels) in enumerate(train_dl):# 3. 將數據放到設備上images, labels = images.to(device), labels.to(device)optimizer.zero_grad()y = model(images)loss = loss_fn(y, labels)train_total_loss += loss.item()train_count += 1loss.backward()optimizer.step()if (i + 1) % 100 == 0:print(f'BATCH: [{i + 1}/{train_total_batch}]; loss: {train_total_loss / train_count}')return train_total_loss / train_countdef valid():val_total_loss = 0.val_count = 0print('開始驗證')model.eval()with torch.no_grad():for i, (images, labels) in enumerate(val_dl):images, labels = images.to(device), labels.to(device)y = model(images)loss = loss_fn(y, labels)val_total_loss += loss.item()val_count += 1if (i + 1) % 100 == 0:print(f'BATCH: [{i + 1}/{val_total_batch}]; loss: {val_total_loss / val_count}')return val_total_loss / val_countfor epoch in range(EPOCH):print(f'EPOCH: [{epoch + 1}/{EPOCH}]')train_loss = train()val_loss = valid()print(f'EPOCH END; train loss: {train_loss}; val loss: {val_loss}')torch.save(model.state_dict(), './weights/model.pth')

五、總結流程

  1. 加載 MNIST 公開手寫數字數據集(訓練集)
  2. 劃分訓練集與驗證集(用于監控過擬合)
  3. 加載自定義的數字識別模型(NumberRecognition),支持斷點續訓(加載歷史參數)
  4. 定義訓練 / 驗證流程,使用交叉熵損失和 Adam 優化器訓練模型
  5. 訓練完成后保存模型參數,便于后續推理或繼續訓練。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/95418.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/95418.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/95418.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

批量給文件夾添加文件v2【件批量復制工具】

代碼功能介紹 這個代碼的功能就是一個,給某個文件夾里面添加某個文件(含父級文件夾下的每一個子文件夾) 舉個例子,父級文件夾是:“D:\Desktop\1,要添加的文件路徑是:D:\1.txt” 則最后會把文件…

Qt實現2048小游戲:看看AI如何評估棋盤策略實現“人機合一

2048 是一款經典的數字益智游戲,其簡單的規則背后蘊含著豐富的策略性。該項目不僅完整實現了 2048 的核心玩法,還包含了一個基于啟發式評估和蒙特卡洛方法的智能 AI 玩家。 我們將從項目整體架構入手,逐一解析游戲核心邏輯、UI 渲染、事件處理、AI 策略等關鍵模塊,并通過展…

封裝紅黑樹實現mysetmymap

1. 源碼分析 set實例化rb_tree時第二個模板參數給的是key&#xff0c;map實例化rb_tree時第?個模板參數給的是 pair<const key,T>&#xff0c;這樣一顆紅黑樹既可以實現key搜索場景的set&#xff0c;也可以實現key/value搜索場 景的map源碼里面模板參數是用T代表value&…

以OWTB為核心以客戶為基礎的三方倉運配一體化平臺分析V0.2

一、系統概述以OWTB&#xff08;Order-Warehouse-Transportation-Billing&#xff0c;訂單-倉儲-運輸-結算&#xff09;為核心的三方倉運配一體化平臺&#xff0c;是專為第三方物流企業打造的深度定制化解決方案。該平臺以第三方倉運配為主線&#xff0c;以多客戶/多SKU/個性化…

技術框架之腳手架實現

一、 序言在日常的企業級Java開發中&#xff0c;我們經常會發現自己在重復地做著一些項目初始化工作&#xff1a;創建相似的項目結構、引入一堆固定的依賴包、編寫通用的配置文件、拷貝那些幾乎每個項目都有的基礎工具類和日志配置。這些工作不僅枯燥乏味&#xff0c;而且容易出…

小迪安全v2023學習筆記(七十七講)—— 業務設計篇隱私合規檢測重定向漏洞資源拒絕服務

文章目錄前記WEB攻防——第七十七天業務設計篇&隱私合規檢測&URL重定向&資源拒絕服務&配合項目隱私合規 - 判斷規則&檢測項目介紹案例演示URL重定向 - 檢測判斷&釣魚配合介紹黑盒測試看業務功能看參數名goole語法搜索白盒測試跳轉URL繞過思路釣魚配合資…

用AI做旅游攻略,真能比人肉整理靠譜?

大家好&#xff0c;我是極客團長&#xff01; 作為一個沉迷研究 “AI 工具怎么滲透日常生活” 的科技博主&#xff0c;我開了個 AI 解決生活小事系列。 前兩期聊了用 AI 寫新聞博客、扒商業報告&#xff0c;后臺一堆人催更&#xff1a;能不能搞點接地氣的&#xff1f;比如&am…

Axure RP 9 Mac 交互原型設計

原文地址&#xff1a;Axure RP 9 Mac 交互原型設計 安裝教程 Axure RP 9是一款功能強大的原型設計和協作工具。 它不僅能夠幫助用戶快速創建出高質量的原型設計&#xff0c;還能促進團隊成員之間的有效協作&#xff0c;從而極大地提高數字產品開發的效率和質量。 擁有直觀易…

多線程——線程狀態

目錄 1.線程的狀態 1.1 NEW 1.2 RUNNABLE 1.3 BLOCKED 1.4 WAITING 1.5 TIMED_WAITING 1.6 TERMINATED 2.線程狀態的相互轉換 在上期的學習中&#xff0c;已經理解線程的啟動&#xff08;start()&#xff09;、休眠&#xff08;sleep()&#xff09;、中斷&#xff08;i…

IMX6ULL的設備樹文件簡析

先分析一個完整的設備樹&#xff0c;是怎么表達各種外設信息的。以imux6ull開發板為例進行說明。這個文件里就一個設備信息才這么點內容&#xff0c;是不是出問題了&#xff1f;當然不是&#xff0c;我們知道dts文件是可包含的&#xff0c;所以&#xff0c;最終形成的一個完整文…

【ARM】PACK包管理

1、 文檔目標對 pack 包的管理有更多的了解。2、 問題場景客戶在安裝了過多的 pack 包導致軟件打開比較慢&#xff0c;各種 pack 包顏色的區別&#xff0c;及圖標不同。3、軟硬件環境1&#xff09;、軟件版本&#xff1a;Keil MDK 5.392&#xff09;、電腦環境&#xff1a;Wind…

【Kubernetes】知識點4

36. 說明K8s中Pod級別的Graceful Shutdown。答&#xff1a;Graceful Shutdown&#xff08;優雅關閉&#xff09;是指當 Pod 需要終止時&#xff0c;系統給予運行中的容器一定的時間來等待業務的應用的正常關閉&#xff08;如保存數據、關閉連接、釋放資源等&#xff09;&#x…

Paraverse平行云實時云渲染助力第82屆威尼斯電影節XR沉浸式體驗

今年&#xff0c;Paraverse平行云實時云渲染平臺LarkXR&#xff0c;為享有盛譽的第82屆威尼斯國際電影節&#xff08;8月27日至9月6日&#xff09;帶來沉浸式體驗。 LarkXR助力我們的生態伙伴FRENCH TOUCH FACTORY&#xff0c;實現ITHACA容積視頻的XR交互演示&#xff0c;從意大…

大數據開發計劃表(實際版)

太好了&#xff01;我將為你生成一份可打印的PDF版學習計劃表&#xff0c;并附上項目模板與架構圖示例&#xff0c;幫助你更直觀地執行計劃。 由于當前環境無法直接生成和發送文件&#xff0c;我將以文本格式為你完整呈現&#xff0c;你可以輕松復制到Word或Markdown中&#xf…

GitLab 18.3 正式發布,更新多項 DevOps、CI/CD 功能【二】

沿襲我們的月度發布傳統&#xff0c;極狐GitLab 發布了 18.3 版本&#xff0c;該版本帶來了通過直接轉移進行遷移、CI/CD 作業令牌的細粒度權限控制、自定義管理員角色、Kubernetes 1.33 支持、通過 API 讓流水線執行策略訪問 CI/CD 配置等幾十個重點功能的改進。下面是對部分重…

Docker學習筆記(二):鏡像與容器管理

Docker 鏡像 最小的鏡像 hello-world 是 Docker 官方提供的一個鏡像&#xff0c;通常用來驗證 Docker 是否安裝成功。 先通過 docker pull 從 Docker Hub 下載它。 [rootdocker ~]# docker pull hello-world Using default tag: latest latest: Pulling from library/hello-wor…

STM32F103C8T6開發板入門學習——寄存器和庫函數介紹

學習目標&#xff1a;STM32F103C8T6開發板入門學習——寄存器和庫函數介紹學習內容&#xff1a; 1. 寄存器介紹 1.1 存儲器映射 存儲器本身無固有地址&#xff0c;是具有特定功能的內存單元。它的地址是由芯片廠商或用戶分配&#xff0c;給存儲器分配地址的過程就叫做存儲區映射…

【CouponHub項目開發】使用RocketMQ5.x實現延時修改優惠券狀態,并通過使用模板方法模式重構消息隊列發送功能

在上個章節中我實現了創建優惠券模板的功能&#xff0c;但是&#xff0c;優惠券總會有過期時間&#xff0c;我們怎么去解決到期自動修改優惠券狀態這樣一個功能呢&#xff1f;我們可以使用RocketMQ5.x新出的任意定時發送消息功能來解決。 初始方案&#xff1a;首先在創建優惠券…

Claude Code SDK 配置Gitlab MCP服務

一、MCP配置前期準備 &#xff08;一&#xff09;創建個人令牌/群組令牌 我這里是創建個人令牌&#xff0c;去到首頁左上角&#xff0c;點擊頭像——>偏好設置——>訪問令牌——>添加新令牌 &#xff08;二&#xff09;配置mcp信息 去到魔塔社區&#xff0c;點擊mc…

Eclipse 常用搜索功能匯總

Eclipse 常用搜索功能匯總 Eclipse 提供了多種搜索功能&#xff0c;幫助開發者快速定位代碼、文件、類、方法、API 等資源。以下是詳細的使用方法和技巧。 一、常用搜索快捷鍵快捷鍵功能描述Ctrl H打開全局搜索對話框&#xff0c;支持文件、Java 代碼、任務等多種搜索。Ctrl …