一、任務描述
????????從手寫數字圖像中自動識別出對應的數字(0-9)” 的問題,屬于單標簽圖像分類任務(每張圖像僅對應一個類別,即 0-9 中的一個數字)
? ? ? ? 1、任務的核心定義:輸入與輸出
- 輸入:28×28 像素的灰度圖像(像素值范圍 0-255,0 代表黑色背景,255 代表白色前景),圖像內容是人類手寫的 0-9 中的某一個數字,例如:一張 28×28 的圖像,像素分布呈現 “3” 的形狀,就是模型的輸入。
- 輸出:一個 “類別標簽”,即從 10 個可能的類別(0、1、2、…、9)中選擇一個,作為輸入圖像對應的數字,例如:輸入 “3” 的圖像,模型輸出 “類別 3”,即完成一次正確識別。
- 目標:讓模型在 “未見的手寫數字圖像” 上,盡可能準確地輸出正確類別(通常用 “準確率” 衡量,即正確識別的圖像數 / 總圖像數)
? ? ? ? 2、任務的核心挑戰
- 不同人書寫習慣差異極大:有人寫的 “4” 帶彎鉤,有人寫的 “7” 帶橫線,有人字體粗大,有人字體纖細;甚至同一個人不同時間寫的同一數字,筆畫粗細、傾斜角度也會不同。例如:同樣是 “5”,可能是 “直筆 5”“圓筆 5”,也可能是傾斜 10° 或 20° 的 “5”—— 模型需要忽略這些 “風格差異”,抓住 “數字的本質特征”(如 “5 有一個上半圓 + 一個豎線”)。
- 圖像噪聲與干擾:手寫數字圖像可能存在噪聲,比如紙張上的污漬、書寫時的斷筆、掃描時的光線不均,這些都會影響像素分布。例如:一張 “0” 的圖像,邊緣有一小塊污漬,模型需要判斷 “這是噪聲” 而不是 “0 的一部分”,避免誤判為 “6” 或 “8”。
二、模型訓練
? ? ? ?1、MNIST數據集
????????MNIST(Modified National Institute of Standards and Technology database)是由美國國家標準與技術研究院(NIST)整理的手寫數字數據集,后經修改(調整圖像大小、居中對齊)成為機器學習領域的 “基準數據集”,MNIST手寫數字識別的核心是 “讓計算機從標準化的手寫數字灰度圖中,自動識別出對應的 0-9 數字”,它看似基礎,卻濃縮了圖像分類的核心挑戰(風格多樣性、噪聲魯棒性、特征自動提取),同時是實際 OCR 場景的技術基礎和機器學習入門的經典案例。
- 數據量適中:包含 70000 張圖像,其中 60000 張用于訓練(讓模型學習特征),10000 張用于測試(驗證模型泛化能力);
- 圖像規格統一:所有圖像都是 28×28 灰度圖,無需復雜的預處理(如尺寸縮放、顏色通道處理),降低入門門檻;
- 標注準確:每張圖像都有明確的 “正確數字標簽”(人工標注),無需額外標注成本。
? ? ? ? 2、代碼
- 數據準備:使用torchvision.datasets加載 MNIST 數據集,對數據進行轉換(轉為 Tensor 并標準化),使用DataLoader創建可迭代的數據加載器;
- 模型定義:定義了一個簡單的兩層神經網絡SimpleNN,第一層將 28x28 的圖像展平后映射到 128 維,第二層將 128 維特征映射到 10 個類別(對應數字 0-9);
- 訓練設置:使用交叉熵損失函數(CrossEntropyLoss),使用 Adam 優化器,設置批量大小為64,訓練輪次為5;
- 訓練過程:循環多個訓練輪次(epoch),每個輪次中迭代所有批次數據,執行前向傳播、計算損失、反向傳播和參數更新。
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 設置隨機種子,確保結果可復現
torch.manual_seed(42)# 1. 數據準備
# 定義數據變換
transform = transforms.Compose([transforms.ToTensor(), # 轉換為Tensortransforms.Normalize((0.1307,), (0.3081,)) # 標準化,MNIST數據集的均值和標準差
])# 加載MNIST數據集
train_dataset = datasets.MNIST(root='./data', # 數據保存路徑train=True, # 訓練集download=True, # 如果數據不存在則下載transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False, # 測試集download=True,transform=transform
)# 創建數據加載器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 2. 定義模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 輸入層到隱藏層self.fc1 = nn.Linear(28 * 28, 128) # MNIST圖像大小為28x28# 隱藏層到輸出層self.fc2 = nn.Linear(128, 10) # 10個類別(0-9)def forward(self, x):# 將圖像展平為一維向量x = x.view(-1, 28 * 28)# 隱藏層,使用ReLU激活函數x = torch.relu(self.fc1(x))# 輸出層,不使用激活函數(因為后面會用CrossEntropyLoss)x = self.fc2(x)return x# 3. 初始化模型、損失函數和優化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss() # 交叉熵損失,適用于分類問題
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam優化器# 4. 訓練模型
def train(model, train_loader, criterion, optimizer, epochs=5):model.train() # 設置為訓練模式train_losses = []for epoch in range(epochs):running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):# 清零梯度optimizer.zero_grad()# 前向傳播outputs = model(data)loss = criterion(outputs, target)# 反向傳播和優化loss.backward()optimizer.step()running_loss += loss.item()# 每100個批次打印一次信息if batch_idx % 100 == 99:print(f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')running_loss = 0.0train_losses.append(running_loss / len(train_loader))return train_losses# 6. 運行訓練和測試
if __name__ == '__main__':# 訓練模型print("開始訓練模型...")train_losses = train(model, train_loader, criterion, optimizer, epochs=5)print("模型訓練完成...")# 保存模型torch.save(model.state_dict(), 'mnist_model.pth')print("模型已保存為 mnist_model.pth")
三、模型使用測試
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms # 修正transforms的導入方式# 定義與訓練時相同的模型結構
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 加載模型
def load_model(model_path='mnist_model.pth'):model = SimpleNN()# 加載模型時添加參數以避免潛在的Python 3兼容性問題model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))model.eval() # 設置為評估模式return model# 圖像預處理(與訓練時保持一致)
def preprocess_image(image_path):# 打開圖像并轉換為灰度圖img = Image.open(image_path).convert('L') # 'L'表示灰度模式# 調整大小為28x28img = img.resize((28, 28))# 轉換為numpy數組并歸一化img_array = np.array(img) / 255.0# 定義圖像轉換(使用torchvision的transforms)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 注意:這里需要先將numpy數組轉換為PIL圖像再應用transformimg_pil = Image.fromarray((img_array * 255).astype(np.uint8))img_tensor = transform(img_pil).unsqueeze(0) # 增加批次維度return img_tensor# 預測函數
def predict_digit(model, image_path):# 預處理圖像img_tensor = preprocess_image(image_path)# 預測with torch.no_grad(): # 不計算梯度outputs = model(img_tensor)_, predicted = torch.max(outputs.data, 1)return predicted.item() # 返回預測的數字# 示例使用
if __name__ == '__main__':# 加載模型model = load_model('mnist_model.pth')# 預測示例圖像test_image_path = 'test_digit.png' # 用戶需要提供的測試圖像路徑try:predicted_digit = predict_digit(model, test_image_path)print(f"預測的數字是: {predicted_digit}")except Exception as e:print(f"預測出錯: {str(e)}")
使用gpu0(第一塊gpu)進行訓練/推理:
????????torch.cuda.set_device(0) ???
????????model = model.cuda(0)
使用cpu記性訓練/推理:
????????model = model.cpu()
怎么用pytorch訓練一個模型-手寫數字識別
手把手教你如何跑通一個手寫中文漢字識別模型-OCR識別【pytorch】
手把手教你用PyTorch從零訓練自己的大模型(非常詳細)零基礎入門到精通,收藏這一篇就夠了
揭秘大模型的訓練方法:使用PyTorch進行超大規模深度學習模型訓練
全套解決方案:基于pytorch、transformers的中文NLP訓練框架,支持大模型訓練和文本生成,快速上手,海量訓練數據!
用 pytorch 從零開始創建大語言模型(三):編碼注意力機制
YOLOv5源碼逐行超詳細注釋與解讀(1)——項目目錄結構解析