PyTorch 實現 MNIST 手寫數字識別

PyTorch 實現 MNIST 手寫數字識別

MNIST 是一個經典的手寫數字數據集,包含 60000 張訓練圖像和 10000 張測試圖像。使用 PyTorch 實現 MNIST 分類通常包括數據加載、模型構建、訓練和評估幾個部分。

數據加載與預處理

使用 torchvision 加載 MNIST 數據集,并進行歸一化和數據增強(可選)。以下是數據加載的示例代碼:

import torch
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

構建模型

定義一個簡單的卷積神經網絡(CNN)模型:

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.fc1 = nn.Linear(1024, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

訓練模型

定義優化器和損失函數,并進行訓練:

model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch: {epoch}, Loss: {loss.item():.4f}')

測試模型

在測試集上評估模型性能:

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print(f'Test Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)')

完整訓練循環

將訓練和測試整合到一個完整的循環中:進行10次訓練

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)for epoch in range(1, 10):train(model, device, train_loader, optimizer, epoch)test(model, device, test_loader)

模型保存與加載

訓練完成后,可以保存模型:

torch.save(model.state_dict(), 'mnist_cnn.pt')

加載模型:

加載模塊

model = Net()
model.load_state_dict(torch.load('mnist_cnn.pt'))
model.eval()

或者是Mnist其他源代碼

Mnist main.py

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLRclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout(0.25)self.dropout2 = nn.Dropout(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)output = F.log_softmax(x, dim=1)return output#訓練模型
def train(args, model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))if args.dry_run:break#測試模型
def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch losspred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))#主函數 MNIST Example的主函數
def main():# Training settingsparser = argparse.ArgumentParser(description='PyTorch MNIST Example')parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')parser.add_argument('--epochs', type=int, default=14, metavar='N',help='number of epochs to train (default: 14)')parser.add_argument('--lr', type=float, default=1.0, metavar='LR',help='learning rate (default: 1.0)')parser.add_argument('--gamma', type=float, default=0.7, metavar='M',help='Learning rate step gamma (default: 0.7)')parser.add_argument('--no-accel', action='store_true',help='disables accelerator')parser.add_argument('--dry-run', action='store_true',help='quickly check a single pass')parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')parser.add_argument('--log-interval', type=int, default=10, metavar='N',help='how many batches to wait before logging training status')parser.add_argument('--save-model', action='store_true', help='For Saving the current Model')args = parser.parse_args()use_accel = not args.no_accel and torch.accelerator.is_available()torch.manual_seed(args.seed)if use_accel:device = torch.accelerator.current_accelerator()else:device = torch.device("cpu")train_kwargs = {'batch_size': args.batch_size}test_kwargs = {'batch_size': args.test_batch_size}if use_accel:accel_kwargs = {'num_workers': 1,'pin_memory': True,'shuffle': True}train_kwargs.update(accel_kwargs)test_kwargs.update(accel_kwargs)transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])dataset1 = datasets.MNIST('../data', train=True, download=True,transform=transform)dataset2 = datasets.MNIST('../data', train=False,transform=transform)train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)model = Net().to(device)optimizer = optim.Adadelta(model.parameters(), lr=args.lr)scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)for epoch in range(1, args.epochs + 1):train(args, model, device, train_loader, optimizer, epoch)test(model, device, test_loader)scheduler.step()if args.save_model:torch.save(model.state_dict(), "mnist_cnn.pt")if __name__ == '__main__':main()

帶命令的Mnist函數

python main.py --help
usage: main.py [-h] [--batch-size N] [--test-batch-size N] [--epochs N] [--lr LR] [--gamma M] [--no-accel][--dry-run] [--seed S] [--log-interval N] [--save-model]PyTorch MNIST Exampleoptional arguments:-h, --help           show this help message and exit--batch-size N       input batch size for training (default: 64)--test-batch-size N  input batch size for testing (default: 1000)--epochs N           number of epochs to train (default: 14)--lr LR              learning rate (default: 1.0)--gamma M            Learning rate step gamma (default: 0.7)--no-accel           disables accelerator--dry-run            quickly check a single pass--seed S             random seed (default: 1)--log-interval N     how many batches to wait before logging training status--save-model         For Saving the current Model

數據目錄

注意事項

  • 確保安裝了 PyTorch 和 torchvision 庫。
  • 可以根據硬件條件調整 batch_size
  • 模型結構和超參數(如學習率)可以根據需求調整。
  • 使用 GPU 可以顯著加速訓練。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/83857.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/83857.shtml
英文地址,請注明出處:http://en.pswp.cn/web/83857.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python內存互斥與共享深度探索:從GIL到分布式內存的實戰之旅

引言:并發編程的內存困局 在開發高性能Python應用時,我遭遇了這樣的困境:多進程間需要共享百萬級數據,而多線程間又需保證數據一致性。傳統解決方案要么性能低下,要么引發競態條件。本文將深入探討Python內存互斥與共…

【Unity】使用 C# SerialPort 進行串口通信

索引 一、SerialPort串口通信二、使用SerialPort1.創建SerialPort對象,進行基本配置2.寫入串口數據①.寫入串口數據的方法②.封裝數據 3.讀取串口數據①.讀取串口數據的方法②.解析數據 4.讀取串口數據的時機①.DataReceived事件②.多線程接收數據 5.粘包問題處理 一…

如何寫好單元測試:Mock 脫離數據庫,告別 @SpringBootTest 的重型啟動

如何寫好單元測試:Mock 脫離數據庫,告別 SpringBootTest 的重型啟動 作者:Killian(重慶) — 歡迎各位架構獵頭、技術布道者聯系我,項目實戰豐富,代碼穩健,Mock測試愛好者。 技術棧&a…

【DNS】在 Windows 下修改 `hosts` 文件

在 Windows 下修改 hosts 文件,一般用于本地 DNS 覆蓋。操作步驟如下(以 Windows 10/11 為例): 1. 以管理員權限打開記事本 點擊 開始 → 輸入 “記事本”在“記事本”圖標上右鍵 → 選擇 以管理員身份運行 如果提示“是否允許此…

共享內存實現進程通信

目錄 system V共享內存 共享內存示意圖 共享內存函數 shmget函數 shmat函數 shmdt函數 shmctl函數 代碼示例 shm頭文件 構造函數 獲取key值 創建者的構造方式 GetShmHelper 函數 GetShmUseCreate 函數 使用者的構造方式 GetShmForUse 函數 分離附加操作 DetachShm 函數 AttachS…

6月15日星期日早報簡報微語報早讀

6月15日星期日,農歷五月二十,早報#微語早讀。 1、證監會擬修訂期貨公司分類評價:明確扣分標準,優化加分標準; 2、國家考古遺址公園再添10家,全國已評定65家; 3、北京多所高校禁用羅馬仕充電寶…

破解關鍵領域軟件測試“三重難題”:安全、復雜性、保密性

在國家關鍵領域,軟件系統正成為核心戰斗力的一部分。相比通用軟件,關鍵領域軟件在 安全性、復雜性、實時性、保密性 等方面要求極高。如何保障安全合規前提下提升測試效率,確保系統穩定,已成為軟件質量保障的核心挑戰。 關鍵領域…

記錄一次 Oracle DG 異常停庫問題解決過程

記錄一次 Oracle DG 異常停庫問題解決過程 某醫院有以下架構的雙節點 Oracle 集群: 節點1:172.16.20.2 節點2:172.16.20.3 SCAN IP:172.16.20.1 DG:172.16.20.1206月12日,醫院信息科用戶反映無法連接 DG 服務器。 登錄 DG 服務…

MySQL使用EXPLAIN命令查看SQL的執行計劃

1?、EXPLAIN 的語法 MySQL 中的 EXPLAIN 命令是用于分析 SQL 查詢執行計劃的關鍵工具,它能幫助開發者理解查詢的執行方式并找出性能瓶頸??。 語法格式: EXPLAIN <sql語句> 【示例】查詢學生表關聯班級表的執行計劃。 (1)創建班級信息表和學生信息表,并創建索…

Go語言2個協程交替打印

WaitGroup 無緩沖channel waitgroup 用來控制2個協程 Add() 、Done()、Wait() channel用來實現信號的傳遞和信號的打印 ch1: 用來記錄打印的信號 ch2:用來實現信號的傳遞&#xff0c;實現2個協程的順序打印 package mainimport ("fmt""sync" )func ma…

微信小程序 路由跳轉

路由方式 官方參考文檔 wx.switchTab 實現底部導航欄 1.配置信息 app.json"tabBar": {"custom": true,"list": [{"pagePath": "pages/home/index","text": "首頁"},{"pagePath": "p…

[Java 基礎]正則表達式

正則表達式是一種強大的文本模式匹配工具&#xff0c;它使用一種特殊的語法來描述要搜索或操作的字符串模式。在 Java 中&#xff0c;我們可以使用 java.util.regex包提供的類來處理正則表達式。 :::color3 正則表達式不止 Java 語言提供了相應的功能&#xff0c;很多其他語言…

ArcGIS安裝出現1606錯誤解決辦法

問題背景&#xff1a; 由于最近Arcgis10.2打是有些功能不正常退出&#xff0c;比如arctoolbox中的&#xff0c;table to excel 功能&#xff0c;只要一點擊&#xff0c;arcgis就報錯退出&#xff0c;平常在使用過程中&#xff0c;也經常出現一些莫名其妙的崩潰現象&#xff0c…

wpf 解決DataGridTemplateColumn中width綁定失效問題

感謝酪酪烤奶 提供的Solution 文章目錄 感謝酪酪烤奶 提供的Solution使用示例示例代碼分析各類交互流程 WPF DataGrid 列寬綁定機制分析整體架構數據流分析1. ViewModel到Slider的綁定2. ViewModel到DataGrid列的綁定a. 綁定代理(BindingProxy)b. 列寬綁定c. 數據流 關鍵機制詳…

語音轉文本ASR、文本轉語音TTS

ASR Automatic Speech Recognition&#xff0c;語音轉文本。 技術難點&#xff1a; 聲學多樣性 口音、方言、語速、背景噪聲會影響識別準確性&#xff1b;多人對話場景&#xff08;如會議錄音&#xff09;需要區分說話人并分離語音。 語言模型適配 專業術語或網絡新詞需要動…

通用embedding模型和通用reranker模型,觀測調研

調研Qwen3-Embedding和Qwen3-Reranker 現在有一個的問答庫&#xff0c;包括150個QA-pair&#xff0c;用10個query去同時檢索問答庫的300個questionanswer Embedding模型&#xff0c;query-question的匹配分數 普遍高于 query-answer的匹配分數。比如對于10個query&#xff0c…

基于YOLOv8+Deepface的人臉檢測與識別系統

摘要 人臉檢測與識別系統是一個集成了先進計算機視覺技術的應用&#xff0c;通過深度學習模型實現人臉檢測、識別和管理功能。系統采用雙模式架構&#xff1a; ??注冊模式??&#xff1a;檢測新人臉并添加到數據庫??刪除模式??&#xff1a;識別數據庫中的人臉并移除匹…

Grdle版本與Android Gradle Plugin版本, Android Studio對應關系

Grdle版本與Android Gradle Plugin版本&#xff0c; Android Studio對應關系 各個 Android Gradle 插件版本所需的 Gradle 版本&#xff1a; https://developer.android.com/build/releases/gradle-plugin?hlzh-cn Maven上發布的Android Gradle Plugin&#xff08;AGP&#x…

用c語言實現簡易c語言掃雷游戲

void test(void) {int input 0;do{menu();printf("請選擇&#xff1a; >");scanf("%d", &input);switch (input){menu();case 1:printf("掃雷\n");game();break;case 2:printf("退出游戲\n");break;default:printf("輸入…

系統辨識的研究生水平讀書報告期末作業參考

這是一份關于系統辨識的研究生水平讀書報告&#xff0c;內容系統完整、邏輯性強&#xff0c;并深入探討了理論、方法與實際應用。報告字數超過6000字 從理論到實踐&#xff1a;系統辨識的核心思想、方法論與前沿挑戰 摘要 系統辨識作為連接理論模型與客觀世界的橋梁&#xff…