[oneAPI] 手寫數字識別-LSTM

[oneAPI] 手寫數字識別-LSTM

  • 手寫數字識別
    • 參數與包
    • 加載數據
    • 模型
    • 訓練過程
    • 結果
  • oneAPI

比賽:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel? DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手寫數字識別

使用了pytorch以及Intel? Optimization for PyTorch,通過優化擴展了 PyTorch,使英特爾硬件的性能進一步提升,讓手寫數字識別問題更加的快速高效
在這里插入圖片描述

使用MNIST數據集,該數據集包含了一系列以黑白圖像表示的手寫數字,每個圖像的大小為28x28像素,數據集組成如下:

  • 訓練集:包含60,000個圖像和標簽,用于訓練模型。
  • 測試集:包含10,000個圖像和標簽,用于測試模型的性能。

每個圖像都被標記為0到9之間的一個數字,表示圖像中顯示的手寫數字。這個數據集常常被用來驗證圖像分類模型的性能,特別是在計算機視覺領域。

參數與包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

加載數據

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False,transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

模型

# Recurrent neural network (many-to-one)
class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):# Set initial hidden and cell states h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward propagate LSTMout, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)# Decode the hidden state of the last time stepout = self.fc(out[:, -1, :])return out

訓練過程

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# Test the model
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

結果

在這里插入圖片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# 模型
model = ConvNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/40387.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/40387.shtml
英文地址,請注明出處:http://en.pswp.cn/news/40387.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Curson 編輯器

Curson 漢化與vacode一樣 Curson 自帶chat功能 1、快捷鍵ctrlk(代碼中編輯) 2、快捷鍵ctrll 右側打開窗口

為什么hive會出現_HIVE_DEFAULT_PARTITION分區

問題: 為什么hive表中出現_HIVE_DEFAULT_PARTITION分區? 解答: 因為在業務sql中使用的是動態分區,并且hive啟用動態分區時,對于指定的分區鍵如果存在空值時,會對空值部分創建一個默認分區用于存儲該部分…

小程序項目組件的基本應用

宿主環境:程序運行必須依賴的環境 小程序的宿主環境 ---->手機微信(定位、掃碼、支付等) 小程序的通信模型: 渲染層和邏輯層之間的通信(微信客戶端轉發)邏輯層和第三方服務器之間的通信(微信客戶端轉發) 小程序的運行機制: 啟動&#xff1…

c#實現工廠模式

可以使用以下代碼實現C#中的工廠模式: 首先,定義一個接口作為產品的抽象: public interface IProduct {void Operation(); }然后,創建具體的產品類: public class ConcreteProductA : IProduct {public void Operat…

vue基礎知識五:請描述下你對vue生命周期的理解?在created和mounted這兩個生命周期中請求數據有什么區別呢?

一、生命周期是什么 生命周期(Life Cycle)的概念應用很廣泛,特別是在政治、經濟、環境、技術、社會等諸多領域經常出現,其基本涵義可以通俗地理解為“從搖籃到墳墓”(Cradle-to-Grave)的整個過程在Vue中實…

41 | 京東商家書籍評論數據分析

京東作為中國領先的電子商務平臺,積累了大量商品評論數據,這些數據蘊含了豐富的信息。通過文本數據分析,我們可以了解用戶對產品的態度、評價的關鍵詞、消費者的需求等,從而有助于商家優化產品和服務,以及消費者作出更明智的購買決策。 本文將詳細闡述如何獲取京東商家評…

Python opennsfw/opennsfw2 圖片/視頻 鑒黃 筆記

nsfw&#xff08; Not Suitable for Work&#xff09;直接翻譯就是 工作的時候不適合看&#xff0c;真文雅 nsfw效果&#xff0c;注意底部的分數 大體流程&#xff0c;輸入圖片/視頻&#xff0c;輸出0-1之間的數字&#xff0c;一般情況下&#xff0c;Scores < 0.2 認為是非…

7zip分卷壓縮

前言 有些項目上傳文件大小有限制 壓縮包大了之后傳輸也會比較慢 解決方案 我們可以利用7zip壓縮工具對文件進行分卷壓縮 利用7zip壓縮工具進行分卷壓縮 查看待壓縮文件大小 壓縮完成之后有300多M&#xff0c;我們用100M去進行分卷壓縮 選擇待壓縮的文件夾&#xff0c;右…

網絡安全 Day30-運維安全項目-容器架構上

容器架構上 1. 什么是容器2. 容器 vs 虛擬機(化) :star::star:3. Docker極速上手指南1&#xff09;使用rpm包安裝docker2) docker下載鏡像加速的配置3) 載入鏡像大禮包&#xff08;老師資料包中有&#xff09; 4. Docker使用案例1&#xff09; 案例01&#xff1a;:star::star::…

《內網穿透》無需公網IP,公網SSH遠程訪問家中的樹莓派

文章目錄 前言 如何通過 SSH 連接到樹莓派步驟1. 在 Raspberry Pi 上啟用 SSH步驟2. 查找樹莓派的 IP 地址步驟3. SSH 到你的樹莓派步驟 4. 在任何地點訪問家中的樹莓派4.1 安裝 Cpolar內網穿透4.2 cpolar進行token認證4.3 配置cpolar服務開機自啟動4.4 查看映射到公網的隧道地…

【JavaEE基礎學習打卡02】是時候了解Java EE了!

目錄 前言一、為什么要學習Java EE二、Java EE規范介紹1.什么是規范&#xff1f;2.什么是Java EE規范&#xff1f;3.Java EE版本 三、Java EE應用程序模型1.模型前置說明2.模型具體說明 總結 前言 &#x1f4dc; 本系列教程適用于 Java Web 初學者、愛好者&#xff0c;小白白。…

java接口導出csv

1、背景介紹 項目中需要導出數據質檢結果&#xff0c;本來使用Excel&#xff0c;但是質檢結果數據行數過多&#xff0c;導致用hutool報錯&#xff0c;因此轉為導出csv格式數據。 2、參考文檔 https://blog.csdn.net/ityqing/article/details/127879556 工程環境&#xff1a;…

Redis-分布式鎖!

分布式鎖&#xff0c;顧名思義&#xff0c;分布式鎖就是分布式場景下的鎖&#xff0c;比如多臺不同機器上的進程&#xff0c;去競爭同一項資源&#xff0c;就是分布式鎖。 分布式鎖特性 互斥性:鎖的目的是獲取資源的使用權&#xff0c;所以只讓一個競爭者持有鎖&#xff0c;這…

PyTorch: clamp函數與梯度的關系

本文主要以下探究這一點&#xff1a;梯度反向傳播過程中&#xff0c;測試強行修改后的預測結果是否還會傳遞loss&#xff1f; clamp應用場景&#xff1a;在深度學習計算損失函數的過程中&#xff0c;會有這樣一個問題&#xff0c;如果Label是1.0&#xff0c;而預測結果是0.0&a…

【算法】排序+雙指針——leetcode三數之和、四數之和

三數之和 &#xff08;1&#xff09;排序雙指針 算法思路&#xff1a; 和之前的兩數之和類似&#xff0c;我們對暴力枚舉進行了一些優化&#xff0c;利用了排序雙指針的思路&#xff1a; 我們先排序&#xff0c;然后固定?個數 a &#xff0c;接著我們就可以在這個數后面的區間…

Mybatis Plus Interceptor

Mybatis Plus Interceptor 1 獲取表名2 獲取SQL 1 獲取表名 Component public class MybatisInterceptor implements Interceptor {private static final List<String> EXCLUDE_TABLE new ArrayList<>();static {EXCLUDE_TABLE.add("test");}private s…

OpenCV實例(九)基于深度學習的運動目標檢測(一)YOLO運動目標檢測算法

基于深度學習的運動目標檢測&#xff08;一&#xff09; 1.YOLO算法檢測流程2.YOLO算法網絡架構3.網絡訓練模型3.1 訓練策略3.2 代價函數的設定 2012年&#xff0c;隨著深度學習技術的不斷突破&#xff0c;開始興起基于深度學習的目標檢測算法的研究浪潮。 2014年&#xff0c;…

電腦突然黑屏的解決辦法

記錄一次電腦使用問題 問題描述 基本情況&#xff1a;雷神游戲筆記本 windows10操作系統 64位 使用時間 4年 日期&#xff1a;2023年8月11日 當時 電腦充著電 打開了兩個瀏覽器&#xff1a;edge[頁面加載5個左右]&#xff0c;火狐[頁面加載1個左右] 兩個文件夾 一個百度網盤…

Davinci 報表工具 0.3.0-rc release 文本框模糊查詢不生效問題

背景: 在使用過程中發現davinci 的控制器配置中, 取值配置的對應關系設置 包含 或 不包含時 不生效, 不能實現模糊匹配效果, 只能精確查詢; 問題分析: 通過跟蹤接口及相應代碼, 發現在sql 拼接時沒有對 like 和 not like 類型的值兩側添加百分號, 導致模糊查詢失敗 調用過程…

CentOS系統環境搭建(七)——Centos7安裝MySQL

centos系統環境搭建專欄&#x1f517;點擊跳轉 坦誠地說&#xff0c;本文中百分之九十的內容都來自于該文章&#x1f517;Linux&#xff1a;CentOS7安裝MySQL8&#xff08;詳&#xff09;&#xff0c;十分佩服大佬文章結構合理&#xff0c;文筆清晰&#xff0c;我曾經在這篇文章…