怎么用pytorch訓練一個模型,并跑起來

MNIST 手寫數字識別

任務描述

MNIST 手寫數字識別是機器學習和計算機視覺領域的經典任務,其本質是解決 “從手寫數字圖像中自動識別出對應的數字(0-9)” 的問題,屬于單標簽圖像分類任務(每張圖像僅對應一個類別,即 0-9 中的一個數字)。

任務的核心定義:輸入與輸出

MNIST 任務的本質是建立 “手寫數字圖像” 到 “數字類別” 的映射關系,具體如下:
維度
| 具體 | 內容 |
|輸入|28×28 像素的灰度圖像(像素值范圍 0-255,0 代表黑色背景,255 代表白色前景),圖像內容是人類手寫的 0-9 中的某一個數字。
例如:一張 28×28 的圖像,像素分布呈現 “3” 的形狀,就是模型的輸入。|
|輸出 |一個 “類別標簽”,即從 10 個可能的類別(0、1、2、…、9)中選擇一個,作為輸入圖像對應的數字。
例如:輸入 “3” 的圖像,模型輸出 “類別 3”,即完成一次正確識別。 |
|目標|讓模型在 “未見的手寫數字圖像” 上,盡可能準確地輸出正確類別(通常用 “準確率” 衡量,即正確識別的圖像數 / 總圖像數)|

任務的核心挑戰

為什么需要 “機器學習模型”?如果只是簡單的 “看圖像認數字”,人類可以輕松完成,但讓計算機自動識別,需要解決多個關鍵挑戰 —— 這些挑戰也是 MNIST 成為經典任務的原因(它濃縮了計算機視覺的核心難題):
不同人書寫習慣差異極大:有人寫的 “4” 帶彎鉤,有人寫的 “7” 帶橫線,有人字體粗大,有人字體纖細;甚至同一個人不同時間寫的同一數字,筆畫粗細、傾斜角度也會不同。
例如:同樣是 “5”,可能是 “直筆 5”“圓筆 5”,也可能是傾斜 10° 或 20° 的 “5”—— 模型需要忽略這些 “風格差異”,抓住 “數字的本質特征”(如 “5 有一個上半圓 + 一個豎線”)。
圖像噪聲與干擾
手寫數字圖像可能存在噪聲:比如紙張上的污漬、書寫時的斷筆、掃描時的光線不均,這些都會影響像素分布。
例如:一張 “0” 的圖像,邊緣有一小塊污漬,模型需要判斷 “這是噪聲” 而不是 “0 的一部分”,避免誤判為 “6” 或 “8”。

特征的自動提取

人類認數字時,會自動關注 “關鍵特征”(如 “0 是圓形、1 是豎線、8 是兩個圓形疊加”),但計算機只能處理像素矩陣 —— 模型需要從 28×28=784 個像素值中,自動學習到這些抽象的 “數字特征”,而不是依賴人工定義(這也是深度學習優于傳統方法的核心)。

MNIST 數據集的背景

MNIST(Modified National Institute of Standards and Technology database)是由美國國家標準與技術研究院(NIST)整理的手寫數字數據集,后經修改(調整圖像大小、居中對齊)成為機器學習領域的 “基準數據集”,其規模和特點非常適合入門:
數據量適中:包含 70000 張圖像,其中 60000 張用于訓練(讓模型學習特征),10000 張用于測試(驗證模型泛化能力);
圖像規格統一:所有圖像都是 28×28 灰度圖,無需復雜的預處理(如尺寸縮放、顏色通道處理),降低入門門檻;
標注準確:每張圖像都有明確的 “正確數字標簽”(人工標注),無需額外標注成本。

任務的實際價值:解決這個問題有什么用?

MNIST 看似簡單,但它是很多實際場景的 “簡化版任務”,其解決思路可以遷移到更復雜的場景:
光學字符識別(OCR)的基礎
例如:銀行支票上的手寫數字識別(識別金額)、快遞單上的手寫郵編識別、試卷批改中的選擇題填涂識別 —— 這些場景本質都是 “手寫字符分類”,MNIST 的技術思路(如卷積神經網絡、全連接網絡)可以直接復用或擴展。
機器學習模型的 “基準測試”
新提出的模型(如早期的 LeNet-5、后來的 ResNet 輕量版)會先在 MNIST 上測試性能:如果在簡單的 MNIST 上都表現差,說明模型設計有問題;如果在 MNIST 上表現好,再遷移到更復雜的任務(如人臉識別、醫學圖像分類),這能大幅降低研發成本。
入門教學的 “絕佳案例”
MNIST 任務足夠簡單(數據量小、目標明確),但又能覆蓋機器學習的完整流程(數據準備、模型定義、訓練、評估、泛化性驗證),因此成為初學者理解 “如何用代碼實現一個完整機器學習任務” 的最佳載體(就像編程入門的 “Hello World”)。
總結
MNIST 手寫數字識別的核心是 “讓計算機從標準化的手寫數字灰度圖中,自動識別出對應的 0-9 數字”,它看似基礎,卻濃縮了圖像分類的核心挑戰(風格多樣性、噪聲魯棒性、特征自動提取),同時是實際 OCR 場景的技術基礎和機器學習入門的經典案例。

代碼

下面我將為你提供一個使用 PyTorch 訓練模型的完整示例,包括數據準備、模型定義、訓練和測試的全過程。我們將使用一個簡單的神經網絡來解決 MNIST 手寫數字識別問題。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 設置隨機種子,確保結果可復現
torch.manual_seed(42)# 1. 數據準備
# 定義數據變換
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為Tensortransforms.Normalize((0.1307,), (0.3081,))  # 標準化,MNIST數據集的均值和標準差
])# 加載MNIST數據集
train_dataset = datasets.MNIST(root='./data',  # 數據保存路徑train=True,     # 訓練集download=True,  # 如果數據不存在則下載transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,    # 測試集download=True,transform=transform
)# 創建數據加載器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 2. 定義模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 輸入層到隱藏層self.fc1 = nn.Linear(28*28, 128)  # MNIST圖像大小為28x28# 隱藏層到輸出層self.fc2 = nn.Linear(128, 10)     # 10個類別(0-9)def forward(self, x):# 將圖像展平為一維向量x = x.view(-1, 28*28)# 隱藏層,使用ReLU激活函數x = torch.relu(self.fc1(x))# 輸出層,不使用激活函數(因為后面會用CrossEntropyLoss)x = self.fc2(x)return x# 3. 初始化模型、損失函數和優化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss()  # 交叉熵損失,適用于分類問題
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam優化器# 4. 訓練模型
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()  # 設置為訓練模式train_losses = []for epoch in range(epochs):running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):# 清零梯度optimizer.zero_grad()# 前向傳播outputs = model(data)loss = criterion(outputs, target)# 反向傳播和優化loss.backward()optimizer.step()running_loss += loss.item()# 每100個批次打印一次信息if batch_idx % 100 == 99:print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')running_loss = 0.0train_losses.append(running_loss / len(train_loader))return train_losses# 5. 測試模型
def test(model, test_loader):model.eval()  # 設置為評估模式correct = 0total = 0# 不計算梯度,節省內存和計算時間with torch.no_grad():for data, target in test_loader:outputs = model(data)_, predicted = torch.max(outputs.data, 1)total += target.size(0)correct += (predicted == target).sum().item()accuracy = 100 * correct / totalprint(f'Test Accuracy: {accuracy:.2f}%')return accuracy# 6. 運行訓練和測試
if __name__ == '__main__':# 訓練模型print("開始訓練模型...")train_losses = train(model, train_loader, criterion, optimizer, epochs=5)# 測試模型print("開始測試模型...")test_accuracy = test(model, test_loader)# 保存模型torch.save(model.state_dict(), 'mnist_model.pth')print("模型已保存為 mnist_model.pth")# 繪制訓練損失曲線plt.plot(train_losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.show()# 可視化一些預測結果model.eval()with torch.no_grad():# 獲取一些測試數據data, target = next(iter(test_loader))outputs = model(data)_, predicted = torch.max(outputs, 1)# 顯示前5個樣本fig, axes = plt.subplots(1, 5, figsize=(15, 3))for i in range(5):axes[i].imshow(data[i].numpy().squeeze(), cmap='gray')axes[i].set_title(f'預測: {predicted[i]}, 實際: {target[i]}')axes[i].axis('off')plt.show()

代碼解釋

上面的代碼實現了一個完整的 PyTorch 模型訓練流程,主要包含以下幾個部分:

  1. 數據準備:
    ? 使用torchvision.datasets加載 MNIST 數據集
    ? 對數據進行轉換(轉為 Tensor 并標準化)
    ? 使用DataLoader創建可迭代的數據加載器
  2. 模型定義:
    ? 定義了一個簡單的兩層神經網絡SimpleNN
    ? 第一層將 28x28 的圖像展平后映射到 128 維
    ? 第二層將 128 維特征映射到 10 個類別(對應數字 0-9)
  3. 訓練設置:
    ? 使用交叉熵損失函數(CrossEntropyLoss)
    ? 使用 Adam 優化器
    ? 設置批量大小為 64,訓練輪次為 5
  4. 訓練過程:
    ? 循環多個訓練輪次(epoch)
    ? 每個輪次中迭代所有批次數據
    ? 執行前向傳播、計算損失、反向傳播和參數更新
  5. 測試評估:
    ? 在測試集上評估模型性能
    ? 計算并打印準確率
  6. 結果可視化:
    ? 繪制訓練損失曲線
    ? 展示部分測試樣本的預測結果
    運行后,程序會自動下載 MNIST 數據集(首次運行),然后開始訓練模型。訓練完成后,會打印測試準確率,保存模型,并顯示損失曲線和部分預測結果。
    這個示例比較基礎,你可以根據需要調整模型結構、超參數(如學習率、批量大小、訓練輪次等)來獲得更好的性能。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/96494.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/96494.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/96494.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Qt應用程序發布方式

解決的問題:在自己電腦上用QT Creator編譯的exe文件放到其他電腦上不能正常打開的問題。1、拷貝已經編譯好的exe應用程序到桌面文件夾。桌面新建文件夾WindowsTest,并且將編譯好的軟件WindowTest.exe放入此文件夾中。2、在此文件夾空白處按住Shift再點擊…

Linux 軟件編程(九)網絡編程:IP、端口與 UDP 套接字

1. 學習目的實現 不同主機之間的進程間通信。在 Linux 下,進程間通信(IPC)不僅可以發生在同一臺主機上,也可以通過網絡實現不同主機之間的通信。要做到這一點,必須同時滿足以下兩個條件:物理層面&#xff1…

5.Kotlin作用于函數let、run、with、apply、also

選擇建議 需要返回值:使用 let、run 或 with配置對象:使用 apply附加操作:使用 also非空檢查:使用 let鏈式調用:使用 let 或 run Kotlin作用域函數詳解 概述 Kotlin提供了5個作用域函數:let、run、with、ap…

嵌入式學習日記(32)Linux下的網絡編程

1. 目的不同主機,進程間通信。2. 解決的問題1). 主機與主機之間物理層面必須互聯互通。2.) 進程與進程在軟件層面必須互聯互通。IP地址:計算機的軟件地址,用來標識計算機設備 MAC地址:計算機的硬件地址&…

C#_接口設計:角色與契約的分離

2.3 接口設計:角色與契約的分離 在軟件架構中,接口(Interface)遠不止是一種語言結構。它是一份契約(Contract),明確規定了實現者必須提供的能力,以及使用者可以依賴的服務。優秀的接…

vsCode或Cursor 使用remote-ssh插件鏈接遠程終端

一、Remote-SSH介紹Remote-SSH 是 VS Code 官方提供的一個擴展插件,允許開發者通過 SSH 協議連接到遠程服務器,并在本地編輯器中直接操作遠程文件,實現遠程開發。它將本地編輯器的功能(如語法高亮、智能提示、調試等)與…

C語言實戰:從零開始編寫一個通用配置文件解析器

資料合集下載鏈接: ?https://pan.quark.cn/s/472bbdfcd014? 在軟件開發中,我們經常需要將一些可變的參數(如數據庫地址、端口號、游戲角色屬性等)與代碼本身分離,方便日后修改而無需重新編譯整個程序。這種存儲配置信息的文件,我們稱之為配置文件。 一、 什么是配置…

車機兩分屏運行Unity制作的效果

目錄 效果概述 實現原理 完整實現代碼 實際車機集成注意事項 1. 顯示系統集成 多屏顯示API調用 代碼示例(AAOS副駕屏顯示) 2. 性能優化 GPU Instancing 其他優化技術 3. 輸入處理 觸控處理 物理按鍵處理 4. 安全規范 駕駛員側限制 乘客側…

vivo“空間計算-機器人”生態落下關鍵一子

出品 | 何璽排版 | 葉媛不出所料,vivo Vision熱度很高。從21號下午發布到今天(22號),大眾圍繞vivo Vision探索版展開了多方面的討論,十分熱烈。從討論來看,大家現在的共識是,MR行業目前還處于起…

Azure TTS Importer:一鍵導入,將微軟TTS語音接入你的閱讀軟件!

Azure TTS Importer:一鍵導入,將微軟TTS語音接入你的閱讀軟件! 文章來源:Poixe AI 厭倦了機械、生硬的文本朗讀?想讓你的閱讀軟件擁有自然流暢的AI語音?今天,我們將為您介紹一款強大且安全的開…

用過redis哪些數據類型?Redis String 類型的底層實現是什么?

Redis 數據類型有哪些? 詳細可以查看:數據類型及其應用場景 基本數據類型: String:最常用的一種數據類型,String類型的值可以是字符串、數字或者二進制,但值最大不能超過512MB。一般用于 緩存和計數器 Ha…

大視協作碼垛機:顛覆傳統制造,開啟智能工廠新紀元

在東三省某食品廠的深夜生產線上,碼垛作業正有序進行,卻不見人影——這不是魔法,而是大視協作碼垛機器人帶來的現實變革。在工業4.0浪潮席卷全球的今天,智能制造已成為企業生存與發展的必由之路。智能碼垛環節作為產線的關鍵步驟&…

c# 保姆級分析繼承詳見問題 父類有一個列表對象,子類繼承這個列表對象并對其進行修改后,將子類對象賦值給父類對象,父類對象是否能包含子類新增的內容?

文章目錄 深入解析:父類與子類列表繼承關系的終極指南 一、問題背景:從實際開發困惑說起 二、基礎知識回顧:必備概念理解 2.1 繼承的本質 2.2 引用類型 vs 值類型 2.3 多態的實現方式 三、核心問題分析:列表繼承場景 3.1 基礎代碼示例 3.2 關鍵問題分解 3.3 結論驗證 四、深…

tensorflow-gpu 2.7下的tensorboard與profiler插件版本問題

可行版本: python3.9.23cuda12.0tensorflow-gpu2.7.0tensorboard2.20.0 tensorboard-plugin-profile 2.4.0 問題描述: 1. 安裝tensorboard后運行tensorboard --logdirlogs在網頁中打開,發現profile模塊無法顯示,報錯如下&#x…

數據結構青銅到王者第一話---數據結構基本常識(1)

目錄 一、集合框架 1、什么是集合框架 2、集合框架的重要性 2.1開發中的使用 2.2筆試及面試題 3、背后涉及的數據結構以及算法 3.1什么是數據結構 3.2容器背后對應的數據結構 3.3相關java知識 3.4什么是算法 3.5如何學好數據結構以及算法 二、時間和空間復雜度 1、…

【Verilog】延時和時序檢查

Verilog中延時和時序檢查1. 延時模型1.1 分布延遲1.2 集總延遲1.3 路徑延遲2. specify 語法2.1 指定路徑延時基本路徑延時邊沿敏感路徑延時狀態依賴路徑延時2.2 時序檢查$setup, $hold, $setuphold$recovery, $removal, $recrem$width, $periodnotifier1. 延時模型 真實的邏輯元…

DigitalOcean Gradient AI平臺現已支持OpenAI gpt-oss

OpenAI 的首批開源 GPT 模型(200 億和 1200 億參數)現已登陸 Gradient AI 平臺。此次發布讓開發者在構建 AI 應用時擁有更高的靈活度和更多選擇,無論是快速原型還是大規模生產級智能體,都能輕松上手。新特性開源 GPT 模型&#xf…

藏在 K8s 幕后的記憶中樞(etcd)

目錄1)etcd 基本架構2)etcd 的讀寫流程總覽a)一個讀流程b)一個寫流程3)k8s存儲數據過程源碼解讀4)watch 機制Informer 機制etcd watch機制etcd的watchableStore源碼解讀5) k8s大規模集群時會存在…

騰訊云EdgeOne安全防護:快速上手,全面抵御Web攻擊

為什么需要專業的安全防護? 在當今數字化時代,網站面臨的安全威脅日益增多。據統計,2023年全球Web應用程序攻擊超7千億次,持續快速增長。 其中最常見的包括: DDoS攻擊:通過海量請求使服務器癱瘓Web應用攻…

SpringBoot中的條件注解

文章目錄前言什么是條件注解核心原理常用條件注解詳解1. ConditionalOnClass和ConditionalOnMissingClass2. ConditionalOnBean和ConditionalOnMissingBean3. ConditionalOnProperty應用場景:多數據源配置在SpringBoot自動配置中的核心作用自動配置的工作原理經典自…