Pytorch-04 搭建神經網絡架構工作流

搭建神經網絡架構

在pytorch中,神經網絡被抽象成由一系列對數據執行特定操作的層或者模塊組成,比如下面的Attention實現,每個塊都是一個模塊或者層。
在這里插入圖片描述

如果你想快速搭建網絡架構,torch.nn這個命名空間提供了所有很多開箱即用的層/模塊/算子:
在這里插入圖片描述
如果你想自定義一個模塊也是完全可以的。每個模塊都是nn.Module的子類,你只需要繼承然后復寫即可,這個后面有例子。

這種簡潔的架構抽象可以讓使用pytorch的人們快速搭建并管理精妙的模型架構。

接下來,我們將搭建一個神經網絡來分類FashionMNIST數據集,來過一遍搭建網絡的工作流。

import os
import torch
from torch import nn
from torch.utils.data import Dataloader
from torchvision import datasets, transforms

1. 獲取可能的加速設備

為了在 加速器(accelerator) 上訓練我們的模型,例如 CUDAMPSMTIAXPU,我們將遵循以下邏輯:

如果當前設備有可用的加速器,我們就使用它;否則,我們將使用 CPU

device = torch.accelerator.current_accelerator().type if  torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

2. 搭建網絡結構

2.1 定義網絡類

通過繼承nn.Module,我們可以定義我們的神經網絡類,并且在__init__里面定義我們要用到的模塊或者層。然后實現forward方法來定義對輸入模型的數據的實際操作以及操作順序,并且返回推理結果。

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.faltten = nn.Faltten() # 展平層self.linear_relu_stack = nn.Sequential( # 定義一個序列模塊,被調用時會依次執行所含模塊nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):x = self.flatten(x)logit = self.linearr_relu_stack(x)return logits

注意,__init__只負責把需要的塊給初始化出來,具體數據是怎么在塊間流動由forward實現。

2.2 實例化網絡并查看結構

現在我們實例化網絡,并且把它搬到device側,然后打印出他的結構:

model = NeuralNetwork().to(device)
print(model)

在這里插入圖片描述

2.3 進行網絡“冒煙測試”

搭建好網絡結構之后,強烈建議進行一次“冒煙測試”,用一個符合輸入shape的tensor看看整個網絡能不能跑通。

要給模型傳入數據進行推理,直接給模型傳入數據即可,千萬別直接調用forward方法,因為model(x)還會做一些forward沒做的一些其他必要操作。

X = torch.rand(1, 28, 28, device=device)
logits = model(X)
print(logits.shape)
pred_probab = nn.Softmax(dim=1)(logits)
print(pred_probab)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

在這里插入圖片描述

給模型輸入數據之后,模型返回一個2維的tensor,dim=0的數據是batch中的具體樣本idx,dim=1的數據則是輸出的這個樣本的所屬10個不同類別的預測值。最后我們套一層nn.Softmax, 就可以獲得每個類別的概率pred_probab了。最后對其使用argmax(1)找到該張量在dim=1維度上的最大值索引,就獲得了這一次推理的分類結果。

3. 進階操作:獲取模型當前的參數

如果你想要一點可解釋性,你可能得用到這個

神經網絡中的許多層都是參數化的,也就是說,它們有相關的權重(weights)偏差(biases),這些值會在訓練過程中進行優化。

當你的模型繼承自 nn.Module 時,PyTorch 會自動追蹤模型對象中定義的所有字段。因此,你可以通過模型的 parameters()named_parameters() 方法來訪問所有這些參數。

print(model)for name, param in model.named_parameters():pritn(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n") # 矩陣獲取前兩行,bias獲取前兩個

在這個例子中,我們遍歷了每一個參數,并打印出它的尺寸(size)和部分值預覽。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/91916.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/91916.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/91916.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

從“碎片化”到“完美重組”:IP報文的分片藝術

前言 在網絡通信中,當IP層需要傳輸的數據包大小超過數據鏈路層的MTU限制時,就必須進行分片處理。本文將完整解析IP分片的工作機制,包括分片字段的作用、如何減少分片,以及分片報文的組裝原理。 IP報頭解析請參考&#xff…

[GESP202306 四級] 2023年6月GESP C++四級上機題超詳細題解,附帶講解視頻!

本文為2023年6月GESP C四級的上機題目的詳細題解!覺得寫的不錯或者有幫助可以點個贊啦! (第一次講解視頻,有問題可以指出,不足之處也可以指出) 目錄 題目一講解視頻: 題目二講解視頻: 題目一: 幸運數 題目大意: …

內網穿透 FRP 配置指南

關鍵詞:內網穿透、FRP配置、frps、frpc、遠程訪問、自建服務器、反向代理、TCP轉發、HTTP轉發 在開發或部署項目時,我們經常遇到內網設備無法被公網訪問的問題,例如你想從外網訪問你家里的 NAS、遠程調試開發板,或是訪問本地測試環…

SpringBoot 信用卡檢測、OpenAI gym、OCR結合、DICOM圖形處理、知識圖譜、農業害蟲識別實戰

信用卡欺詐檢測通常使用公開數據集 數據準備與預處理 信用卡欺詐檢測通常使用公開數據集如Kaggle的信用卡交易數據集。數據預處理包括處理缺失值、標準化數值特征、處理類別特征。在Spring Boot中,可以使用pandas或sklearn進行數據預處理。 // 示例:使用Spring Boot讀取CS…

使用 Docker 部署 Golang 程序

Docker 是部署 Golang 應用程序的絕佳方式,它可以確保環境一致性并簡化部署流程。以下是完整的指南: 1. 準備 Golang 應用程序 首先確保你的 Go 應用程序可以正常構建和運行。一個簡單的示例 main.go: package mainimport ("fmt""net/http" )func ha…

從零開始的CAD|CAE開發: LBM源碼實現分享

起因:上期我們寫了流體仿真的經典案例: 通過LBM,模擬計算渦流的形成,當時承諾: 只要驗證通過,就把代碼開源出來;ok.驗證通過了,那么我也就將代碼全都貼出來代碼開源并貼出:public class LidDrivenCavityFlow : IDisposable{public LidDrivenCavityFlow(int width 200, int hei…

倉庫管理系統-17-前端之物品類型管理

文章目錄 1 表設計(goodstype) 2 后端代碼 2.1 Goodstype.java 2.2 GoodstypeMapper.java 2.3 GoodstypeService.java 2.4 GoodstypeServiceImpl.java 2.5 GoodstypeController.java 3 前端代碼 3.1 goodstype/GoodstypeManage.vue 3.2 添加菜單 3.3 頁面顯示 1、goodstype表設…

共識算法深度解析:PoS/DPoS/PBFT對比與Python實現

目錄 共識算法深度解析:PoS/DPoS/PBFT對比與Python實現 1. 引言:區塊鏈共識的核心挑戰 2. 共識算法基礎 2.1 核心設計維度 2.2 關鍵評估指標 3. PoS(權益證明)原理與實現 3.1 核心機制 3.2 Python實現 4. DPoS(委托權益證明)原理與實現 4.1 核心機制 4.2 Python實現 5. P…

3.JVM,JRE和JDK的關系是什么

3.JVM,JRE和JDK的關系是什么 1.JDK(Java Development Kit),是功能齊全的Java SDK,包含JRE和一些開發工具(比如java.exe,運行工具javac.exe編譯工具,生成.class文件,javaw.exe,大多用…

深度學習技術發展思考筆記 || 一項新技術的出現,往往是為了解決先前范式中所暴露出的特定局限

深度學習領域的技術演進,遵循著一個以問題為導向的迭代規律。一項新技術的出現,往往是為了解決先前范式中所暴露出的特定局限。若將這些新技術看作是針對某個問題的“解決方案”,便能勾勒出一條清晰的技術發展脈絡。 例如,傳統的前…

Promise的reject處理: then的第二個回調 與 catch回調 筆記250804

Promise的reject處理: then的第二個回調 與 catch回調 筆記250804 Promise 錯誤處理深度解析:then 的第二個回調 vs catch 在 JavaScript 的 Promise 鏈式調用中,錯誤處理有兩種主要方式:.then() 的第二個回調函數和 .catch() 方法。這兩種方…

Maven模塊化開發與設計筆記

1. 模塊化開發模塊化開發是將大型應用程序拆分成多個小模塊的過程,每個模塊負責不同的功能。這有助于降低系統復雜性,提高代碼的可維護性和可擴展性。2. 聚合模塊聚合模塊(父模塊)用于組織和管理多個子模塊。它定義了項目的全局配…

sqli-labs:Less-21關卡詳細解析

1. 思路🚀 本關的SQL語句為: $sql"SELECT * FROM users WHERE username($cookee) LIMIT 0,1";注入類型:字符串型(單引號、括號包裹)、GET操作提示:參數需以)閉合關鍵參數:cookee p…

大模型+垂直場景:技術縱深、場景適配與合規治理全景圖

大模型垂直場景:技術縱深、場景適配與合規治理全景圖??核心結論?:2025年大模型落地已進入“深水區”,技術價值需通過 ?領域縱深(Domain-Deep)?、數據閉環(Data-Driven)?、部署友好&#x…

Kotlin Daemon 簡介

Kotlin Daemon 是 Kotlin 編譯器的一個后臺進程,旨在提高編譯性能。它通過保持編譯環境的狀態來減少每次編譯所需的啟動時間,從而加快增量編譯的速度。 Kotlin Daemon 的主要功能增量編譯: 只編譯自上次編譯以來發生更改的文件,節…

鴻蒙南向開發 編寫一個簡單子系統

文章目錄 前言給設備,編寫一個簡單子系統總結 一、前言 對于應用層的開發,搞了十幾年,其實已經有點開發膩的感覺了,翻來覆去,就是調用api,頁面實現,最多就再加個性能優化,但對底層…

超詳細:2026年博士申請時間線

博士申請是一場持久戰,需要提前規劃。那么,如何科學安排2026年博士申請時間線?SCI論文發表的最佳時間節點是什么?今天給所有打算申博的同學們,詳細解析下,每個時間節點的重點內容。2025年4月:是…

Python爬蟲實戰:研究tproxy代理工具,構建電商數據采集系統

1. 引言 1.1 研究背景 在大數據與人工智能技術快速發展的背景下,網絡數據已成為企業決策、學術研究、輿情監控的核心資源。據 Statista 統計,2024 年全球互聯網數據總量突破 180ZB,其中 80% 為非結構化數據,需通過爬蟲技術提取與轉化。Python 憑借其簡潔語法與豐富的爬蟲…

HighgoDB查詢慢SQL和阻塞SQL

文章目錄環境文檔用途詳細信息環境 系統平臺:N/A 版本:6.0,5.6.5,5.6.4,5.6.3,5.6.1,4.5.2,4.5,4.3.4.9,4.3.4.8,4.3.4.7,4.3.4.6,4.3.4.5,4.3.4.4,4.3.4.3,4.3.4.2,4.3.4,4.7.8,4.7.7,4.7.6,4.7.5,4.3.2 文檔用途 本文介紹了如何對數據庫日志進行分析…

day15 SPI

1串行外設接口概述1.1基本概念SPI(串行外設接口)是一種高速、全雙工、同步的串行通信協議。串行外設接口一般是需要4根線來進行通信(NSS、MISO、MOSI、SCK),但是如果打算實現單向通信(最少3根線&#xff09…