MNIST 手寫數字分類

轉自我的個人博客: https://shar-pen.github.io/2025/05/04/torch-distributed-series/1.MNIST/

基礎的單卡訓練

本筆記本演示了訓練一個卷積神經網絡(CNN)來對 MNIST 數據集中的手寫數字進行分類的過程。工作流程包括:

  1. 數據準備:加載和預處理 MNIST 數據集。
  2. 模型定義:使用 PyTorch 構建 CNN 模型。
  3. 模型訓練:在 MNIST 訓練數據集上訓練模型。
  4. 模型評估:在 MNIST 測試數據集上測試模型并評估其性能。
  5. 可視化:展示樣本圖像及其對應的標簽。

參考 pytorch 官方示例 https://github.com/pytorch/examples/blob/main/mnist/main.py 。

至于為什么選擇 MNIST 分類任務, 因為它就是深度學習里的 Hello World.

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from time import time

深度學習里,真正必要的超參數,大致是下面這些:

  1. 學習率(learning rate)

    • 最最核心的超參數。
    • 決定每次參數更新的步幅大小。
    • 學習率不合適,訓練幾乎一定失敗。
  2. 優化器(optimizer)

    • 比如 SGDAdamAdamW 等。
    • 不同優化器,收斂速度、最終效果差異很大。
    • 有時也需要設置優化器內部超參(比如 Adam 的 β 1 , β 2 \beta_1, \beta_2 β1?,β2?)。
  3. 批大小(batch size)

    • 多少樣本合成一批送進模型訓練。
    • 影響訓練穩定性、收斂速度、硬件占用。
  4. 訓練輪次(epoch)最大步數(max steps)

    • 總共訓練多久。
    • 如果訓練不夠長,模型欠擬合;太久則過擬合或資源浪費。
  5. 損失函數(loss function)

    • 明確訓練目標,比如分類用 CrossEntropyLoss,回歸用 MSELoss
    • 不同任務必須選對損失。

超參設置

我們設置些最基礎的超參: epoch, batch size, device, lr

EPOCHS = 5
BATCH_SIZE = 512
LR = 0.001
LR_DECAY_STEP_NUM = 1
LR_DECAY_FACTOR = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

數據構建

直接用庫函數生成 dataset 和 dataloader, 前者其實只是拿來生成 dataloader

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_data = datasets.MNIST(root = './mnist',train=True,       # 設置True為訓練數據,False為測試數據transform = transform,# download=True  # 設置True后就自動下載,下載完成后改為False即可
)train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)test_data = datasets.MNIST(root = './mnist',train=False,       # 設置True為訓練數據,False為測試數據transform = transform,
)test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)# plot one exampleprint(f'dataset: input shape: {train_data.data.size()}, label shape: {train_data.targets.size()}')
print(f'dataloader iter: input shape: {next(iter(train_loader))[0].size()}, label shape: {next(iter(train_loader))[1].size()}')
plt.imshow(train_data.data[0].numpy(), cmap='gray')
plt.title(f'Label: {train_data.targets[0]}')
plt.show()

? dataset: input shape: torch.Size([60000, 28, 28]), label shape: torch.Size([60000])
? dataloader iter: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512])

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳

網絡

設計簡單的 ConvNet, 幾層 CNN + MLP。初始化新模型后,先將其放到 DEVICE 上

class ConvNet(nn.Module):"""A neural network model for MNIST digit classification.This model is designed to classify images from the MNIST dataset, which consists of grayscale images of handwritten digits (0-9). The network architecture includes convolutional layers for feature extraction, followed by fully connected layers for classification.Attributes:features (nn.Sequential): A sequential container of convolutional layers, activation functions, pooling, and dropout for feature extraction.classifier (nn.Sequential): A sequential container of fully connected layers, activation functions, and dropout for classification.Methods:forward(x):Defines the forward pass of the network. Takes an input tensor `x`, processes it through the feature extractor and classifier, and returns the log-softmax probabilities for each class."""def __init__(self):super(ConvNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 32, 3, 1),nn.ReLU(),nn.Conv2d(32, 64, 3, 1),nn.ReLU(),nn.MaxPool2d(2),nn.Dropout(0.25))self.classifier = nn.Sequential(nn.Linear(9216, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, 10))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)output = F.log_softmax(x, dim=1)return output

訓練和評估函數

將訓練和評估函數分別封裝為函數,使主循環更簡潔

def train(model, device, train_loader, optimizer):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if (batch_idx + 1) % 30 == 0: print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item() # 將一批的損失相加pred = output.max(1, keepdim=True)[1] # 找到概率最大的下標correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

主訓練循環

model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)start_time = time()  # Record the start time
for epoch in range(EPOCHS):epoch_start_time = time()  # Record the start time of the current epochprint(f'Epoch {epoch}/{EPOCHS}')print(f'Learning Rate: {scheduler.get_last_lr()[0]}')train(model, DEVICE, train_loader, optimizer)test(model, DEVICE, test_loader)scheduler.step()epoch_end_time = time()  # Record the end time of the current epochprint(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")end_time = time()  # Record the end time
print(f"Total training time: {end_time - start_time:.2f} seconds")
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   1795609      C   ...st/anaconda3/envs/xprepo/bin/python        448MiB |
|    0   N/A  N/A   1814253      C   ...st/anaconda3/envs/xprepo/bin/python       1036MiB |
|    7   N/A  N/A   4167010      C   ...guest/anaconda3/envs/QDM/bin/python      19416MiB |
+-----------------------------------------------------------------------------------------+

0 卡的占用 1484 MB

完整代碼

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from time import time
import argparseclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 32, 3, 1),nn.ReLU(),nn.Conv2d(32, 64, 3, 1),nn.ReLU(),nn.MaxPool2d(2),nn.Dropout(0.25))self.classifier = nn.Sequential(nn.Linear(9216, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, 10))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)output = F.log_softmax(x, dim=1)return outputdef arg_parser():parser = argparse.ArgumentParser(description="MNIST Training Script")parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training")parser.add_argument("--lr", type=float, default=0.0005, help="Learning rate")parser.add_argument("--lr_decay_step_num", type=int, default=1, help="Step size for learning rate decay")parser.add_argument("--lr_decay_factor", type=float, default=0.5, help="Factor by which learning rate is decayed")parser.add_argument("--cuda_id", type=int, default=0, help="CUDA device ID to use")return parser.parse_args()def prepare_data(batch_size):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])train_data = datasets.MNIST(root = './mnist',train=True,       # 設置True為訓練數據,False為測試數據transform = transform,# download=True  # 設置True后就自動下載,下載完成后改為False即可)train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)test_data = datasets.MNIST(root = './mnist',train=False,       # 設置True為訓練數據,False為測試數據transform = transform,)test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)return train_loader, test_loaderdef train(model, device, train_loader, optimizer):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if (batch_idx + 1) % 30 == 0: print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item() # 將一批的損失相加pred = output.max(1, keepdim=True)[1] # 找到概率最大的下標correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))def train_mnist_classification():args = arg_parser()print(args)EPOCHS = args.epochsBATCH_SIZE = args.batch_sizeLR = args.lrLR_DECAY_STEP_NUM = args.lr_decay_step_numLR_DECAY_FACTOR = args.lr_decay_factorCUDA_ID = args.cuda_idDEVICE = torch.device(f"cuda:{CUDA_ID}")train_loader, test_loader = prepare_data(BATCH_SIZE)model = ConvNet().to(DEVICE)optimizer = optim.Adam(model.parameters(), lr=LR)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)start_time = time()  # Record the start timefor epoch in range(EPOCHS):epoch_start_time = time()  # Record the start time of the current epochprint(f'Epoch {epoch}/{EPOCHS}')print(f'Learning Rate: {scheduler.get_last_lr()[0]}')train(model, DEVICE, train_loader, optimizer)test(model, DEVICE, test_loader)scheduler.step()epoch_end_time = time()  # Record the end time of the current epochprint(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")end_time = time()  # Record the end timeprint(f"Total training time: {end_time - start_time:.2f} seconds")if __name__ == "__main__":train_mnist_classification()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/82568.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/82568.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/82568.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

數據庫中的 Segment、Extent、Page、Row 詳解

在關系型數據庫的底層存儲架構中,數據并不是隨意寫入磁盤,而是按照一定的結構分層管理的。理解這些存儲單位對于優化數據庫性能、理解 SQL 執行過程以及排查性能問題都具有重要意義。 我將從宏觀到微觀,依次介紹數據庫存儲中的四個核心概念&…

DAMA車輪圖

DAMA車輪圖是國際數據管理協會(DAMA International)提出的數據管理知識體系(DMBOK)的圖形化表示,它以車輪(同心圓)的形式展示了數據管理的核心領域及其相互關系。以下是基于用戶提供的關鍵詞對D…

《QDebug 2025年4月》

一、Qt Widgets 問題交流 1. 二、Qt Quick 問題交流 1.QML單例動態創建的對象,訪問外部id提示undefined 先定義一個窗口組件,打印外部的id: // MyWindow.qml import QtQuick 2.15 import QtQuick.Window 2.15Window {id: controlwidth: …

JS | 正則 · 常用正則表達式速查表

以下是前端開發中常用的正則表達式速查表,包含驗證規則、用途說明與示例: 📌 常用正則表達式速查表 名稱正則表達式描述 / 用途示例手機號/^1[3-9]\d{9}$/中國大陸手機號13812345678 ?座機號/^0\d{2,3}-?\d{7,8}$/固定電話010-12345678 ?…

系統思考:個人與團隊成長

四年前,我交付的系統思考項目,今天學員的反饋依然深深觸動了我。 我常常感嘆,系統思考不僅僅是一場培訓,更像是一場持續的“修煉”。在這條修煉之路上,最珍貴的,便是有志同道合的伙伴們一路同行&#xff0…

寫屏障和讀屏障的區別是什么?

寫屏障(Write Barrier)與讀屏障(Read Barrier)的區別 在計算機科學中,寫屏障和讀屏障是兩種關鍵的內存同步機制,主要用于解決并發編程中的可見性、有序性問題,或在垃圾回收(GC&…

ssh -T git@github.com 測試失敗解決方案:修改hosts文件

問題描述 通過SSH方式測試,使用該方法測試連接可能會遇到連接超時、端口占用的情況,原因是因為DNS配置及其解析的問題 ssh -T gitgithub.com我們可以詳細看看建立 ssh 連接的過程中發生了什么,可以使用 ssh -v命令,-v表示 verbo…

大疆無人機搭載樹莓派進行目標旋轉檢測

環境部署 首先是環境創建,創建虛擬環境,名字叫 pengxiang python -m venv pengxiang隨后激活環境 source pengxiang/bin/activate接下來便是依賴包安裝過程了: pip install onnxruntime #推理框架 pip install fastapi uvicorn[standard] #網絡請求…

00 Ansible簡介和安裝

1. Ansible概述與基本概念 1.1. 什么是Ansible? Ansible 是一款用 Python 編寫的開源 IT 自動化工具,主要用于配置管理、軟件部署及高級工作流編排。它能夠簡化應用程序部署、系統更新等操作,并且支持自動化管理大規模的計算機系統。Ansibl…

Linxu實驗五——NFS服務器

一.NFS服務器介紹 NFS服務器(Network File System)是一種基于網絡的分布式文件系統協議,允許不同操作系統的主機通過網絡共享文件和目錄3。其核心作用在于實現跨平臺的資源透明訪問,例如在Linux和Unix系統之間共享靜態數據&#…

『 測試 』測試基礎

文章目錄 1. 調試與測試的區別2. 開發過程中的需求3. 開發模型3.1 軟件的生命周期3.2 瀑布模型3.2.1 瀑布模型的特點/缺點 3.3 螺旋模型3.3.1 螺旋模型的特點/缺點 3.4 增量模型與迭代模型3.5 敏捷模型3.5.1 Scrum模型3.5.2 敏捷模型中的測試 4 測試模型4.1 V模型4.2 W模型(雙V…

紅外遙控鍵

紅外 本章節旨在讓用戶自定義紅外遙控功能,需要有板載紅外接收的板卡。 12.1. 獲取紅外遙控鍵值 由于不同遙控器廠家定義的按鍵鍵值不一樣,所以配置不通用,需要獲取實際按鍵對應的鍵值。 1 2 3 4 5 6 #設置輸出等級 echo 7 4 1 7> /pr…

同一個虛擬環境中conda和pip安裝的文件存儲位置解析

文章目錄 存儲位置的基本區別conda安裝的包pip安裝的包 看似相同實則不同的機制實際路徑示例這種差異帶來的問題如何檢查包安裝來源最佳實踐建議 總結 存儲位置的基本區別 conda安裝的包 存儲在Anaconda(或Miniconda)目錄下的pkgs和envs子目錄中: ~/anaconda3/en…

機器學習極簡入門:從基礎概念到行業應用

有監督學習(supervised learning) 讓模型學習的數據包含正確答案(標簽)的方法,最終模型可以對無標簽的數據進行正確處理和預測,可以分為分類與回歸兩大類 分類問題主要是為了“盡可能分開整個數據而畫線”…

split和join的區別?

split和join是Python中用于處理字符串的兩種方法,它們的主要區別在于功能和使用場景。? split()方法 ?split()方法用于將字符串按照指定的分隔符分割成多個子串,并返回這些子串組成的列表?。如果不指定分隔符,則默認分割所有的空白字符&am…

MySQL從入門到精通(二):Windows和Mac版本MySQL安裝教程

目錄 MySQL安裝流程 (一)、進入MySQL官網 (二)、點擊下載(Download) (三)、Windows和Mac版本下載 下載Windows版本 下載Mac版本 (四)、驗證并啟動MySQL …

LeetCode 解題思路 45(分割等和子集、最長有效括號)

解題思路: dp 數組的含義: 在數組中是否存在一個子集,其和為 i。遞推公式: dp[i] | dp[i - num]。dp 數組初始化: dp[0] true。遍歷順序: 從大到小去遍歷,從 i target 開始,直到 …

電影感戶外啞光人像自拍攝影Lr調色預設,手機濾鏡PS+Lightroom預設下載!

調色詳情 電影感戶外啞光人像自拍攝影 Lr 調色,是借助 Lightroom 軟件,針對戶外環境下拍攝的人像自拍進行后期處理。旨在模擬電影畫面的氛圍與質感,通過調色賦予照片獨特的藝術氣息。強調打造啞光效果,使畫面色彩不過于濃烈刺眼&a…

使用 NV?Ingest、Unstructured 和 Elasticsearch 處理非結構化數據

作者:來自 Elastic Ajay Krishnan Gopalan 了解如何使用 NV-Ingest、Unstructured Platform 和 Elasticsearch 為 RAG 應用構建可擴展的非結構化文檔數據管道。 Elasticsearch 原生集成了行業領先的生成式 AI 工具和提供商。查看我們的網絡研討會,了解如…

Android 13 使能user版本進recovery

在 debug 版本上,可以在關機狀態下,同時按 電源鍵 和 音量加鍵 進 recovery 。 user 版本上不行。 參考 使用 build 變體 debug 版本和 user 版本的差別之一就是 ro.debuggable 屬性不同。 順著這個思路追蹤,找到 bootable/recovery/reco…