第6節 torch.nn介紹

6.1 torch.nn.Module介紹

????????torch.nn.Module是 PyTorch 中構建神經網絡的基礎類,所有的神經網絡模塊都應該繼承這個類。它提供了一種便捷的方式來組織和管理網絡中的各個組件,包括層、參數等,同時還內置了許多用于模型訓練和推理的功能。

官網:torch.nn — PyTorch 1.8.1 documentation

核心功能

(1)、網絡構建:通過繼承torch.nn.Module類,我們可以自定義自己的神經網絡結構。在__init__方法中定義網絡的各個層,在forward方法中定義數據的前向傳播過程。

(2)、參數管理:torch.nn.Module會自動跟蹤和管理網絡中的參數(如權重和偏置)。我們可以通過parameters()方法獲取網絡的所有參數,方便進行優化器的配置和參數的更新。

(3)、設備轉換:可以使用to()方法將模型轉移到指定的設備(如 CPU 或 GPU)上,以利用不同設備的計算能力。?

(4)、狀態切換:提供了train()和eval()方法來切換模型的訓練和評估狀態。在訓練狀態下,一些具有隨機性的層(如 Dropout、BatchNorm)會正常工作;在評估狀態下,這些層會采用確定性的行為。

6.2 torch.nn.Module常用方法

????????__init__(self):構造函數,用于初始化網絡的各個層和參數。在自定義網絡時,需要在該方法中調用super().__init__()來初始化父類。?

????????forward(self, x):前向傳播方法,定義了數據在網絡中的流動過程。當對模型進行調用時(如model(x)),實際上是調用了該方法。?

????????parameters(self):返回一個迭代器,包含網絡中的所有可學習參數。?

????????named_parameters(self):返回一個迭代器,包含網絡中參數的名稱和對應的參數值。?

????????to(self, device):將模型轉移到指定的設備上。例如,model.to('cuda')將模型轉移到 GPU 上。?

????????train(self, mode=True):將模型設置為訓練模式。?

????????eval(self):將模型設置為評估模式,相當于train(mode=False)。?

????????save_state_dict(self, path):保存模型的參數狀態字典到指定路徑。?

????????load_state_dict(self, state_dict):從參數狀態字典中加載模型的參數。

6.3 程序演示

6.3.1 官網提供的例子

import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):   #搭建的神經網絡 Model繼承了 Module類(父類)def __init__(self):   #初始化函數super(Model, self).__init__()   #必須要這一步,調用父類的初始化函數self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):   #前向傳播(為輸入和輸出中間的處理過程),x為輸入x = F.relu(self.conv1(x))   #conv為卷積,relu為非線性處理return F.relu(self.conv2(x))

注意:前向傳播 forward(在所有子類中進行重寫)

6.3.2 自定義Model

import torch
from torch import nn# 定義一個自定義模型類Custom_Model,繼承自nn.Module
# 所有的神經網絡模型都應該繼承nn.Module,以利用其提供的參數管理、設備轉換等功能
class Custom_Model(nn.Module):# 構造函數,用于初始化模型的層和參數def __init__(self):# 調用父類nn.Module的構造函數,確保模型能夠正確初始化super().__init__()# 前向傳播方法,定義數據在模型中的流動和計算過程# 當對模型實例傳入輸入數據時,會自動調用該方法def forward(self, input):# 定義模型的計算邏輯:輸入數據加1output = input + 1# 返回計算結果return outputCustom_Model = Custom_Model()
# 創建一個張量x,值為1.0,作為模型的輸入數據
x = torch.tensor(1.0)
# 將輸入數據x傳入模型,模型會自動調用forward方法進行計算,得到輸出結果
output = Custom_Model(x)
# 打印輸出結果,此時輸出應為2.0(1.0 + 1)
print(output)

6.4?torch.nn.functional.conv2d介紹

????????torch.nn.functional.conv2d是 PyTorch 中用于執行二維卷積操作的函數,在卷積神經網絡(CNN)中扮演著至關重要的角色,用于提取圖像等二維數據的特征。以下是對它的詳細介紹:

參數說明:

  • input?(Tensor):輸入張量,形狀為(N, C_in, H_in, W_in)。其中,N是批量大小(batch size),表示一次處理的樣本數量;C_in是輸入通道數,例如對于灰度圖像C_in=1,對于彩色圖像(RGB 格式)C_in=3;H_in和W_in分別是輸入特征圖的高度和寬度。
  • weight?(Tensor):卷積核(過濾器)張量,形狀為(C_out, C_in, H_k, W_k)?。C_out是輸出通道數,決定了經過卷積操作后生成的特征圖數量;C_in必須與輸入張量的通道數一致;H_k和W_k分別是卷積核的高度和寬度。
  • bias?(Tensor,可選):偏置張量,形狀為(C_out)?,為每個輸出通道添加一個可學習的偏置值,默認值為None。
  • stride?(int或tuple,默認值:1?):卷積核在輸入特征圖上滑動的步長。如果是一個整數,表示在高度和寬度方向上的步長相同;如果是一個元組(stride_h, stride_w),則分別指定高度和寬度方向上的步長。
  • padding?(int或tuple,默認值:0?):在輸入特征圖的邊緣添加填充(padding)像素。同樣,整數表示在高度和寬度方向上添加相同數量的填充;元組(padding_h, padding_w)分別指定高度和寬度方向上的填充數量。填充可以用來控制輸出特征圖的大小,使其與輸入大小相同或滿足特定的尺寸要求。
  • dilation?(int或tuple,默認值:1?):卷積核元素之間的間距。dilation=1表示正常的卷積核;dilation=2時,卷積核元素之間會間隔一個位置,相當于擴大了卷積核的感受野。
  • groups?(int,默認值:1?):分組卷積的組數。當groups=1時,就是普通的卷積操作;當groups > 1時,輸入通道會被分成groups組,卷積核也會相應分組,每組卷積核只與對應的一組輸入通道進行卷積操作,常用于減少計算量或實現特定的網絡結構,比如 AlexNet 中的分組卷積。

應用場景:

????????torch.nn.functional.conv2d廣泛應用于各類基于卷積神經網絡的任務,如:

  • 圖像分類:從輸入圖像中提取各種層次的特征,用于判斷圖像所屬的類別。
  • 目標檢測:提取圖像特征來定位和識別目標物體。
  • 語義分割:對圖像中的每個像素進行分類,以實現對圖像內容的精細分割。

????????總的來說,torch.nn.functional.conv2d是構建深度學習視覺模型的基礎組件之一,通過合理設置其參數,可以靈活地調整卷積操作,以適應不同的任務需求。

6.4.1 卷積操作原理

6.4.2 實戰演示

import torch
import torch.nn.functional as F
# 將二維矩陣轉化為tensor數據類型
input = torch.tensor([[1, 2, 0, 3, 1],[0, 1, 2, 3, 1],[1, 2, 1, 0, 0],[5, 2, 3, 1, 1],[2, 1, 0, 1, 1]])
# 卷積核
kernel = torch.tensor([[1, 2, 1],[0, 1, 0],[2, 1, 0]])
# 尺寸只有高寬,不符合要求
print(input.shape)    # 5*5
print(kernel.shape)   # 3*#
input = torch.reshape(input, (1, 1, 5, 5))
kernel = torch.reshape(kernel, (1, 1, 3, 3))
print(input.shape)
print(kernel.shape)output = F.conv2d(input, kernel, stride=1)
print(output)

運行結果:

參數修改:

1)、將stride修改

????????stride?(int或tuple,默認值:1?):卷積核在輸入特征圖上滑動的步長。如果是一個整數,表示在高度和寬度方向上的步長相同;如果是一個元組(stride_h, stride_w),則分別指定高度和寬度方向上的步長。

????????????????output = F.conv2d(input, kernel, stride=2)

(2)、修改Padding

????????padding?(int或tuple,默認值:0?):在輸入特征圖的邊緣添加填充(padding)像素。同樣,整數表示在高度和寬度方向上添加相同數量的填充;元組(padding_h, padding_w)分別指定高度和寬度方向上的填充數量。填充可以用來控制輸出特征圖的大小,使其與輸入大小相同或滿足特定的尺寸要求。

????????padding=1:將輸入圖像左右上下兩邊都拓展一個像素,空的地方默認為0

????????????????????????output = F.conv2d(input, kernel, stride=1, padding=1)

運行結果:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/95379.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/95379.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/95379.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

python自學筆記7 可視化初步

圖像的組成工具庫 Matplotlib:繪制靜態圖 Plotly: 可以繪制交互式圖片 圖像的繪制(Matplotlib) 創建圖形,軸對象 創造等差數列 # 包含后端點 arr np.linspace(0, 1, num11) # 不包含后端點 arr_no_endpoint np.linspace(0, 1, n…

GIS 常用的矢量與柵格分析工具

矢量處理工具作用典型應用緩沖區分析Buffer環境影響區域,空間鄰近度分析等,例如道路周圍一公里內的學校,噪音污染影響的范圍裁剪Clip例如使用A市圖層裁剪全國道路數據,獲取A市道路數據交集Intersect識別與LUCC、分區洪水區、基礎設…

http與https協議區別;vue3本地連接https地址接口報500

文章目錄問題解決方案一、問題原因分析二、解決方案詳解1. 保持當前配置(推薦臨時方案)2. 更安全的方案(推薦)3. 環境區分配置(最佳實踐)三、為什么開發環境不用配置?問題 問題:本地…

C語言——深入理解指針(三)

C語言——深入理解指針(三) 1.回調函數是什么? 首先我們來回顧一下函數的直接調用:而回調函數就是通過函數指針調用的函數。我們將函數的指針(地址)作為參數傳遞給另一個函數,當這個指針被用來調…

kettle 8.2 ETL項目【四、加載數據】

一、dim_store表結構,數據來源于業務表,且隨時間會有增加,屬于緩慢變化維(SCD)類型二 轉換步驟如下 詳細步驟如下

【測試報告】SoundWave(Java+Selenium+Jmeter自動化測試)

一、項目背景 隨著數字音樂內容的爆炸式增長,用戶對于便捷、高效的音樂管理與播放需求日益增強。傳統的本地音樂管理方式已無法滿足多設備同步、在線分享與個性化推薦等現代需求。為此,我們設計并開發了一款基于Spring Boot框架的SoundWave,旨…

C++ 類和對象詳解(1)

類和對象是 C 面向對象編程的核心概念,它們為代碼提供了更好的封裝性、可讀性和可維護性。本文將從類的定義開始,逐步講解訪問限定符、類域、實例化、對象大小計算、this 指針等關鍵知識,并對比 C 語言與 C 在實現數據結構時的差異&#xff0…

奈飛工廠:算法優化實戰

推薦系統的算法邏輯與優化技巧在流媒體行業的 “用戶注意力爭奪戰” 中,推薦系統是決定成敗的核心武器。對于擁有2.3 億全球付費用戶的奈飛(Netflix)而言,其推薦系統每天處理數十億次用戶交互,最終實現了一個驚人數據&…

【人工智能99問】BERT的訓練過程和推理過程是怎么樣的?(24/99)

文章目錄BERT的訓練過程與推理過程一、預訓練過程:學習通用語言表示1. 數據準備2. MLM任務訓練(核心)3. NSP任務訓練4. 預訓練優化二、微調過程:適配下游任務1. 任務定義與數據2. 輸入處理3. 模型結構調整4. 微調訓練三、推理過程…

[TryHackMe]Challenges---Game Zone游戲區

這個房間將涵蓋 SQLi(手動利用此漏洞和通過 SQLMap),破解用戶的哈希密碼,使用 SSH 隧道揭示隱藏服務,以及使用 metasploit payload 獲取 root 權限。 1.通過SQL注入獲得訪問權限 手工注入 輸入用戶名 嘗試使用SQL注入…

北京JAVA基礎面試30天打卡09

1.MySQL存儲引擎及區別特性MyISAMMemoryInnoDBB 樹索引? Yes? Yes? Yes備份 / 按時間點恢復? Yes? Yes? Yes集群數據庫支持? No? No? No聚簇索引? No? No? Yes壓縮數據? Yes? No? Yes數據緩存? NoN/A? Yes加密數據? Yes? Yes? Yes外鍵支持? No? No? Yes…

AI時代的SD-WAN異地組網如何落地?

在全球化運營與數字化轉型浪潮下,企業分支機構、數據中心與云服務的跨地域互聯需求激增。傳統專線因成本高昂、部署緩慢、靈活性差等問題日益凸顯不足。SD-WAN以其智能化調度、顯著降本、敏捷部署和云網融合的核心優勢,成為實現高效、可靠、安全異地組網…

css中的color-mix()函數

color-mix() 是 CSS 顏色模塊(CSS Color Module Level 5)中引入的一個強大的顏色混合函數,用于在指定的顏色空間中混合兩種或多種顏色,生成新的顏色值。它解決了傳統顏色混合(如通過透明度疊加)在視覺一致性…

Github desktop介紹(GitHub官方推出的一款圖形化桌面工具,旨在簡化Git和GitHub的使用流程)

文章目錄**1. 簡化 Git 操作****2. 代碼版本控制****3. 團隊協作****4. 代碼托管與共享****5. 集成與擴展****6. 跨平臺支持****7. 適合的使用場景****總結**GitHub Desktop 是 GitHub 官方推出的一款圖形化桌面工具,旨在簡化 Git 和 GitHub 的使用流程,…

整數規劃-分支定界

內容來自:b站數學建模老哥 如:3.4,先找小于3的,再找大于4的 逐個

JetPack系列教程(六):Paging——讓分頁加載不再“禿”然

前言 在Android開發的世界里,分頁加載就像是一場永無止境的馬拉松,每次滾動到底部,都仿佛在提醒你:“嘿,朋友,還有更多數據等著你呢!”但別擔心,Google大佬們早就看透了我們的煩惱&a…

扎實基礎!深入理解Spring框架,解鎖Java開發新境界

大家好,今天想和大家聊聊Java開發路上繞不開的一個重要基石——Spring框架。很多朋友在接觸SpringBoot、SpringCloud這些現代化開發工具時,常常會感到吃力。究其原因,往往是對其底層的Spring核心機制理解不夠透徹。Spring是構建這些高效框架的…

Heterophily-aware Representation Learning on Heterogeneous Graphs

Heterophily-Aware Representation Learning on Heterogeneous Graphs (TPAMI 2025) 計算機科學 1區 I:18.6 top期刊 ?? 摘要 現實世界中的圖結構通常非常復雜,不僅具有全局結構上的異質性,還表現出局部鄰域內的強異質相似性(heterophily)。雖然越來越多的研究揭示了圖…

計算機視覺(7)-純視覺方案實現端到端軌跡規劃(思路梳理)

基于純視覺方案實現端到端軌跡規劃,需融合開源模型、自有數據及系統工程優化。以下提供一套從模型選型到部署落地的完整方案,結合前沿開源技術與工業實踐: 一、開源模型選型與組合策略 1. 感知-預測一體化模型 ViP3D(清華&#…

Nginx 屏蔽服務器名稱與版本信息(源碼級修改)

Nginx 屏蔽服務器名稱與版本信息(源碼級修改) 一、背景與目的 在生產環境部署 Nginx 時,默認配置會在 Server 響應頭中暴露服務類型(如 nginx)和版本號(如 nginx/1.25.4)。這些信息可能被攻擊者…