Pytroch搭建全連接神經網絡識別MNIST手寫數字數據集

編寫步驟

之前已經記錄國多次的編寫步驟了,無需多言。
(1)準備數據集
這里我們使用MNIST數據集,有官方下載渠道。我們直接使用torchvison里面提供的數據讀取功能包就行。如果不使用這個,自己像這樣子構建也一樣。

# 自己構建數據讀取模塊
#(1) 數據讀取模塊
class Mydataset(Dataset):def __init__(self,filepath):xy=np.loadtxt(filepath,delimiter=',',dtype=np.float32)self.len=xy.shape[0]self.x_data=torch.from_numpy(xy[:,:-1])self.y_data=torch.from_numpy(xy[:,[-1]])#魔法方法,容許用戶通過索引index得到值def __getitem__(self,index):return self.x_data[index],self.y_data[index]def __len__(self):return self.len

這里直接使用torchvison里面的工具

#準備數據集
batch_size = 64
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])trainset = torchvision.datasets.MNIST(root=r'../data/mnist',train=True,download=True,transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)testset = torchvision.datasets.MNIST(root=r'../data/mnist',train=False,download=True,transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

(2) 構建模型
這次我們使用不帶dropout的全連接模型

# 定義模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = nn.Linear(784, 100)self.linear2 = nn.Linear(100, 20)self.linear3 = nn.Linear(20, 10)def forward(self, x):x=x.view(x.size(0), -1)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))x = self.linear3(x)return x

(3) 選擇損失和優化器

# 構建模型和損失
model=Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

(4)訓練模型

def train(epoch):running_loss = 0.0for batch_idx, (inputs, targets) in enumerate(trainloader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()#需要將張量轉換為浮點數運算running_loss += loss.item()if batch_idx % 100 == 0:print('Train Epoch: {}, Loss: {:.6f}'.format(epoch, loss.item()))running_loss = 0

(5)測試模型

def test(epoch):correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(testloader):outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct=correct+(predicted.eq(targets).sum()*1.0)print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct/total))

全部代碼

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
#準備數據集
batch_size = 64
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])trainset = torchvision.datasets.MNIST(root=r'../data/mnist',train=True,download=True,transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)testset = torchvision.datasets.MNIST(root=r'../data/mnist',train=False,download=True,transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)# 定義模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = nn.Linear(784, 100)self.linear2 = nn.Linear(100, 20)self.linear3 = nn.Linear(20, 10)def forward(self, x):x=x.view(x.size(0), -1)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))x = self.linear3(x)return x
# 構建模型和損失
model=Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)def train(epoch):running_loss = 0.0for batch_idx, (inputs, targets) in enumerate(trainloader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()#需要將張量轉換為浮點數運算running_loss += loss.item()if batch_idx % 100 == 0:print('Train Epoch: {}, Loss: {:.6f}'.format(epoch, loss.item()))running_loss = 0
def test(epoch):correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(testloader):outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct=correct+(predicted.eq(targets).sum()*1.0)print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct/total))
if __name__ == '__main__':for epoch in range(10):train(epoch)test(epoch)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/74805.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/74805.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/74805.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Java 基本數據類型 vs 包裝類(引用數據類型)

一、核心概念對比(以 int vs Integer 為例) 特性基本數據類型(int)包裝類(Integer)數據類型原始值(Primitive Value)對象(Object)默認值0null內存位置棧&…

什么是 強化學習(RL):以DQN、PPO等經典模型

什么是 強化學習(RL):以DQN、PPO等經典模型 DQN(深度 Q 網絡)和 PPO(近端策略優化)共同屬于強化學習(Reinforcement Learning,RL)這一領域。強化學習是機器學習中的一個重要分支,其核心在于智能體(Agent)通過與環境進行交互,根據環境反饋的獎勵信號來學習最優的…

【Sql Server】在SQL Server中生成雪花ID(Snowflake ID)

大家好,我是全棧小5,歡迎來到《小5講堂》。 這是《Sql Server》系列文章,每篇文章將以博主理解的角度展開講解。 溫馨提示:博主能力有限,理解水平有限,若有不對之處望指正! 目錄 前言認識雪花ID…

HTML 表單處理進階:驗證與提交機制的學習心得與進度(一)

引言 在前端開發的廣袤領域中,HTML 表單處理堪稱基石般的存在,是構建交互性 Web 應用不可或缺的關鍵環節。從日常頻繁使用的登錄注冊表單,到功能多樣的搜索欄、反饋表單,HTML 表單如同橋梁,緊密連接著用戶與 Web 應用…

C# CancellationTokenSource CancellationToken Task.Run傳入token 取消令牌

基本使用方法創建 CancellationTokenSource獲取 CancellationToken將 CancellationToken 傳遞給任務***注意*** 在任務中檢查取消狀態請求取消處理取消異常 高級用法設置超時自動取消或者使用 CancelAfter 方法關聯多個取消令牌注冊回調 注意事項 CancellationTokenSource 是 …

Git 之配置ssh

1、打開 Git Bash 終端 2、設置用戶名 git config --global user.name tom3、生成公鑰 ssh-keygen -t rsa4、查看公鑰 cat ~/.ssh/id_rsa.pub5、將查看到的公鑰添加到不同Git平臺 6、驗證ssh遠程連接git倉庫 ssh -T gitgitee.com ssh -T gitcodeup.aliyun.com

cli命令編寫

新建文件夾 template-cli template-cli下運行 npm init生成package.json 新建bin文件夾和index.js文件 編寫index.js #! /usr/bin/env node console.log(hello cli)package.json增加 bin 字段注冊命令template-cli template-cli命令對應執行的內容文件 bin/index.js 運行 n…

vue3自定義動態錨點列表,實現本頁面錨點跳轉效果

需求&#xff1a;當前頁面存在多個模塊且內容很長時&#xff0c;需要提供一個錨點列表&#xff0c;可以快速查看對應模塊內容 實現步驟&#xff1a; 1.每個模塊添加唯一id&#xff0c;添加錨點列表div <template><!-- 模塊A --><div id"modalA">…

L2TP實驗

一、實驗拓撲 二、實驗內容 手工部署IPec VPN 三、實驗步驟 1、配置接口IP和安全區域 [PPPoE Client]firewall zone trust [PPPoE Client-zone-trust]add int g 1/0/0[NAS]firewall zone untrust [NAS-zone-untrust]add int g 1/0/1 [NAS]firewall zone trust [NAS-zon…

青少年編程與數學 02-012 SQLite 數據庫簡介 01課題、數據庫概要

青少年編程與數學 02-012 SQLite 數據庫簡介 01課題、數據庫概要&#xff09; 一、特點二、功能 課題摘要:SQLite 是一種輕量級的嵌入式關系型數據庫管理系統。 一、特點 輕量級 它不需要單獨的服務器進程來運行。不像 MySQL 或 PostgreSQL 這樣的數據庫系統需要一個專門的服務…

分布式系統面試總結:3、分布式鎖(和本地鎖的區別、特點、常見實現方案)

僅供自學回顧使用&#xff0c;請支持javaGuide原版書籍。 本篇文章涉及到的分布式鎖&#xff0c;在本人其他文章中也有涉及。 《JUC&#xff1a;三、兩階段終止模式、死鎖的jconsole檢測、樂觀鎖&#xff08;版本號機制CAS實現&#xff09;悲觀鎖》&#xff1a;https://blog.…

Ubuntu 系統上完全卸載 Docker

以下是在 Ubuntu 系統上完全卸載 Docker 的分步指南 一.卸載驗證 二.卸載步驟 1.停止 Docker 服務 sudo systemctl stop docker.socket sudo systemctl stop docker.service2.卸載 Docker 軟件包 # 移除 Docker 核心組件 sudo apt-get purge -y \docker-ce \docker-ce-cli …

Postman 版本信息速查:快速定位版本號

保持 Postman 更新至最新版本是非常重要的&#xff0c;因為這能讓我們享受到最新的功能&#xff0c;同時也保證了軟件的安全性。所以&#xff0c;如何快速查看你的 Postman 版本信息呢&#xff1f; 如何查看 Postman 的版本信息教程

EF Core 異步方法

文章目錄 前言一、為什么使用異步方法二、核心異步方法1&#xff09;查詢數據2&#xff09;保存數據3&#xff09;事務處理 三、異步查詢最佳實踐1&#xff09;始終使用 await2&#xff09;組合異步操作3&#xff09;并行查詢&#xff08;謹慎使用&#xff09; 四、異常處理五、…

裝飾器模式介紹和典型實現

裝飾器模式&#xff08;Decorator Pattern&#xff09;是一種結構型設計模式&#xff0c;它允許你通過將對象放入包含行為的特殊封裝對象中來為原對象添加新的功能。裝飾器模式的主要優點是可以在運行時動態地添加功能&#xff0c;而不需要修改原對象的代碼。這使得代碼更加靈活…

【 <二> 丹方改良:Spring 時代的 JavaWeb】之 Spring Boot 中的日志管理:Logback 的集成

<前文回顧> 點擊此處查看 合集 https://blog.csdn.net/foyodesigner/category_12907601.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12907601&sharereferPC&sharesourceFoyoDesigner&sharefromfrom_link <今日更新> 一、開篇整…

神經網絡知識點整理

目錄 ?一、深度學習基礎與流程 二、神經網絡基礎組件 三、卷積神經網絡&#xff08;CNN&#xff09;?編輯 四、循環神經網絡&#xff08;RNN&#xff09;與LSTM 五、優化技巧與調參 六、應用場景與前沿?編輯 七、總結與展望?編輯 一、深度學習基礎與流程 機器學習流…

【sql優化】where 1=1

文章目錄 where 11問題描述錯誤實現正確實現性能對比測試 where 11 問題描述 在動態 SQL 拼接場景中&#xff0c;開發者常使用 WHERE 11 簡化條件拼接邏輯&#xff08;避免處理首個條件的 AND&#xff09;。理論上&#xff0c;數據庫優化器會忽略 11&#xff0c;但字符串拼接…

車載以太網網絡測試 -24【SOME/IP概述】

目錄 1 摘要2 車載SOME/IP 概述2.1發展背景以及應用2.1.1車載 SOME/IP 背景2.1.2 車載 SOME/IP 應用場景 2.3 什么是SOME/IP2.3.1 SOME/IP定義2.3.2 SOME/IP在協議棧中的位置 3 SOA是什么4 SOME/IP主要功能5 SOME/IP標準 1 摘要 本文主要介紹SOME/IP的背景以及在車載行業的發展…

vue3中,route4,獲取當前頁面路由的問題

首先應用場景如下&#xff1a; 在main.js里面&#xff0c;引入的是路由的配置文件&#xff0c;如下&#xff1a; import {router} from /router; app.use(router); 路由配置文件router.js如下&#xff1a; import { createRouter, createWebHistory } from vue-router; imp…