【AI】Pytorch神經網絡分類初探

Pytorch神經網絡分類初探

1.數據準備

環境采用之前創建的Anaconda虛擬環境pytorch,為了方便查看每一步的返回值,可以使用Jupyter Notebook來進行開發。首先把需要的包導入進來

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

torch框架的數據輸入依賴兩個基類:torch.utils.data.DataLoader和torch.utils.data.Dataset,Dataset 存儲樣本及其相應的標簽,DataLoader 將 Dataset 封裝為迭代器。

為了方便使用數據,我們采用Mnist數據集

%matplotlib inline
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)

等待數據下載完畢,然后將數據讀入進來。

import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

讀入進來的數據并不是tensor格式的,需要將其轉化成Tensor格式

import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)

最重要的一步,將其轉換成dataset和dataloader

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)

這樣就完成了數據準備的工作

2.定義模型

這邊直接引用官網教程的模型

# Get cpu, gpu or mps device for training.
device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device")# Define model
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()#self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):#x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)

將打印的結果放在下面,可以查看一下

Using cuda device
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True))
)

3.定義模型損失函數和優化器

這里我們依舊使用官網教程中的直接來

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

這里的SGD是最基礎的優化器,采用的是梯度遞減的方式,其收斂的會比較慢,如果希望收斂快些,可以使用Adam方式。

4. 定義訓練和測試函數

訓練函數

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

測試函數

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

5.開始訓練

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dl, model, loss_fn, optimizer)test(valid_dl, model, loss_fn)
print("Done!")

6.模型保存

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

7.模型加載和使用模型預測

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))

模型預測

classes = ["0","1","2","3","4","5","6","7","8","9",
]model.eval()
x, y = train_ds[2][0], train_ds[2][1]
with torch.no_grad():x = x.to(device)pred = model(x)print(pred)predicted, actual = classes[pred.argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

使用SGD優化器訓練,訓練5次的最高精度為76%,而使用Adam優化器第一個epoch的精度就已經達到了97%

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/206781.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/206781.shtml
英文地址,請注明出處:http://en.pswp.cn/news/206781.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【RHCE】openlab搭建web網站

網站需求: 1、基于域名 www.openlab.com 可以訪問網站內容為 welcome to openlab!!! 增加映射 [rootlocalhost ~]# vim /etc/hosts 創建網頁 [rootlocalhost ~]# mkdir -p /www/openlab [rootlocalhost ~]# echo welcome to openlab > /www/openlab/index.h…

利用法線貼圖渲染逼真的3D老虎模型

在線工具推薦: 3D數字孿生場景編輯器 - GLTF/GLB材質紋理編輯器 - 3D模型在線轉換 - Three.js AI自動紋理開發包 - YOLO 虛幻合成數據生成器 - 三維模型預覽圖生成器 - 3D模型語義搜索引擎 當談到游戲角色的3D模型風格時,有幾種不同的風格&#xf…

傅里葉變換在圖像中的應用

1.圖像增強與圖像去噪 絕大部分噪音都是圖像的高頻分量,通過低通濾波器來濾除高頻——噪聲; 邊緣也是圖像的高頻分量,可以通過添加高頻分量來增強原始圖像的邊緣; 2.圖像分割之邊緣檢測 提取圖像高頻分量 3.圖像特征提取: 形狀特…

3-Mybatis

文章目錄 Mybatis概述什么是Mybatis?Mybatis導入知識補充數據持久化持久層 第一個Mybatis程序:數據的增刪改查查創建環境編寫代碼1、目錄結構2、核心配置文件:resources/mybatis-config.xml3、mybatis工具類:Utils/MybatisUtils4、…

ALNS的MDP模型| 還沒整理完12-08

有好幾篇論文已經這樣做了,先擺出一篇,然后再慢慢更新 第一篇 該篇論文提出了一種稱為深增強ALNS(DR-ALNS)的方法,它利用DRL選擇最有效的破壞和修復運營商,配置破壞嚴重性參數施加在破壞算子上&#xff0c…

請簡要介紹一下HTML的發展史?

問題:什么是池化思想? 回答: 池化思想是一種資源管理的策略,通過事先創建并維護一組已經初始化好的資源對象池,以便在需要時快速獲取資源并在用完后歸還給池,以減少資源的創建和銷毀開銷,提高資…

第二十一章網絡通信總結

21.1 網絡程序設計基礎 Java網絡程序設計基礎涉及使用Java編程語言創建網絡應用程序。這通常涉及到使用Java的網絡API,如java.net包,以建立客戶端和服務器之間的通信。 基本步驟包括: 1.創建服務器: 使用ServerSocket類創建服務…

常見的中間件--消息隊列中間件測試點

最近刷題,看到了有問中間件的題目,于是整理了一些中間件的知識,大多是在小破站上的筆記,僅供大家參考~ 主要分為七個部分來分享: 一、常見的中間件 二、什么是隊列? 三、常見消息隊列MQ的比較 四、隊列…

用戶管理 --匯總

一、第一節課 1.1 本人寫的 前端: 魚皮 --> 用戶中心 第1節課-CSDN博客 中期: 一、用戶管理 第1節課中間-CSDN博客 后端: 一、用戶管理-CSDN博客 其他的鏈接 億圖腦圖MindMaster 1.2 優秀球友,推薦 Docs 另…

12_企業架構之Tomcat部署使用

Tomcat 學習目標和內容 1、能夠描述Tomcat的使用場景 2、能夠簡單描述Tomcat的工作原理 3、能夠實現部署安裝Tomcat 4、能夠實現配置Tomcat的service服務和自啟動 5、能夠實現Tomcat的Host的配置 6、能夠實現Nginx反向代理Tomcat 7、能夠實現Nginx負載均衡到Tomcat 一、Tomcat介…

Abaqus許可證配置文件問題

在使用Abaqus工程設計和仿真軟件時,您可能會遇到許可證配置文件問題。這些問題可能會影響軟件的正常運行和工作效率。為了幫助您解決這些問題,我們特別撰寫了這篇文章,以提供全面、有效的解決方案。 一、Abaqus許可證配置文件問題及原因 許…

力扣labuladong一刷day32天二叉樹

力扣labuladong一刷day32天二叉樹 一、297. 二叉樹的序列化與反序列化 題目鏈接:https://leetcode.cn/problems/serialize-and-deserialize-binary-tree/ 思路:關于序列化與反序列化,題目不要求序列化的方式,只要求樹經過序列化…

linux的定時任務Corntab

安裝crontab # yum安裝crontab yum install -y crontab# 開機自啟crond服務并現在啟動 systemctl enable --now crondcron系統任務調度 系統任務調度: 系統周期性所要執行的工作,比如寫緩存數據到硬盤、日志清理等。 在/etc/crontab文件,這…

機器學習之全面了解回歸學習器

我們將和大家一起探討機器學習與數據科學的主題。 本文主要討論大家針對回歸學習器提出的問題。我將概要介紹,然后探討以下五個問題: 1. 能否將回歸學習器用于時序數據? 2. 該如何縮短訓練時間? 3. 該如何解釋不同模型的結果和…

No suitable driver found for jdbc:mysql://localhost:3306(2023/12/7更新)

有兩種情況: 壓根沒安裝下載了但沒設為庫或方法不對 大多數為第一種情況: 一. 下載jdbc 打開網址選擇一個版本進行下載 https://nowjava.com/jar/version/mysql/mysql-connector-java.html 二.安裝jdbc 在項目里建一個lib文件夾 在把之前下載的jar文…

優化 SQL 日志記錄的方法

為什么 SQL 日志記錄是必不可少的 SQL 日志記錄在數據庫安全和審計中起著至關重要的作用,它涉及跟蹤在數據庫上執行的所有 SQL 語句,從而實現審計、故障排除和取證分析。SQL 日志記錄可以提供有關數據庫如何訪問和使用的寶貴見解,使其成為確…

JNPF低代碼平臺詳解 -- 系統架構

目錄 一、技術介紹 技術架構 二、設計原理 三、界面展示 1.代碼生成器 2.工作流程 3.門戶設計 4.大屏設計 5.報表設計 6.第三方登錄 7.多租戶實現 8.分布式調度 9.消息中心 四、功能框架 JNPF低代碼是一款新奇、實用、高效的企業級軟件開發工具,支持企…

Qt/C++音視頻開發58-逐幀播放/上一幀下一幀/切換播放進度/實時解碼

一、前言 逐幀播放是近期增加的功能,之前也一直思考過這個功能該如何實現,對于mdk/qtav等內核組件,可以直接用該組件提供的接口實現即可,而對于ffmpeg,需要自己處理,如果有緩存的數據的話,可以…

Rust的eBFP框架Aya(一) - Linux內核網絡基礎

前言 在我的Rust入門及實戰系列文章中已經說明, Rust是一門內存安全的高性能編程語言,從它的這些優秀特性來看,就是一門專為系統開發而誕生的語言。至于很多使用Rust來進行web開發的行為,不能說它們不好,只能說是殺雞…

2017下半年軟工(橋接模式)

題目——橋接模式(抽象調用實現部分) package org.example.橋接模式;/*** 橋接模式的核心思想是將抽象部分與它的實現部分分離,使它們可以獨立變化,就是說你在實現部分:WinImp、LinuxImp基礎上還能加上RedHatImp&#…