1. pytorch手寫數字預測

1. pytorch手寫數字預測

  • 1.背景
  • 2.準備數據集
  • 2.定義模型
  • 3.dataloader和訓練
  • 4.訓練模型
  • 5.測試模型
  • 6.保存模型

1.背景

因為自身的研究方向是多模態目標跟蹤,突然對其他的視覺方向產生了興趣,所以心血來潮的回到最經典的視覺任務手寫數字預測上來,所以這份教程并不是一份非常詳盡的教程,是在一部分pytorch,深度學習基礎上的教程,如果需要的是非常保姆級的教程建議看別的文章

2.準備數據集

這里我才用了直接導torchvision中的dataset包來下載Mnist數據集,也算是一個非常經典的數據集了

# 導入數據集
from torchvision.datasets import MNIST
import torch# 設置隨機種子
torch.manual_seed(3306)# 數據預處理
from torchvision import transforms
# 定義數據轉換
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為 Tensortransforms.Normalize((0.1307,), (0.3081,))  # 標準化
])# 下載 MNIST 數據集
mnist_train = MNIST(root='./dataset_file/mnist_raw', train=True, download=True,transform=transform)
mnist_test = MNIST(root='./dataset_file/mnist_raw', train=False, download=True,transform=transform)
# 查看數據集大小
print(f"MNIST train dataset size: {len(mnist_train)}")
print(f"MNIST test dataset size: {len(mnist_test)}")

其中,MNIST()中的root代表的是數據集存放的位置,download代表是如果當前位置沒有數據集是否需要下載。
transformer則是對數據的處理方式,我這里采用了簡單地轉成tensor和簡單地標準化。

不過這樣子下載下來的數據集是二進制格式的,無法直接查看圖片,當然,如果你需要查看圖片,也有辦法。

# 查看圖片
import matplotlib.pyplot as pltdef show_image(id):img, label = mnist_train[id]img = img.squeeze().numpy()  # 去掉通道維度print(img.shape)# print(img)plt.imshow(img, cmap='gray')plt.title(f"Label: {label}")plt.axis('off')plt.show()show_image(1)

效果
在這里插入圖片描述

又或者你想要下載的數據集是圖片格式,我這里也準備了代碼

代碼是在別人的基礎上改的,其中數據集存放路徑是dataset_dir,如果需要修改自行打印然后修改位置就好了。

#!/usr/bin/env python3
# -*- encoding utf-8 -*-'''
@File: save_mnist_to_jpg.py
@Date: 2024-08-23
@Author: KRISNAT
@Version: 0.0.0
@Email: ****
@Copyright: (C)Copyright 2024, KRISNAT
@Desc:1. 通過 torchvision.datasets.MNIST 下載、解壓和讀取 MNIST 數據集;2. 使用 PIL.Image.save 將 MNIST 數據集中的灰度圖片以 JPEG 格式保存。
'''import sys, os
sys.path.insert(0, os.getcwd())from torchvision.datasets import MNIST
import PIL
from tqdm import tqdmif __name__ == "__main__":home_dir = os.path.abspath('.')root = os.path.abspath(os.path.join(home_dir, '../dataset_file'))print(root)# exit(0)# 圖片保存路徑dataset_dir = os.path.join(root, 'mnist_jpg')if not os.path.exists(dataset_dir):os.makedirs(dataset_dir)# 從網絡上下載或從本地加載MNIST數據集# 訓練集60K、測試集10K# torchvision.datasets.MNIST接口下載的數據一組元組# 每個元組的結構是: (PIL.Image.Image image model=L size=28x28, 標簽數字 int)training_dataset = MNIST(root='mnist',train=True,download=True,)test_dataset = MNIST(root='mnist',train=False,download=True,)# 保存訓練集圖片with tqdm(total=len(training_dataset), ncols=150) as pro_bar:for idx, (X, y) in enumerate(training_dataset):f = dataset_dir + "/" + "training_" + str(idx) + \"_" + str(training_dataset[idx][1] ) + ".jpg"  # 文件路徑training_dataset[idx][0].save(f)pro_bar.update(n=1)# 保存測試集圖片with tqdm(total=len(test_dataset), ncols=150) as pro_bar:for idx, (X, y) in enumerate(test_dataset):f = dataset_dir + "/" + "test_" + str(idx) + \"_" + str(test_dataset[idx][1] ) + ".jpg"  # 文件路徑test_dataset[idx][0].save(f)pro_bar.update(n=1)

2.定義模型

這里我準備了兩個模型,一個MLP模型和一個簡單地CNN模型,其中MLP模型參數量1M,CNN模型參數量大概8M,當然這倆模型也沒有很仔細的規劃

import torch
import torch.nn as nnclass DigitLinear(nn.Module):def __init__(self):super(DigitLinear, self).__init__()self.fc1 = nn.Linear(28 * 28, 1000)self.fc2 = nn.Linear(1000, 500)self.dropout = nn.Dropout(0.3)self.fc3 = nn.Linear(500, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = self.fc1(x)x = torch.relu(x)x = self.dropout(x)x = self.fc2(x)x = torch.relu(x)x = self.fc3(x)return xclass DigitCNN(nn.Module):def __init__(self):super(DigitCNN,self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64*28*28, 128)self.dropout = nn.Dropout(0.1)self.fc2 = nn.Linear(128, 10)def forward(self, x):# print("x.shape:", x.shape)B,N,H,W = x.shapex = self.conv1(x)x = torch.relu(x)x = self.conv2(x)x = torch.relu(x)x = x.view(B, -1)  # 展平x = self.fc1(x)x = torch.relu(x)x = self.dropout(x)x = self.fc2(x)return x

3.dataloader和訓練

這里的代碼就很簡單了,就是一些參數的選擇,例如epoch,batchsize。其中的訓練函數我寫的買有很全面,只是勉強滿足了訓練功能,還有好多可以優化的點,比如打印fps,斷點續訓練啥的,不過這個任務提不起勁去干這事,大家可以自行優化。

# 數據加載器
from torch.utils.data import DataLoader
from lib.model.DigitModel import DigitLinear,DigitCNN
# 定義數據加載器
batch_size = 256
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)epoch = 50
# 訓練模型
net = DigitLinear() # 參數量1M 97.50%
# net = DigitCNN() # 參數量8M 98.81%
net.cuda()# 定義損失函數和優化器
import torch.optim as optim
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 訓練函數def train_model(model, train_loader, criterion, optimizer, num_epochs=10):model.train()  # 設置模型為訓練模式for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(train_loader):inputs= inputs.cuda()y = torch.tensor(torch.zeros((inputs.shape[0],10), dtype=torch.float)).cuda()y[torch.arange(inputs.shape[0]), labels] = 1optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, y)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels.cuda()).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100. * correct / totalprint(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 訓練模型
train_model(net, train_loader, criterion, optimizer, num_epochs=epoch)

4.訓練模型

有了上面的代碼就可以開始訓練了,我這里訓練的截圖是我的MLP模型,效果不是很好,CNN的效果稍微好一點,比MLP高1%,但是圖忘記截了。反正夠用了,因為本身MNIST的數據就不是很完美,有很多類似于噪聲的數據例如:
在這里插入圖片描述
這些數字我人眼都分不出是什么玩意。

訓練效果如下
在這里插入圖片描述

5.測試模型

訓練完當然是測試了
最后我的MLP模型跑了97.50%的準確率

代碼如下

# 測試模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.eval()
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device).float(), labels.to(device).float()outputs = net(inputs)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels.cuda()).sum().item()# print(f"Predicted: {predicted}, Ground Truth: {targets}")print(f"Accuracy: {correct / total * 100:.4f} %")

在這里插入圖片描述

6.保存模型

保存模型代碼就更簡單了

# 保存模型
torch.save(net.state_dict(), './digit_model.pth')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/82992.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/82992.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/82992.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

AWS WebRTC:獲取ICE服務地址(part 2): ICE Agent的作用

上一篇,已經獲取到了ICE服務地址,從返回結果中看,是兩組TURN服務地址。 拿到這些地址有什么用呢?接下來就要說到WebRTC中ICE Agent的作用了,返回的服務地址會傳給WebRTC最終給到ICE Agent。 ICE Agent的作用&#xf…

大數據時代的利劍:Bright Data網頁抓取與自動化工具共建高效數據采集新生態

目錄 一、為何要選用Bright Data網頁自動化抓取——幫助我們高效高質解決以下問題! 二、Bright Data網頁抓取工具 - 網頁爬蟲工具實測 2.1 首先注冊用戶 2.2 首先點擊 Proxies & Scraping ,再點擊瀏覽器API的開始使用 2.3 填寫通道名稱&#xff…

指紋識別+精準化POC攻擊

開發目的 解決漏洞掃描器的痛點 第一就是掃描量太大,對一個站點掃描了大量的無用 POC,浪費時間 指紋識別后還需要根據對應的指紋去進行 payload 掃描,非常的麻煩 開發思路 我們的思路分為大體分為指紋POC掃描 所以思路大概從這幾個方面…

【Day40】

DAY 40 訓練和測試的規范寫法 知識點回顧: 彩色和灰度圖片測試和訓練的規范寫法:封裝在函數中展平操作:除第一個維度batchsize外全部展平dropout操作:訓練階段隨機丟棄神經元,測試階段eval模式關閉dropout 作業&#x…

【HTML-13】HTML表格合并技術詳解:打造專業數據展示

表格是HTML中展示結構化數據的重要元素,而表格合并則是提升表格表現力的關鍵技術。本文將全面介紹HTML中的表格合并方法,幫助您創建更專業、更靈活的數據展示界面。 1. 表格合并基礎概念 在HTML中,表格合并主要通過兩個屬性實現&#xff1a…

<uniapp><threejs>在uniapp中,怎么使用threejs來顯示3D圖形?

前言 本專欄是基于uniapp實現手機端各種小功能的程序,并且基于各種通訊協議如http、websocekt等,實現手機端作為客戶端(或者是手持機、PDA等),與服務端進行數據通訊的實例開發。 發文平臺 CSDN 環境配置 系統:windows 平臺:visual studio code、HBuilderX(uniapp開…

如何制作全景VR圖?

全景VR圖,特別是720度全景VR,為觀眾提供一種沉浸式體驗。 全景VR圖能夠捕捉場景的全貌,還能將多個角度的圖片或視頻無縫拼接成一個完整的全景視角,讓觀眾在虛擬環境中自由探索。隨著虛擬現實(VR)技術的飛速…

前端使用qrcode來生成二維碼的時候中間添加logo圖標

這個開源倉庫可以讓你在前端頁面中生成二維碼圖片,并且支持調整前景色和背景色,但是有個問題,就是不能添加logo圖片。issue: GitHub Where software is built 但是已經有解決方案了: add a logo picture Issue #21…

【C語言】函數指針及其應用

目錄 1.1 函數指針的概念和應用 1.2 賦值與內存模型 1.3 調用方式與注意事項 二、函數指針的使用 2.1 函數指針的定義和訪問 2.2 動態調度:用戶輸入驅動函數執行 2.3 函數指針數組進階應用 2.4 函數作為參數的高階抽象 三、回調函數 3.1 指針函數…

安裝flash-attention失敗的終極解決方案(WINDOWS環境)

想要看linux版本下安裝問題的請走這里:安裝flash-attention失敗的終極解決方案(LINUX環境) 其實,現在的flash-attention不像 v2.3.2之前的版本,基本上不兼容WINDOWS環境。但是在WINDOWS環境安裝總還是有那么一點不順暢…

[C]基礎16.數據在內存中的存儲

博客主頁:向不悔本篇專欄:[C]您的支持,是我的創作動力。 文章目錄 0、總結1、整數在內存中的存儲1.1 整數的二進制表示方法1.2 不同整數的表示方法1.3 內存中存儲的是補碼 2、大小端字節序和字節序判斷2.1 什么是大小端2.2 為什么有大小端2.3…

Python 基于卷積神經網絡手寫數字識別

Ubuntu系統:22.04 python版本:3.9 安裝依賴庫: pip install tensorflow2.13 matplotlib numpy -i https://mirrors.aliyun.com/pypi/simple 代碼實現: import tensorflow as tf from tensorflow.keras.models import Sequent…

ElectronBot復刻-電路測試篇

typec-16p 接口部分 USB1(Type - C 接口):這是通用的 USB Type - C 接口,具備供電和數據傳輸功能。 GND 引腳(如 A1、A12、B1、B12 等):接地引腳,用于提供電路的參考電位&#xff0…

ESP8266+STM32 AT驅動程序,心知天氣API 記錄時間: 2025年5月26日13:24:11

接線為 串口2 接入ESP8266 esp8266.c #include "stm32f10x.h"//8266預處理文件 #include "esp8266.h"//硬件驅動 #include "delay.h" #include "usart.h"//用得到的庫 #include <string.h> #include <stdio.h> #include …

CDN安全加速:HTTPS加密最佳配置方案

CDN安全加速的HTTPS加密最佳配置方案需從證書管理、協議優化、安全策略到性能調優進行全鏈路設計&#xff0c;以下是核心實施步驟與注意事項&#xff1a; ??一、證書配置與管理?? ??證書選擇與格式?? ??證書類型??&#xff1a;優先使用受信任CA機構頒發的DV/OV/EV證…

【前端】Twemoji(Twitter Emoji)

目錄 注意使用Vue / React 項目 驗證 Twemoji 的作用&#xff1a; Twemoji 會把你網頁/應用中的 Emoji 字符&#xff08;如 &#x1f604;&#xff09;自動替換為 Twitter 風格的圖片&#xff08;SVG/PNG&#xff09;&#xff1b; 它不依賴系統字體&#xff0c;因此在 Android、…

GCN圖神經網絡的光伏功率預測

一、GCN圖神經網絡的核心優勢 圖結構建模能力 GCN通過鄰接矩陣&#xff08;表示節點間關系&#xff09;和節點特征矩陣&#xff08;如氣象數據、歷史功率&#xff09;進行特征傳播&#xff0c;能夠有效捕捉光伏電站間的空間相關性。其核心公式為&#xff1a; H ( l 1 ) σ (…

按照狀態實現自定義排序的方法

方法一&#xff1a;使用 MyBatis-Plus 的 QueryWrapper 自定義排序 在查詢時動態構建排序規則&#xff0c;通過 CASE WHEN 語句實現優先級排序&#xff1a; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import org.springframework.stereotype.Ser…

【計算機網絡】IPv6和NAT網絡地址轉換

IPv6 IPv6協議使用由單/雙冒號分隔一組數字和字母&#xff0c;例如2001:0db8:85a3:0000:0000:8a2e:0370:7334&#xff0c;分成8段。IPv6 使用 128 位互聯網地址&#xff0c;有 2 128 2^{128} 2128個IP地址無狀態地址自動配置&#xff0c;主機可以通過接口標識和網絡前綴生成全…

【Redis】string

String 字符串 字符串類型是 Redis 最基礎的數據類型&#xff0c;關于字符串需要特別注意&#xff1a; 首先 Redis 中所有的鍵的類型都是字符串類型&#xff0c;而且其他幾種數據結構也都是在字符串的基礎上構建的。字符串類型的值實際可以是字符串&#xff0c;包含一般格式的…