pytorch-01

加載mnist數據集

one-hot編碼實現

import numpy as np
import torch
x_train = np.load("../dataset/mnist/x_train.npy") # 從網站提前下載數據集,并解壓縮
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64)  # 獲取前5個樣本的標簽數據
# 定義一個張量輸入,因為此時有 5 個數值,且最大值為9,類別數為10
# 所以我們可以得到 y 的輸出結果的形狀為 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10)  # 一個參數張量x,10為類別數
print(y)

對于擁有6000個樣本的MNIST數據集來說,標簽就是一個6000\times 10大小的矩陣張量。

多層感知機模型

#設定的多層感知機網絡模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()  # 拉平圖像矩陣self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),   # 輸入大小為28*28,輸出大小為312維的線性變換層torch.nn.ReLU(),   # 激活函數層torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10)  # 最終輸出大小為10,對應one-hot標簽維度)def forward(self, input):   # 構建網絡x = self.flatten(input)  #拉平矩陣為1維logits = self.linear_relu_stack(x) # 多層感知機return logits

損失函數

優化函數

model = NeuralNetwork()
loss_fu = torch.nn.CrossEntropyLoss() # 交叉熵損失函數,內置了softmax函數,
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #設定優化函數loss = loss_fu(pred,label_batch)  # 計算損失

完整模型

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU編
import torch
import numpy as npbatch_size = 320                        #設定每次訓練的批次數
epochs = 1024                           #設定訓練次數#device = "cpu"                         #Pytorch的特性,需要指定計算的硬件,如果沒有GPU的存在,就使用CPU進行計算
device = "cuda"                         #在這里讀者默認使用GPU,如果讀者出現運行問題可以將其改成cpu模式#設定的多層感知機網絡模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),torch.nn.ReLU(),torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()
model = model.to(device)                #將計算模型傳入GPU硬件等待計算
torch.save(model, './model.pth')
#model = torch.compile(model)            #Pytorch2.0的特性,加速計算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #設定優化函數#載入數據
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")train_num = len(x_train)//batch_size#開始計算
for epoch in range(20):train_loss = 0for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizetrain_batch = torch.tensor(x_train[start:end]).to(device)label_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(train_batch)loss = loss_fu(pred,label_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 記錄每個批次的損失值# 計算并打印損失值train_loss /= train_numaccuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

可視化模型結構和參數

model = NeuralNetwork()
print(model)

是對模型具體使用的函數及其對應的參數進行打印。

格式化顯示:

param = list(model.parameters())
k=0
for i in param:l = 1print('該層結構:'+str(list(i.size())))for j in i.size():l*=jprint('該層參數和:'+str(l))k = k+l
print("總參數量:"+str(k))

模型保存

model = NeuralNetwork()
torch.save(model, './model.pth')

netron可視化

安裝:pip install netron

運行:命令行輸入netron

打開:通過網址http://localhost:8080打開

打開保存的模型文件model.pth:

?

?點擊顏色塊,可以顯示詳細信息:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/38251.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/38251.shtml
英文地址,請注明出處:http://en.pswp.cn/web/38251.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Vue 全局狀態管理新寵:Pinia實戰指南

文章目錄 前言全局狀態管理基本步驟:pinia 前言 隨著Vue.js項目的日益復雜,高效的狀態管理變得至關重要。Pinia作為Vue.js官方推薦的新一代狀態管理庫,以其簡潔的API和強大的功能脫穎而出。本文將帶您快速上手Pinia,從安裝到應用&…

uniapp如何根據不同角色自定義不同的tabbar

思路: 1.第一種是根據登錄時獲取的不同角色信息,來進行 跳轉到不同的頁面,在這些頁面中使用自定義tabbar 2.第二種思路是封裝一個自定義tabbar組件,然后在所有要展示tabbar的頁面中引入使用 1.根據手機號碼一鍵登錄&#xff0c…

SpringMVC的基本使用

SpringMVC簡介 SpringMVC是Spring提供的一套建立在Servlet基礎上,基于MVC模式的web解決方案 SpringMVC核心組件 DispatcherServlet:前置控制器,來自客戶端的所有請求都經由DispatcherServlet進行處理和分發Handler:處理器&…

三個方法教大家學會RAR文件轉換為ZIP格式

在日常工作當中,RAR和ZIP是兩種常見的壓縮文件格式。有時候,大家可能會遇到將RAR文件轉換為ZIP格式的情況,這通常是為了方便在特定情況下打開或使用文件。下面給大家分享幾個RAR文件轉換為ZIP格式的方法,下面隨小編一起來看看吧~ …

在mfc程序中,如何用c++找到exe文件所在的路徑

在 MFC&#xff08;Microsoft Foundation Class&#xff09;程序中&#xff0c;你可以使用 GetModuleFileName 函數來獲取當前運行的可執行文件&#xff08;.exe&#xff09;的路徑。 以下是一個示例代碼&#xff1a; #include <afxwin.h> #include <iostream>in…

KVM性能優化之CPU優化

1、查看kvm虛擬機vCPU的QEMU線程 ps -eLo ruser,pid,ppid,lwp,psr,args |awk /^qemu/{print $1,$2,$3,$4,$5,$6,$8} 注:vcpu是不同的線程&#xff0c;而不同的線程是跑在不同的cpu上&#xff0c;一般情況&#xff0c;虛擬機在運行時自身會點用3個cpus&#xff0c;為保證生產環…

通過MATLAB控制TI毫米波雷達的工作狀態

前言 前一章博主介紹了MATLAB上位機軟件“設計視圖”的制作流程,這一章節博主將介紹如何基于這些組件結合MATLAB代碼來發送CFG指令控制毫米波雷達的工作狀態 串口配置 首先,在我們選擇的端口號輸入框和端口波特率設置框內是可以手動填入數值(字符)的,也可以在點擊運行后…

匯凱金業:投資交易如何才能不虧損

投資交易中永不虧損是一個理想化的目標&#xff0c;現實中無法完全避免虧損。然而&#xff0c;通過科學的方法、合理的策略和嚴格的風險管理&#xff0c;投資者可以大幅減少虧損&#xff0c;并提高長期盈利的概率。以下是一些關鍵策略和方法&#xff0c;幫助投資者在交易中盡量…

【CSRF】

CSRF 原理&#xff1a;誘導用戶在訪問第三方site時&#xff0c;訪問攻擊者構造的site,攻擊者site會對原site進行惡意操作。 burp模擬攻擊&#xff1a; 對一個博客系統點擊發布文章時&#xff0c;Burp Suite抓包&#xff0c;右鍵CSRF PoC功能 -> Engagament tools -> Gen…

洛谷 P3954 [NOIP2017 普及組] 成績

本文由Jzwalliser原創&#xff0c;發布在CSDN平臺上&#xff0c;遵循CC 4.0 BY-SA協議。 因此&#xff0c;若需轉載/引用本文&#xff0c;請注明作者并附原文鏈接&#xff0c;且禁止刪除/修改本段文字。 違者必究&#xff0c;謝謝配合。 個人主頁&#xff1a;blog.csdn.net/jzw…

太陽能輻射系統加速材料老化的關鍵設備光照老化實驗箱

光照老化實驗箱概述 光照老化實驗箱是一種模擬太陽光照射對材料影響的實驗設備&#xff0c;主要用于加速材料的自然老化過程&#xff0c;以此來評估材料在實際使用環境中的耐久性和穩定性。該設備廣泛應用于汽車、航空、建筑、塑料制品等行業&#xff0c;尤其在汽車領域&#…

多商戶b2b2c商城系統怎么運營

B2B2C多用戶商城系統支持多種運營模式&#xff0c;以滿足不同類型和發展階段的企業需求。以下是五大主要的運營模式&#xff1a; **1. 自營模式&#xff1a;**平臺企業通過建立自營線上商城&#xff0c;整合自身多渠道業務。通過會員、商品、訂單、財務和倉儲等多用戶商城管理系…

OK527N-C開發板-簡單的性能測試

OK527N-C CoreMark 獲取CoreMark源碼 首先使用Git克隆倉庫&#xff1a; git clone https://github.com/eembc/coremark.git cd coremark修改Makefile 首先復制文件夾 cp -rf posix ok527之后修改ok527文件夾下的core_portme.mak文件&#xff0c;將CC修改如下 CC aarch6…

CPU占用率飆升至100%:是攻擊還是正常現象?

在運維和開發的日常工作中&#xff0c;CPU占用率突然飆升至100%往往是一個令人緊張的信號。這可能意味著服務器正在遭受攻擊&#xff0c;但也可能是由于某些正常的、但資源密集型的任務或進程造成的。本文將探討如何識別和應對服務器的異常CPU占用情況&#xff0c;并通過Python…

魔行觀察-探魚·鮮青椒爽麻烤魚-開關店監測-時間段:2013年1月 至 2024年6月

今日監測對象&#xff1a;探魚鮮青椒爽麻烤魚&#xff0c;監測時間段&#xff1a;2011年1月 至 2024年6月 本文用到數據源免費獲取地址 魔行觀察http://www.wmomo.com/ 品牌介紹&#xff1a; 探魚建立了產、供、銷一體全鏈條式供應鏈體系&#xff0c;并在低緯珠江口特設潮汐…

大公司圖紙管理的未來趨勢

隨著科技的不斷發展&#xff0c;大公司圖紙管理正朝著更加智能化、自動化和協同化的方向發展。以下是大公司圖紙管理的未來趨勢預測。 1. 智能化管理 利用人工智能和機器學習技術&#xff0c;實現圖紙的自動分類、標注和檢索。通過智能分析算法&#xff0c;預測圖紙的使用趨勢…

NSSCTF-Web題目19(數據庫注入、文件上傳、php非法傳參)

目錄 [LitCTF 2023]這是什么&#xff1f;SQL &#xff01;注一下 &#xff01; 1、題目 2、知識點 3、思路 [SWPUCTF 2023 秋季新生賽]Pingpingping 4、題目 5、知識點 6、思路 [LitCTF 2023]這是什么&#xff1f;SQL &#xff01;注一下 &#xff01; 1、題目 2、知識…

基于Vue的MOBA類游戲攻略分享平臺

你好呀&#xff0c;我是計算機學姐碼農小野&#xff01;如果有相關需求&#xff0c;可以私信聯系我。 開發語言&#xff1a;Java 數據庫&#xff1a;MySQL 技術&#xff1a;Java技術、SpringBoot框架、B/S模式、Vue.js 工具&#xff1a;MyEclipse、MySQL 系統展示 首頁 用…

在 Windows 上,使用 icacls 命令讓apache 用戶有權訪問

調試免費云服務器&#xff0c;三豐云&#xff0c;用戶權限過程。 在 Windows 上&#xff0c;icacls 命令是一個非常強大的工具&#xff0c;用于修改文件和目錄的權限。然而&#xff0c;需要注意的是&#xff0c;Windows 默認的 Web 服務器&#xff08;如 IIS&#xff09;通常運…

lstrip()方法——截掉字符串左邊的空格或指定的字符

自學python如何成為大佬(目錄):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 語法參考 lstrip()方法用于截掉字符串左邊的空格或指定的字符。lstrip()方法的語法格式如下&#xff1a; str.lstrip([chars]) 參數說明&#xff…