【Pytorch】16.使用ImageFolder加載自定義MNIST數據集訓練手寫數字識別網絡(包含數據集下載)

數據集下載

MINST_PNG_Training在github的項目目錄中的datasets中有MNIST的png格式數據集的壓縮包

用于訓練的神經網絡模型

在這里插入圖片描述

自定義數據集訓練

在前文【Pytorch】13.搭建完整的CIFAR10模型我們已經知道了基本搭建神經網絡的框架了,但是其中的數據集使用的torchvision中的CIFAR10官方數據集進行訓練的

train_dataset = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,transform=torchvision.transforms.ToTensor())

在這里插入圖片描述

本文將用圖片格式的數據集進行訓練
在這里插入圖片描述
我們通過

# Dataset CIFAR10
#     Number of datapoints: 60000
#     Root location: ../datasets
#     Split: Train
#     StandardTransform
# Transform: ToTensor()
print(train_dataset)

可以看到我們下載的數據集是這種格式的,所以我們的主要問題就是如何將自定義的數據集獲取,并且轉化為這種形式,剩下的步驟就和上文相同了

數據類型進行轉化

我們的首要目的是,根據數據集的地址,分別將數據轉化為train_datasettest_dataset
我們需要調用ImageFolder方法來進行操作

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *# 訓練集地址
train_root = "../datasets/mnist_png/training"
# 測試集地址
test_root = '../datasets/mnist_png/testing'# 進行數據的處理,定義數據轉換
data_transform = transforms.Compose([transforms.Resize((28, 28)),transforms.Grayscale(),transforms.ToTensor()])# 加載數據集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)

首先我們需要將數據進行處理,通過transforms.Compose獲取對象data_transform
其中進行了三步操作

  • 將圖片大小變為28*28像素便于輸入網絡模型
  • 將圖片轉化為灰度格式,因為手寫數字識別不需要三通道的圖片,只需要灰度圖像就可以識別,而png格式的圖片是四通道
  • 將圖片轉化為tensor數據類型

然后通過ImageFolder給出圖片的地址與轉化類型,就可以實現與我們在官方下載數據集相同的格式

# Dataset ImageFolder
#     Number of datapoints: 60000
#     Root location: ../datasets/mnist_png/training
#     StandardTransform
# Transform: Compose(
#                Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
#                ToTensor()
#            )
print(train_dataset)

其他與前文【Pytorch】13.搭建完整的CIFAR10模型基本相同

完整代碼

網絡模型

import torch
from torch import nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(2, stride=2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(3136, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = self.relu1(x)x = self.pool1(x)x = self.conv2(x)x = self.relu2(x)x = self.pool2(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == "__main__":model = Net()input = torch.ones((1, 1, 28, 28))output = model(input)print(output.shape)

訓練過程

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *# 訓練集地址
train_root = "../datasets/mnist_png/training"
# 測試集地址
test_root = '../datasets/mnist_png/testing'# 進行數據的處理,定義數據轉換
data_transform = transforms.Compose([transforms.Resize((28, 28)),transforms.Grayscale(),transforms.ToTensor()])# 加載數據集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)# Dataset ImageFolder
#     Number of datapoints: 60000
#     Root location: ../datasets/mnist_png/training
#     StandardTransform
# Transform: Compose(
#                Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
#                ToTensor()
#            )
# print(train_dataset)# print(train_dataset[0])train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")model = Net().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)epoch = 10writer = SummaryWriter('../logs')
total_step = 0for i in range(epoch):model.train()pre_step = 0pre_loss = 0for data in train_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = loss_fn(outputs, labels)loss.backward()optimizer.step()pre_loss = pre_loss + loss.item()pre_step += 1total_step += 1if pre_step % 100 == 0:print(f"Epoch: {i+1} ,pre_loss = {pre_loss/pre_step}")writer.add_scalar('train_loss', pre_loss / pre_step, total_step)model.eval()pre_accuracy = 0with torch.no_grad():for data in test_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = model(images)pre_accuracy += outputs.argmax(1).eq(labels).sum().item()print(f"Test_accuracy: {pre_accuracy/len(test_dataset)}")writer.add_scalar('test_accuracy', pre_accuracy / len(test_dataset), i)torch.save(model, f'../models/model{i}.pth')writer.close()

參考文章

【CNN】搭建AlexNet網絡——并處理自定義的數據集(貓狗分類)
How to download MNIST images as PNGs

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/14251.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/14251.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/14251.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Flutter 中的 WidgetInspector 小部件:全面指南

Flutter 中的 WidgetInspector 小部件:全面指南 Flutter 的 WidgetInspector 是一個強大的工具,它允許開發者在運行時檢查和操作他們的 widget 樹。這在調試復雜的布局和 widget 結構時尤其有用。本文將為您提供一個全面的指南,幫助您了解如…

Excel 按順序去重再編號

Excel的A有重復數據: A1Cow2Chicken3Horse4Butterfly5Cow 現在要去除重復,用自然數按順序進行編號,結果寫在相鄰列: AB1Cow12Chicken23Horse34Butterfly45Cow1 使用 SPL XLL,輸入公式并向下拖: spl(&q…

RISC-V壓縮指令擴展測試

概述 RISC-V定義了壓縮指令擴展(compressed instruction-set extension ),命名為“C”擴展。壓縮指令使用16位寬指令替換32位寬指令,從而減少代碼量。這個C擴展可運用在RV32、RV64和RV128指令集上,通常使用“RVC”來表…

Double 4 VR情景實訓教學系統在商務洽談課堂上的應用

隨著科技的不斷發展,VR(虛擬現實)技術已經逐漸滲透到各個領域。在商務洽談課堂上,Double 4 VR情景實訓教學系統不僅可以為學生提供身臨其境的模擬環境,還可以通過互動和交互式學習方式,增強學生的學習體驗和…

貝銳向日葵打造農機設備遠程運維支持方案

當物聯網“萬物互聯”的概念向第一產業賦能,農機設備的智能化程度也越來越高。 所謂農業物聯網,即在應用層將大量的傳感器節點構成監控網絡,通過各種傳感器采集信息,以幫助農民及時發現問題,并準確地判定發生問題的位…

QT 使用QZipReader 進行文件解壓縮

目錄 1、QZipReader 概述 2、解壓示例 3、說明 1、QZipReader 概述 QZipReader 是一個方便的工具,用于在 Qt 應用程序中解壓 ZIP 壓縮包。它提供了讀取 ZIP 文件的接口,并能提取其中的內容。以下是如何使用 QZipReader 解壓 ZIP 文件的示例代碼&#…

List、IList、ArrayList 和 Dictionary

List 類型: 泛型類命名空間: System.Collections.Generic作用: List<T> 表示一個強類型的對象列表&#xff0c;可以通過索引訪問。提供了搜索、排序和操作列表的方法。特點: 類型安全&#xff0c;性能較好&#xff0c;適用于需要強類型和高效操作的場景。例子: List<…

每日一練 - BGP Keepalive 報文詳解

01 真題題目 關于 BGP 的 Keepalive 報文消息的描述,錯誤的是&#xff1a; A.Keepalive 周期性的在兩個 BGP 鄰居之間發送 B.缺省情況下,Keepalive 的時間間隔是 180s C.Keepalive 報文主要用于對等路由器間的運行狀態和鏈路的可用性確認 D.Keepalive 報文的組成只包含一個…

Web安全:SQL注入之時間盲注原理+步驟+實戰操作

「作者簡介」&#xff1a;2022年北京冬奧會網絡安全中國代表隊&#xff0c;CSDN Top100&#xff0c;就職奇安信多年&#xff0c;以實戰工作為基礎對安全知識體系進行總結與歸納&#xff0c;著作適用于快速入門的 《網絡安全自學教程》&#xff0c;內容涵蓋系統安全、信息收集等…

ICML2024高分論文!大模型計算效率暴漲至200%,來自中國AI公司

前段時間&#xff0c;KAN突然爆火&#xff0c;成為可以替代MLP的一種全新神經網絡架構&#xff0c;200個參數頂30萬參數&#xff1b;而且&#xff0c;GPT-4o的生成速度也是驚艷了一眾大模型愛好者。 大家開始意識到—— 大模型的計算效率很重要&#xff0c;提升大模型的token…

前端加載excel文件數據 XLSX插件的使用

npm i xlsx import axios from axios; axios //這里用自己封裝的http是不行的&#xff0c;踩過坑.get(url,{ responseType: "arraybuffer" }).then((re) > {console.log(re)let res re.datavar XLSX require("xlsx");let wb XLSX.read(r…

黑龍江大學文學院古代文學教研室安家琪副教授

女&#xff0c;生于1990年。蘭州大學文學學士、碩士&#xff0c;上海交通大學文學博士&#xff0c;曾赴臺灣東華大學交流&#xff0c;研究方向為明清詩文與唐代文學。 在《文藝理論研究》、《蘇州大學學報》、《唐史論叢》、《中國社會科學報》等期刊發表論文20余篇&#xff0…

2024年 電工杯 (A題)大學生數學建模挑戰賽 | 園區微電網風光儲協調優化配置 | 數學建模完整代碼解析

DeepVisionary 每日深度學習前沿科技推送&頂會論文&數學建模與科技信息前沿資訊分享&#xff0c;與你一起了解前沿科技知識&#xff01; 本次DeepVisionary帶來的是電工杯的詳細解讀&#xff1a; 完整內容可以在文章末尾全文免費領取&閱讀&#xff01; 問題重述…

干就對了!

成年人的世界哪有那么容易&#xff0c;不過都在負重前行&#xff0c;誰不是一邊抱怨著&#xff0c;一邊咬牙堅持&#xff0c;一邊崩潰&#xff0c;一邊還要自我安慰。 想改變&#xff0c;想更好&#xff0c;我們都有很多想法。 想再多不如動手做一次。一旦開始做了&#xff0…

前端手寫文件上傳;使用input實現文件拖動上傳

使用input實現文件拖動上傳 vue2代碼&#xff1a; <template><div><div class"drop-area" dragenter"highlight" dragover"highlight" dragleave"unhighlight" drop"handleDrop"click"handleClick&quo…

聽說京東618裁員沒?上午還在趕需求,下午就開會通知被裁了~

文末還有最新面經共享群&#xff0c;沒準能讓你刷到意向公司的面試真題呢。 京東也要向市場輸送人才了? 在群里看到不少群友轉發京東裁員相關的內容&#xff1a; 我特地去網上搜索了相關資料&#xff0c;看看網友的分享&#xff1a; 想不到馬上就618了&#xff0c;東哥竟然搶…

Python 機器學習 基礎 之 模型評估與改進 【模型評估與改進 / 交叉驗證】的簡單說明

Python 機器學習 基礎 之 模型評估與改進 【模型評估與改進 / 交叉驗證】的簡單說明 目錄 Python 機器學習 基礎 之 模型評估與改進 【模型評估與改進 / 交叉驗證】的簡單說明 一、簡單介紹 二、模型評估與改進 三、交叉驗證 1、scikit-learn 中的交叉驗證 2、交叉驗證的…

stm32工程綜合實驗_延時及中斷優先級

待下載綜合實驗 ![在這里插入圖片描述](https://img-blog.csdnimg.cn/161fa4e200bb4022bf384e80a3af8797.jpg 很好的編程思想模式及資料(富萊xx電子)

【repo系列】repo常用命令的使用

前言 repo是一種代碼版本管理工具&#xff0c;它是由一系列的Python腳本組成&#xff0c;封裝了一系列的Git命令&#xff0c;用來統一管理多個Git倉庫。 本文章描述repo常用命令的使用。 常用命令 初始化 repo init 初始化代碼倉 repo init [options]常用options: -u URL…

JDBC——API詳解

一、DriverManager 1、用于注冊驅動程序&#xff1a;registerDriver(Driver driver)。 更常用的是Class.forName("com.mysql.jdbc.Driver")是由于Driver中包含了registerDriver(Driver driver)&#xff0c;值得注意的是&#xff0c;是mysql5之后的版本中&#xff0…