PyTorch多層感知機(MLP)模型構建與MNIST分類訓練

沖沖沖😊
here😊

文章目錄

  • PyTorch多層感知機模型構建與MNIST分類訓練筆記
    • 🎯 1. 任務概述
    • ?? 2. 環境設置
      • 2.1 導入必要庫
      • 2.2 GPU配置
    • 🧠 3. 模型構建
      • 3.1 模型定義關鍵點
      • 3.2 損失函數選擇
      • 3.3 模型初始化與設備選擇
    • 🔧 4. 優化器配置
      • 4.1 隨機梯度下降優化器
    • 🔄 5. 訓練循環實現
      • 5.1 訓練函數設計
      • 5.2 測試函數設計
    • 📦 6. 數據準備
      • 6.1 加載MNIST數據集
    • 🚀 7. 訓練執行
      • 7.1 訓練循環主體
      • 7.2 訓練過程輸出(部分)
    • 📊 8. 結果可視化
      • 8.1 損失曲線繪制
      • 8.2 準確率曲線繪制

PyTorch多層感知機模型構建與MNIST分類訓練筆記

🎯 1. 任務概述

解決MNIST手寫數字分類問題,創建一個簡單的多層感知機(MLP)模型

  • 使用torch.nn.Linear層構建模型
  • 使用ReLU作為激活函數
  • 包含兩個全連接隱藏層(120和84個神經元)和輸出層(10個神經元對應10個數字類別)
  • 模型輸入為展平后的28×28=784像素圖像

?? 2. 環境設置

2.1 導入必要庫

import torch
from torch import nn
import os

2.2 GPU配置

# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,6"  # 只使用空閑的GPU

🧠 3. 模型構建

3.1 模型定義關鍵點

class Model(nn.Module):def __init__(self):super().__init__()# 第一層輸入展平后的特征長度28乘28,創建120個神經元self.liner_1 = nn.Linear(28*28, 120)# 第二層輸入的是前一層的輸出,創建84個神經元self.liner_2 = nn.Linear(120, 84)# 輸出層接受第二層的輸入84,輸出分類個數10self.liner_3 = nn.Linear(84, 10)def forward(self, input):x = input.view(-1, 28*28)  # 將輸入展平為二維(1,28,28)->(28*28)x = torch.relu(self.liner_1(x))x = torch.relu(self.liner_2(x))x = self.liner_3(x)return x

📝 模型結構說明

  1. 輸入層:將28×28圖像展平為784維向量
  2. 隱藏層1:120個神經元,使用ReLU激活
  3. 隱藏層2:84個神經元,使用ReLU激活
  4. 輸出層:10個神經元對應10個數字類別

3.2 損失函數選擇

loss_fn = nn.CrossEntropyLoss()  # 交叉熵損失函數
'''
注意兩個參數
1. weight: 各類別的權重(處理不平衡數據集)
2. ignore_index: 忽略特定類別的索引
另外,它要求實際類別為數值編碼,而不是獨熱編碼
'''

🔍 為什么選擇交叉熵損失?

  • 適用于多分類問題
  • 內部集成了Softmax計算,簡化實現流程
  • 對錯誤分類有較強的懲罰

3.3 模型初始化與設備選擇

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)
# print(device)  # 可選:打印使用的設備

💡 GPU加速提示
使用.to(device)將模型移動到GPU可顯著加快訓練速度,特別是對于大模型和大數據集

🔧 4. 優化器配置

4.1 隨機梯度下降優化器

optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

🔧 關鍵參數解析

  • params: 需要優化的模型參數(通常為model.parameters()
  • lr=0.005: 學習率,控制參數更新步長的超參數
  • 其他可選參數:momentum(動量),weight_decay(L2正則化)

🔄 5. 訓練循環實現

5.1 訓練函數設計

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 獲取當前數據集樣本總數量num_batches = len(dataloader)   # 獲取當前data loader總批次數# train_loss用于累計所有批次的損失之和, correct用于累計預測正確的樣本總數train_loss, correct = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)# 進行預測,并計算當前批次的損失pred = model(X)loss = loss_fn(pred, y)# 利用反向傳播算法,根據損失優化模型參數optimizer.zero_grad()   # 先將梯度清零loss.backward()          # 損失反向傳播,計算模型參數梯度optimizer.step()         # 根據梯度優化參數with torch.no_grad():# correct用于累計預測正確的樣本總數correct += (pred.argmax(1) == y).type(torch.float).sum().item()# train_loss用于累計所有批次的損失之和train_loss += loss.item()# train_loss 是所有批次的損失之和,所以計算全部樣本的平均損失時需要除以總的批次數train_loss /= num_batches# correct 是預測正確的樣本總數,若計算整個epoch總體正確率,需要除以樣本總數量correct /= sizereturn train_loss, correct

5.2 測試函數設計

def test(dataloader, model):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizereturn test_loss, correct

📊 數據加載器相關方法區別

方法返回內容適用場景
len(dataset)數據集總樣本數(如100)數據統計、劃分
len(dataloader)總批次數(如4)訓練循環控制
len(dataloader.dataset)等同于 len(dataset)需要訪問原始數據時

📦 6. 數據準備

6.1 加載MNIST數據集

import torchvision
from torchvision.transforms import ToTensortrain_ds = torchvision.datasets.MNIST("data/", train=True, transform=ToTensor(), download=True)
test_ds = torchvision.datasets.MNIST("data/", train=False, transform=ToTensor(), download=True)train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=64)

🚀 7. 訓練執行

7.1 訓練循環主體

# 對全部的數據集訓練50個epoch(一個epoch表示對全部數據訓練一遍)
epochs = 50 
train_loss, train_acc = [], []
test_loss, test_acc = [], []for epoch in range(epochs):# 調用train()函數訓練epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)# 調用test()函數測試epoch_test_loss, epoch_test_acc = test(test_dl, model)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)# 定義一個打印模板template = ("epoch:{:2d},train_loss:{:.6f},train_acc:{:.1f}%,""test_loss:{:.5f},test_acc:{:.1f}%")print(template.format(epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))print("Done")

7.2 訓練過程輸出(部分)

epoch: 0,train_loss:2.157364,train_acc:46.7%,test_loss:1.83506,test_acc:63.7%
epoch: 1,train_loss:1.222660,train_acc:74.3%,test_loss:0.74291,test_acc:81.8%
epoch: 2,train_loss:0.612381,train_acc:84.0%,test_loss:0.49773,test_acc:86.3%
...
epoch:48,train_loss:0.110716,train_acc:96.9%,test_loss:0.12003,test_acc:96.4%
epoch:49,train_loss:0.108877,train_acc:97.0%,test_loss:0.11783,test_acc:96.5%
Done

📈 訓練趨勢分析

  • 初始準確率:46.7%(訓練集),63.7%(測試集)
  • 最終準確率:97.0%(訓練集),96.5%(測試集)
  • 過擬合現象輕微:訓練集和測試集性能差距僅0.5%

📊 8. 結果可視化

8.1 損失曲線繪制

import matplotlib.pyplot as pltplt.plot(range(1, epochs+1), train_loss, label="train_loss")
plt.plot(range(1, epochs+1), test_loss, label="test_loss", ls="--")
plt.xlabel("epoch")
plt.legend()
plt.show()

注釋:損失曲線顯示訓練初期損失快速下降,后期趨于平穩

8.2 準確率曲線繪制

plt.plot(range(1, epochs+1), train_acc, label="train_acc")
plt.plot(range(1, epochs+1), test_acc, label="test_acc")
plt.xlabel("epoch")
plt.legend()
plt.show()

注釋:準確率曲線穩步上升,最終達到96.5%的測試準確率

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/91058.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/91058.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/91058.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

android tabLayout 切換fragment fragment生命周期

1、TabLayout 與 Fragment 結合使用的常見方式 通常會使用 FragmentPagerAdapter 或 FragmentStatePagerAdapter 與 ViewPager 配合,再將 TabLayout 與 ViewPager 關聯,實現通過 TabLayout 切換 Fragment。 以下是布局文件示例 activity_main.xml: <LinearLayout xmln…

馬蹄集 BD202401補給

可怕的戰爭發生了&#xff0c;小度作為后勤保障工作人員&#xff0c;也要為了保衛國家而努力。現在有 N(1≤N≤)個堡壘需要補給&#xff0c;然而總的預算 B(1≤B≤)是有限的。現在已知第 i 個堡壘需要價值 P(i) 的補給&#xff0c;并且需要 S(i) 的運費。 鑒于小度與供應商之間…

《Llava:Visual Instruction Tuning》論文精讀筆記

論文鏈接&#xff1a;arxiv.org/pdf/2304.08485 參考視頻&#xff1a;LLAVA講解_嗶哩嗶哩_bilibili [論文速覽]LLaVA: Visual Instruction Tuning[2304.08485]_嗶哩嗶哩_bilibili 標題&#xff1a;Visual Instruction Tuning 視覺指令微調 背景引言 大模型的Instruction…

【DataWhale】快樂學習大模型 | 202507,Task01筆記

引言 我從2016年開始接觸matlab看別人做語音識別&#xff0c;再接觸tensorflow的神經網絡&#xff0c;2017年接觸語音合成&#xff0c;2020年做落地的醫院手寫數字識別。到2020年接觸pytorch做了計算機視覺圖像分類&#xff0c;到2021年做了目標檢測&#xff0c;2022年做了文本…

機器學習中的樸素貝葉斯(Naive Bayes)模型

1. 用實例來理解樸素貝葉斯 下面用具體的數據來演示垃圾郵件 vs 正常郵件的概率計算假設我們有一個小型郵件數據集郵件內容類別&#xff08;垃圾/正常&#xff09;“免費 贏取 大獎”垃圾“免費 參加會議”正常“中獎 點擊 鏈接”垃圾“明天 開會”正常“贏取 免費 禮品”垃圾 …

document.documentElement詳解

核心概念定義 它始終指向當前文檔的根元素&#xff0c;在 HTML 文檔中對應 <html> 標簽。與 document.body&#xff08;對應 <body>&#xff09;和 document.head&#xff08;對應 <head>&#xff09;形成層級關系。與 document.body 的區別 <html> &l…

c#進階之數據結構(動態數組篇)----Queue

1、簡介這個是c#封裝的隊列類型&#xff0c;同棧相反&#xff0c;這個是先進先出&#xff0c;一般用于事件注冊&#xff0c;或者數據的按順序處理&#xff0c;理解為需要排隊處理的可以用隊列來處理。注意&#xff0c;隊列一定是有順序的&#xff0c;先進確實是會先出&#xff…

使用 keytool 在服務器上導入證書操作指南(SSL 證書驗證錯誤處理)

使用 keytool 在服務器上導入證書操作指南(SSL 證書驗證錯誤處理) 一、概述 本文檔用于指導如何在運行 Java 應用程序的服務器上,通過keytool工具將證書導入 Java 信任庫,解決因證書未被信任導致的 SSL/TLS 通信問題(如PKIX path building failed錯誤)。 二、操作步驟…

VUE export import

目錄 命名導出 導出變量 導出函數 總結 默認導出 導出變量 導出函數 總結 因為總是搞不懂export和Import什么時候需要加{}&#xff0c;什么時候不用&#xff0c;所以自己測試了一下&#xff0c;以下是總結。 需不需要加{}取決于命名導出還是默認導出&#xff0c;命名導…

端側寵物識別+拍攝控制智能化:解決設備識別頻次識別率雙低問題

隨著寵物成為家庭重要成員&#xff0c;寵物影像創作需求激增&#xff0c;傳統相機系統 “人臉優先” 的調度邏輯已難以應對寵物拍攝的復雜場景。毛發邊緣模糊、動態姿態多變、光照反差劇烈等問題&#xff0c;推動著智能拍攝技術向 “寵物優先” 范式轉型。本文基于端側 AI 部署…

Popover API 實戰指南:前端彈層體驗的原生重構

&#x1fa84; Popover API 實戰指南&#xff1a;前端彈層體驗的原生重構 還在用 position: absolute JS 定位做 tooltip&#xff1f;還在引入大型 UI 庫只為做個浮層&#xff1f;現在瀏覽器已經支持了真正原生的「彈出層 API」&#xff0c;一行 HTMLCSS 就能構建可交互、無障…

CCS-MSPM0G3507-6-模塊篇-OLED的移植

前言基礎篇結束&#xff0c;接下來我們來開始進行模塊驅動如果懂把江科大的OLED移植成HAL庫&#xff0c;那其實也沒什么難首先配置OLED的引腳這里我配置PA16和17為推挽輸出&#xff0c;PA0和1不要用&#xff0c;因為只有那兩個引腳能使用MPU6050 根據配置出來的引腳&#xff0c…

意識邊界的算法戰爭—腦機接口技術重構人類認知的顛覆性挑戰

一、神經解碼的技術奇點當癱瘓患者通過腦電波操控機械臂飲水&#xff0c;當失語者借由皮層電極合成語音&#xff0c;腦機接口&#xff08;BCI&#xff09;正從醫療輔助工具演變為認知增強的潘多拉魔盒。這場革命的核心突破在于神經信號解析精度的指數躍遷&#xff1a;傳統腦電圖…

詳解彩信 SMIL規范

以下內容將系統地講解彩信 MMS&#xff08;Multimedia Messaging Service&#xff09;中使用的 SMIL&#xff08;Synchronized Multimedia Integration Language&#xff09;規范&#xff0c;涵蓋歷史、語法結構、在彩信中的裁剪與擴展、常見實現細節以及最佳實踐。末尾附示例代…

《紅藍攻防:構建實戰化網絡安全防御體系》

《紅藍攻防&#xff1a;構建實戰化網絡安全防御體系》文章目錄第一部分&#xff1a;網絡安全的攻防全景 1、攻防演練的基礎——紅隊、藍隊、紫隊 1.1 紅隊&#xff08;攻擊方&#xff09; 1.2 藍隊&#xff08;防守方&#xff09; 1.3 紫隊&#xff08;協調方&#xff09; 2、5…

MFC UI大小改變與自適應

文章目錄窗口最大化庫EasySize控件自適應大小窗口最大化 資源視圖中開放最大化按鈕&#xff0c;添加窗口樣式WS_MAXIMIZEBOX。發送大小改變消息ON_WM_SIZE()。響應大小改變。 void CDlg::OnSize(UINT nType, int cx, int cy) {CDialog::OnSize(nType, cx, cy);//獲取改變后窗…

【Linux網絡】:HTTP(應用層協議)

目錄 一、HTTP 1、URL 2、協議格式 3、請求方法 4、狀態碼 5、Header信息 6、會話保持Cookie 7、長連接 8、簡易版HTTP服務器代碼 一、HTTP 我們在編寫網絡通信代碼時&#xff0c;我們可以自己進行協議的定制&#xff0c;但實際有很多優秀的工程師早就寫出了許多非常…

C++-linux 7.文件IO(三)文件元數據與 C 標準庫文件操作

文件 IO 進階&#xff1a;文件元數據與 C 標準庫文件操作 在 Linux 系統中&#xff0c;文件操作不僅涉及數據的讀寫&#xff0c;還包括對文件元數據的管理和高層庫函數的使用。本文將從文件系統的底層存儲機制&#xff08;inode 與 dentry&#xff09;講起&#xff0c;詳細解析…

WordPress Ads Pro Plugin本地文件包含漏洞(CVE-2025-4380)

免責聲明 本文檔所述漏洞詳情及復現方法僅限用于合法授權的安全研究和學術教育用途。任何個人或組織不得利用本文內容從事未經許可的滲透測試、網絡攻擊或其他違法行為。 前言:我們建立了一個更多,更全的知識庫。每日追蹤最新的安全漏洞,追中25HW情報。 更多詳情: http…