13.深度學習——Minst手寫數字識別

第一部分——起手式

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimuse_cuda = torch.cuda.is_available()if use_cuda:device = torch.device("cuda")
else: device = torch.device("cpu")print(f"Using device {device}")

第二部分——計算均值、方差

transform = transforms.Compose([#將數據轉換成Tensor張量transforms.ToTensor()
]
)#讀取數據
datasets1 = datasets.MNIST('./data',train=True,download = True, transform =transform)
datasets1_len = len(datasets1)#設置數據加載器、批次大小全部圖片
train_loader = torch.utils.data.DataLoader(datasets1, batch_size=datasets1_len, shuffle = True)#循環訓練集 DataLoader,0是起始索引
for batch_idx, data in enumerate(train_loader,0):inputs, targets = data #將訓練集圖(60000,1,28,28)像轉換為(60000*1,28*28)的二維數組,-1 是占位符用于自動計算維度大小x = inputs.view(-1,28*28)#計算均值-0.3081x_mean =x.mean().item()#計算標準差-0.1307x_std =x.std().item()print(f"mean: {x_mean}, std: {x_std}")
#mean: 0.13066047430038452, std: 0.30810782313346863

第三部分——網絡模型

#自定義類構建模型、繼承torch.nn.module初始化網絡模型
class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.fc1 = torch.nn.Linear(784, 128)#Liner線性加權求和,784是input,128是當前層神經元個數self.dropout = torch.nn.Dropout(p = 0.2)self.fc2 = torch.nn.Linear(128, 10)#input=上一層的神經元個數,輸出是10,做一個0-9的10分類def forward(self, x):#把x的每條數據展成一維數組28*28=784x = torch.flatten(x,1)x = self.fc1(x)x = F.relu(x)x = self.dropout(x)x = self.fc2(x)output = F.log_softmax(x, dim=1)#做完softmax然后取log,便于后續計算損失函數(損失函數需要取log)return output       

第四部分——訓練策略、測試策略

#創建實例
model = Net().to(device)#每個批次如何訓練
def train_step(data, target, model, optimizer):optimizer.zero_grad()#梯度歸零output = model(data)loss = F.nll_loss(output,target)#nll是負對數似然,output是y_head,target是y_trueloss.backward()#反向傳播求梯度optimizer.step()#根據梯度更新網絡return loss#每個批次如何測試
def test(data, target, model, test_loss, correct):output = model(data)#累積計算每個批次的損失test_loss += F.nll_loss(output,target,reduction='sum').item()#獲取對數概率最大對應的索引,dim=1:表示選取每一行概率最大的索引,keepdim = True 表示維度保持不變pred = output.argmax(dim=1, keepdim=True)#統計預測值與正確值相同的數量,eq在做比較,返回True/Fasle,sum是求和,item是將數據取出來(原來是tensor)correct += pred.eq(target.view_as(pred)).sum().item()return test_loss, correct

第五部分——開始訓練

#真正分輪次訓練
EPOCHS = 5#調參優化器,lr是學習率
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)for epoch in range(EPOCHS):model.train()#設置為訓練模式:BN層計算的是均值方差for batch_index, (data, target) in enumerate(train_loader):data, target = data.to(device),target.to(device)loss = train_step(data, target, model, optimizer)#每隔10個批次打印一次信息if batch_index%10 ==0:print('Train Epoch:{epoch} [{batch}/{total_batch} {percent}%] train_loss:{loss:.3f}'.format(epoch=epoch+1,#第幾個批次batch = batch_index*len(data),#已跑多少數據total_batch = len(train_loader.dataset),#當前輪總數據條數percent = 100.0*batch_index/len(train_loader),#當前輪數已占訓練集百分比loss = loss.item()#損失是tensor,轉為數值))       #設置為測試模式:BN層計算的是滑動平均,Droput層不進行預測model.eval()test_loss = 0correct = 0with torch.no_grad():#不求梯度for data, target in test_loader:data, target = data.to(device), target.to(device)test_loss, correct = test_step(data, target, model, test_loss, correct)    test_loss = test_loss/len(test_loader.dataset)print('\n Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(test_loss,correct,len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

完整代碼

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimuse_cuda = torch.cuda.is_available()if use_cuda:device = torch.device("cuda")
else: device = torch.device("cpu")print(f"Using device {device}")#數據預處理
transform = transforms.Compose([#將數據轉換成Tensor張量transforms.ToTensor(),#圖片數據歸一化:0.1307是均值,0.3081是方差。數值和數據集有關系transforms.Normalize((0.1307),(0.3081))
]
)#讀取數據
datasets1 =datasets.MNIST('./data',train=True,download = True, transform =transform)
datasets2 =datasets.MNIST('./data',train=False,download = True, transform =transform)#設置數據加載器、批次大小128、是否打亂順序-是
train_loader = torch.utils.data.DataLoader(datasets1, batch_size=128, shuffle = True)
#測試批次可以大,測試集不需要打亂順序-False
test_loader = torch.utils.data.DataLoader(datasets2, batch_size =1000,shuffle = False)#自定義類構建模型、繼承torch.nn.module初始化網絡模型
class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.fc1 = torch.nn.Linear(784, 128)#Liner線性加權求和,784是input,128是當前層神經元個數self.dropout = torch.nn.Dropout(p = 0.2)self.fc2 = torch.nn.Linear(128, 10)#input=上一層的神經元個數,輸出是10,做一個0-9的10分類def forward(self, x):#把x的每條數據展成一維數組28*28=784x = torch.flatten(x,1)x = self.fc1(x)x = F.relu(x)x = self.dropout(x)x = self.fc2(x)output = F.log_softmax(x, dim=1)#做完softmax然后取log,便于后續計算損失函數(損失函數需要取log)return output       #創建實例
model = Net().to(device)#每個批次如何訓練
def train_step(data, target, model, optimizer):optimizer.zero_grad()#梯度歸零output = model(data)loss = F.nll_loss(output,target)#nll是負對數似然,output是y_head,target是y_trueloss.backward()#反向傳播求梯度optimizer.step()#根據梯度更新網絡return loss#每個批次如何測試
def test_step(data, target, model, test_loss, correct):output = model(data)#累積計算每個批次的損失test_loss += F.nll_loss(output,target,reduction='sum').item()#獲取對數概率最大對應的索引,dim=1:表示選取每一行概率最大的索引,keepdim = True 表示維度保持不變pred = output.argmax(dim=1, keepdim=True)#統計預測值與正確值相同的數量,eq在做比較,返回True/Fasle,sum是求和,item是將數據取出來(原來是tensor)correct += pred.eq(target.view_as(pred)).sum().item()return test_loss, correct#真正分輪次訓練
EPOCHS = 5#調參優化器,lr是學習率
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)for epoch in range(EPOCHS):model.train()#設置為訓練模式:BN層計算的是均值方差for batch_index, (data, target) in enumerate(train_loader):data, target = data.to(device),target.to(device)loss = train_step(data, target, model, optimizer)#每隔10個批次打印一次信息if batch_index%10 ==0:print('Train Epoch:{epoch} [{batch}/{total_batch} {percent}%] train_loss:{loss:.3f}'.format(epoch=epoch+1,#第幾個批次batch = batch_index*len(data),#已跑多少數據total_batch = len(train_loader.dataset),#當前輪總數據條數percent = 100.0*batch_index/len(train_loader),#當前輪數已占訓練集百分比loss = loss.item()#損失是tensor,轉為數值))       #設置為測試模式:BN層計算的是滑動平均,Droput層不進行預測model.eval()test_loss = 0correct = 0with torch.no_grad():#不求梯度for data, target in test_loader:data, target = data.to(device), target.to(device)test_loss, correct = test_step(data, target, model, test_loss, correct)    test_loss = test_loss/len(test_loader.dataset)print('\n Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(test_loss,correct,len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/95774.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/95774.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/95774.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【JAVA高級】實現word轉pdf 實現,源碼概述。深坑總結

之前的需求做好后,需求,客戶突發奇想。要將生成的word轉為pdf! 因為不想讓下載文檔的人改動文檔。 【JAVA】實現word添加標簽實現系統自動填入字段-CSDN博客 事實上這個需求難度較高,并不是直接轉換就行的 word文檔當中的很多東西都需要處理 public static byte[] gener…

數據驅動測試提升自動化效率

測試工程師老王盯著滿屏重復代碼嘆氣:“改個搜索條件要重寫20個腳本,這班加到啥時候是個頭?” 隔壁組的小李探過頭:“試試數據驅動唄,一套腳本吃遍所有數據,我們組上周測了300個組合都沒加班!”…

模板引用(Template Refs)全解析2

三、v-for 中的模板引用 當在 v-for 中使用模板引用時,引用的 value 會自動變為一個數組,包含列表中所有元素/組件的引用(需 Vue 3.5+ 版本,舊版需手動處理且順序不保證)。 1. 基本用法(Vue 3.5+) <script setup> import { ref, useTemplateRef, onMounted } f…

【Linux系統】進程間通信:System V IPC——共享內存

前文中我們介紹了管道——匿名管道和命名管道來實現進程間通信&#xff0c;在介紹怎么進行通信時&#xff0c;我們有提到過不止管道的方式進行通信&#xff0c;還有System V IPC&#xff0c;今天這篇文章我們就來學習一下System V IPC中的共享內存1. 為何引入共享內存&#xff…

[優選算法專題二滑動窗口——最大連續1的個數 III]

題目鏈接 最大連續1的個數 III 題目描述 題目解析 問題本質 輸入&#xff1a;二進制數組nums&#xff08;只包含 0 和 1&#xff09;和整數k操作&#xff1a;最多可以將k個 0 翻轉成 1目標&#xff1a;找到翻轉后能得到的最長連續 1 的子數組長度 這個問題的核心是要找到一…

C#單元測試(xUnit + Moq + coverlet.collector)

C#單元測試 xUnit Moq coverlet.collector 1.添加庫 MlyMathLib 2.編寫庫函數內容 using System;namespace MlyMathLib {public interface IUserRepo{string GetName(int id);}public class UserService{private readonly IUserRepo _repo;public UserService(IUserRepo repo…

【數據庫】Oracle學習筆記整理之五:ORACLE體系結構 - 參數文件與控制文件(Parameter Files Control Files)

Oracle體系結構 - 參數文件與控制文件&#xff08;Parameter Files & Control Files&#xff09; 參數文件與控制文件是Oracle數據庫的“雙核基石”&#xff1a;參數文件是實例的“啟動配置中心”&#xff0c;定義運行環境與規則&#xff1b;控制文件是數據庫的“物理元數據…

GDB典型開發場景深度解析

GDB典型開發場景深度解析 以下是開發過程中最常見的GDB使用場景&#xff0c;結合具體實例和調試技巧&#xff0c;幫助開發者高效解決實際問題&#xff1a;一、崩潰分析&#xff08;Core Dump調試&#xff09; 場景&#xff1a;程序突然崩潰&#xff0c;生成了core文件 # 啟動調…

存儲、硬盤、文件系統、 IO相關常識總結

目錄 &#xff08;一&#xff09;存儲 &#xff08;1&#xff09;定義 &#xff08;2&#xff09;分類 &#xff08;二&#xff09;硬盤 &#xff08;1&#xff09;容量&#xff08;最主要的參數&#xff09; &#xff08;2&#xff09;轉速 &#xff08;3&#xff09;訪…

docker安裝mongodb及java連接實戰

1.docker部署mongodb docker run --name mongodb -d -p 27017:27017 -v /data/mongodbdata:/data/db -e MONGO_INITDB_ROOT_USERNAMEtestmongo -e MONGO_INITDB_ROOT_PASSWORDtest123456 mongodb:4.0.112.項目實戰 <dependencies><dependency><groupId>org.m…

Java設計模式之《工廠模式》

目錄 1、介紹 1.1、定義 1.2、優缺點 1.3、使用場景 2、實現 2.1、簡單工廠模式 2.2、工廠方法模式 2.3、抽象工廠模式 3、小結 前言 在面向對象編程中&#xff0c;創建對象實例最常用的方式就是通過 new 操作符構造一個對象實例&#xff0c;但在某些情況下&#xff0…

【異步】js中異步的實現方式 async await /Promise / Generator

JS的異步相關知識 js里面一共有以下異步的解決方案 傳統的回調 省略 。。。。 生成器 Generator 函數是 ES6 提供的一種異步編程解決方案, 語法上&#xff0c;首先可以把它理解成&#xff0c;Generator 函數是一個狀態機&#xff0c;封裝了多個內部狀態。執行 Generator 函數…

JVM字節碼文件結構

Class文件結構class文件是二進制文件&#xff0c;這里要介紹的是這個二級制文件的結構。思考&#xff1a;一個java文件編譯成class文件&#xff0c;如果要描述一個java文件&#xff0c;需要哪些信息呢&#xff1f;基本信息&#xff1a;類名、父類、實現哪些接口、方法個數、每個…

11.web api 2

5. 操作元素屬性 5.1操作元素常用屬性 &#xff1a;通過 JS 設置/修改標簽元素屬性&#xff0c;比如通過 src更換 圖片最常見的屬性比如&#xff1a; href、title、src 等5.2 操作元素樣式屬性 &#xff1a;通過 JS 設置/修改標簽元素的樣式屬性。使用 className 有什么好處&a…

java中數組和list的區別是什么?

在Java中&#xff0c;數組&#xff08;Array&#xff09;和List&#xff08;通常指java.util.List接口的實現類&#xff0c;如ArrayList、LinkedList&#xff09;是兩種常用的容器&#xff0c;但它們在設計、功能和使用場景上有顯著區別。以下從核心特性、使用方式等方面詳細對…

Python爬取推特(X)的各種數據

&#x1f31f; Hello&#xff0c;我是蔣星熠Jaxonic&#xff01; &#x1f308; 在浩瀚無垠的技術宇宙中&#xff0c;我是一名執著的星際旅人&#xff0c;用代碼繪制探索的軌跡。 &#x1f680; 每一個算法都是我點燃的推進器&#xff0c;每一行代碼都是我航行的星圖。 &#x…

Oracle數據庫文件管理與空間問題解決指南

在Oracle數據庫運維中&#xff0c;表空間、數據文件及相關日志文件的管理是保障數據庫穩定運行的核心環節。本文將系統梳理表空間與數據文件的調整、關鍵文件的移動、自動擴展配置&#xff0c;以及常見空間不足錯誤的排查解決方法&#xff0c;為數據庫管理員提供全面參考。 一、…

華為實驗綜合小練習

描述&#xff1a; 1 內網有A、B、C 三個部門。所在網段如圖所示。 2 內網服務器配置靜態IP,網關192.168.100.1。 3 sw1和R1之間使用vlan200 192.168.200.0/30 互聯。 4 R1向運營商申請企業寬帶并申請了5個公網IP&#xff1a;200.1.1.1-.5子網掩碼 255.255.255.248&#xff0c;網…

Flink面試題及詳細答案100道(1-20)- 基礎概念與架構

《前后端面試題》專欄集合了前后端各個知識模塊的面試題&#xff0c;包括html&#xff0c;javascript&#xff0c;css&#xff0c;vue&#xff0c;react&#xff0c;java&#xff0c;Openlayers&#xff0c;leaflet&#xff0c;cesium&#xff0c;mapboxGL&#xff0c;threejs&…

爬蟲逆向之滑塊驗證碼加密分析(軌跡和坐標)

本文章中所有內容僅供學習交流使用&#xff0c;不用于其他任何目的。否則由此產生的一切后果均與作者無關&#xff01;在爬蟲開發過程中&#xff0c;滑塊驗證碼常常成為我們獲取數據的一大阻礙。而滑塊驗證碼的加密方式多種多樣&#xff0c;其中軌跡加密和坐標加密是比較常見的…