PyTorch之nn.Module與nn.functional用法區別

文章目錄

  • 1. nn.Module
  • 2. nn.functional
    • 2.1 基本用法
    • 2.2 常用函數
  • 3. nn.Module 與 nn.functional
    • 3.1 主要區別
    • 3.2 具體樣例:nn.ReLU() 與 F.relu()
  • 參考資料

1. nn.Module

在PyTorch中,nn.Module 類扮演著核心角色,它是構建任何自定義神經網絡層、復雜模塊或完整神經網絡架構的基礎構建塊。通過繼承 nn.Module 并在其子類中定義模型結構和前向傳播邏輯(forward() 方法),開發者能夠方便地搭建并訓練深度學習模型。

關于 nn.Module 的更多介紹可以參考博客:PyTorch之nn.Module、nn.Sequential、nn.ModuleList使用詳解

這里,我們基于nn.Module創建一個簡單的神經網絡模型,實現代碼如下:

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(MyModel, self).__init__()self.layer1 = nn.Linear(input_size, hidden_size)self.layer2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.layer1(x))x = self.layer2(x)return x

2. nn.functional

nn.functional 是PyTorch中一個重要的模塊,它包含了許多用于構建神經網絡的函數。與 nn.Module 不同,nn.functional 中的函數不具有可學習的參數。這些函數通常用于執行各種非線性操作、損失函數、激活函數等。

2.1 基本用法

如何在神經網絡中使用nn.functional?

在PyTorch中,你可以輕松地在神經網絡中使用 nn.functional 函數。通常,你只需將輸入數據傳遞給這些函數,并將它們作為網絡的一部分。

以下是一個簡單的示例,演示如何在一個全連接神經網絡中使用ReLU激活函數:

import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc1 = nn.Linear(64, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return x

在上述示例中,我們首先導入nn.functional 模塊,然后在網絡的forward 方法中使用F.relu 函數作為激活函數。

nn.functional 的主要優勢是它的計算效率和靈活性,因為它允許你以函數的方式直接調用這些操作,而不需要創建額外的層。

2.2 常用函數

(1)激活函數

激活函數是神經網絡中的關鍵組件,它們引入非線性性,使網絡能夠擬合復雜的數據。以下是一些常見的激活函數:

  • ReLU(Rectified Linear Unit)
    ReLU是一種簡單而有效的激活函數,它將輸入值小于零的部分設為零,大于零的部分保持不變。它的數學表達式如下:
output = F.relu(input)
  • Sigmoid
    Sigmoid函數將輸入值映射到0和1之間,常用于二分類問題的輸出層。它的數學表達式如下:
output = F.sigmoid(input)
  • Tanh(雙曲正切)
    Tanh函數將輸入值映射到-1和1之間,它具有零中心化的特性,通常在循環神經網絡中使用。它的數學表達式如下:
output = F.tanh(input)

(2)損失函數

損失函數用于度量模型的預測與真實標簽之間的差距。PyTorch的nn.functional 模塊包含了各種常用的損失函數,例如:

  • 交叉熵損失(Cross-Entropy Loss)
    交叉熵損失通常用于多分類問題,計算模型的預測分布與真實分布之間的差異。它的數學表達式如下:
loss = F.cross_entropy(input, target)
  • 均方誤差損失(Mean Squared Error Loss)
    均方誤差損失通常用于回歸問題,度量模型的預測值與真實值之間的平方差。它的數學表達式如下:
loss = F.mse_loss(input, target)
  • L1 損失
    L1損失度量預測值與真實值之間的絕對差距,通常用于稀疏性正則化。它的數學表達式如下:
loss = F.l1_loss(input, target)

(3)非線性操作

nn.functional 模塊還包含了許多非線性操作,如池化、歸一化等。

  • 最大池化(Max Pooling)
    最大池化是一種用于減小特征圖尺寸的操作,通常用于卷積神經網絡中。它的數學表達式如下:
output = F.max_pool2d(input, kernel_size)
  • 批量歸一化(Batch Normalization)
    批量歸一化是一種用于提高訓練穩定性和加速收斂的技術。它的數學表達式如下:
output = F.batch_norm(input, mean, std, weight, bias)

3. nn.Module 與 nn.functional

3.1 主要區別

nn.Module 與 nn.functional 的主要區別在于:

  • nn.Module實現的layers是一個特殊的類,都是由class Layer(nn.Module)定義,會自動提取可學習的參數;
  • nn.functional中的函數更像是純函數,由def function(input)定義。

注意:

  1. 如果模型有可學習的參數時,最好使用nn.Module。
  2. 激活函數(ReLU、sigmoid、Tanh)、池化(MaxPool)等層沒有可學習的參數,可以使用對應的functional函數。
  3. 卷積、全連接等有可學習參數的網絡建議使用nn.Module。
  4. dropout沒有可學習參數,但建議使用nn.Dropout而不是nn.functional.dropout。

3.2 具體樣例:nn.ReLU() 與 F.relu()

nn.ReLU() :

import torch.nn as nn
'''
nn.ReLU()

F.relu():

import torch.nn.functional as F
'''
out = F.relu(input)

其實這兩種方法都是使用relu激活,只是使用的場景不一樣,F.relu()是函數調用,一般使用在foreward函數里。而nn.ReLU()是模塊調用,一般在定義網絡層的時候使用。

當用print(net)輸出時,nn.ReLU()會有對應的層,而F.ReLU()是沒有輸出的。

import torch.nn as nn
import torch.nn.functional as Fclass NET1(nn.Module):def __init__(self):super(NET1, self).__init__()self.conv = nn.Conv2d(3, 16, 3, 1, 1)self.bn = nn.BatchNorm2d(16)self.relu = nn.ReLU()  # 模塊的激活函數def forward(self, x):out = self.conv(x)x = self.bn(x)out = self.relu()return outclass NET2(nn.Module):def __init__(self):super(NET2, self).__init__()self.conv = nn.Conv2d(3, 16, 3, 1, 1)self.bn = nn.BatchNorm2d(16)def forward(self, x):x = self.conv(x)x = self.bn(x)out = F.relu(x)  # 函數的激活函數return outnet1 = NET1()
net2 = NET2()
print(net1)
print(net2)

在這里插入圖片描述

參考資料

  • PyTorch的nn.Module類的詳細介紹
  • PyTorch nn.functional 模塊詳解:探索神經網絡的魔法工具箱
  • pytorch:F.relu() 與 nn.ReLU() 的區別

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/37501.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/37501.shtml
英文地址,請注明出處:http://en.pswp.cn/web/37501.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Spring Boot 源碼學習】初識 ConfigurableEnvironment

《Spring Boot 源碼學習系列》 初識 ConfigurableEnvironment 一、引言二、主要內容2.1 Environment2.1.1 配置文件(profiles)2.1.2 屬性(properties) 2.2 ConfigurablePropertyResolver2.2.1 屬性類型轉換配置2.2.2 占位符配置2.…

wxss和css有什么區別?

WXSS(WeiXin Style Sheets)和CSS(Cascading Style Sheets)在功能和應用上有很多相似之處,但針對微信小程序的特殊需求,WXSS對CSS進行了一些擴展和修改。以下是WXSS和CSS之間的主要區別: 尺寸單…

Mybatis實現流程

一&#xff0c;UserDAO 接口定義 首先&#xff0c;定義 UserDAO接口&#xff0c;包含 getList()方法,定義類型為List<User>&#xff1a; package dao;import model.User; import java.util.List;public interface UserDAO {List<User> getList(); }二&#xff0c…

Python--進程基礎

創建進程 os.fork() 該方法只能在linux和mac os中使用&#xff0c;因為其主要基于系統的fork來實現。window中沒有這個方法。 通過os.fork()方法會創建一個子進程&#xff0c;子進程的程序集為該語句下方的所有語句。 import os??print("主進程的PID為:" , os.g…

Python pdfkit wkhtmltopdf html轉換pdf 黑體字體亂碼

wkhtmltopdf 黑體在html轉換pdf時&#xff0c;黑體亂碼&#xff0c;分析可能wkhtmltopdf對黑體字體不太兼容&#xff1b; 1.html內容如下 <html> <head> <meta http-equiv"content-type" content"text/html;charsetutf-8"> </head&…

DreamView數據流

DreamView數據流 查看DV中界面啟動dag&#xff0c;/apollo/modules/dreamview_plus/conf/hmi_modes/pnc.pb.txt可以看到點擊界面的planning按鈕&#xff0c;后臺其實啟動的是/apollo/modules/planning/planning_component/dag/planning.dag和/apollo/modules/external_command…

語音識別應用Python示例

語音識別是將語音信號轉換為文本的技術&#xff0c;是人工智能領域的重要研究方向之一。下面是一個基于Python的簡單語音識別應用的代碼示例。 首先&#xff0c;需要安裝Python的語音識別庫SpeechRecognition。可以使用以下命令進行安裝&#xff1a; pip install SpeechRecog…

版本號比較

版本號比較&#xff1a; 注意&#xff1a; 不可以直接使用字符串比較的方法進行版本號比較。例如 2.29.1 > 2.3.0 是 false 的 版本號比較可以參考以下代碼&#xff1a; function compareVersion(v1, v2) {v1 v1.split(.)v2 v2.split(.)const len Math.max(v1.length, …

Oracle連接mysql

oracle使用的11g&#xff0c;在一臺windows服務器&#xff1b;mysql使用的是5.7版本&#xff0c;在另一臺windows服務器&#xff0c;這兩個服務器之間的網絡是互通的。做BI時&#xff0c;要獲取不同數據源的數據&#xff0c;這些數據源可能是Oracle&#xff0c;也可能是sqlserv…

springboot基礎入門2(profile應用)

Profile應用 一、何為Profile二、profile配置方式1.多profile文件方式2.yml多文檔方式 三、加載順序1. file:./config/: 當前項目下的/config目錄下2. file:./ &#xff1a;當前項目的根目錄3. classpath:/config/:classpath的/config目錄4. classpath:/ : classpath的根目錄 四…

【設計模式】【創建型5-2】【工廠方法模式】

文章目錄 工廠方法模式工廠方法模式的結構示例產品接口具體產品工廠接口具體工廠客戶端代碼 實際的使用 工廠方法模式 工廠方法模式的結構 產品&#xff08;Product&#xff09;&#xff1a;定義工廠方法所創建的對象的接口。 具體產品&#xff08;ConcreteProduct&#xff0…

Redis 集群模式

一、集群模式概述 Redis 中哨兵模式雖然提高了系統的可用性&#xff0c;但是真正存儲數據的還是主節點和從節點&#xff0c;并且每個節點都存儲了全量的數據&#xff0c;此時&#xff0c;如果數據量過大&#xff0c;接近或超出了 主節點 / 從節點機器的物理內存&#xff0c;就…

個人網站制作 Part 28 添加用戶活動跟蹤功能 | Web開發項目添加頁面緩存

文章目錄 &#x1f469;?&#x1f4bb; 基礎Web開發練手項目系列&#xff1a;個人網站制作&#x1f680; 添加用戶活動跟蹤功能&#x1f528;使用分析工具&#x1f527;步驟 1: 選擇分析工具&#x1f527;步驟 2: 注冊Google Analytics賬戶&#x1f527;步驟 3: 獲取Analytics…

Java面試題--JVM大廠篇之深入了解G1 GC:高并發、響應時間敏感應用的最佳選擇

引言&#xff1a; 在現代Java應用的性能優化中&#xff0c;垃圾回收器&#xff08;GC&#xff09;的選擇至關重要。對于高并發、響應時間敏感的應用而言&#xff0c;G1 GC&#xff08;Garbage-First Garbage Collector&#xff09;無疑是一個強大的工具。本文將深入探討G1 GC適…

李一桐遭遇蜈蚣驚魂

李一桐遭遇“蜈蚣驚魂”&#xff01;劉宇寧展現真男人本色在娛樂圈的幕后&#xff0c;總有一些心跳加速的驚險。近日&#xff0c;李一桐在拍戲時遭遇了一場“蜈蚣驚魂”&#xff0c;讓無數粉絲和網友為她捏了一把冷汗。而在這場驚險的遭遇中&#xff0c;劉宇寧展現出了真男人的…

NOI大綱——普及組——二叉搜索樹

二叉搜索樹 二叉搜索樹&#xff08;Binary Search Tree&#xff0c;簡稱BST&#xff09;是一種特殊的二叉樹&#xff0c;它具有以下幾個特點&#xff1a; 節點的左子樹上的所有節點的值都小于或等于該節點的值。節點的右子樹上的所有節點的值都大于或等于該節點的值。每個節點…

ActiveMq工具之管理頁面說明

文章目錄 安裝ActiveMQ一: 訪問管理頁面二: 進入管理頁面&#xff0c;主頁三: Queues頁說明四: Topics頁說明五: Subscribers頁說明 安裝ActiveMQ wget https://archive.apache.org/dist//activemq/5.13.3/apache-activemq-5.13.3-bin.tar.gz wget https://mirrors.huaweiclou…

為什么越來越多的企業選擇外包?賦能企業未來

軟件開發過程包括設計需求、設計方案、產品研發、產品交付、后期維護&#xff0c;許多企業并沒有軟件開發的專業能力與工作經驗&#xff0c;將軟件開發工作進行外包是比較節約成本的&#xff0c;企業能少走不少彎路。 YesPMP平臺&#xff08;一站式軟件外包、項目外包服務-YesP…

UWA Pipeline 2.6.1版本更新

UWA Pipeline是專為游戲開發團隊設計的本地協作平臺&#xff0c;旨在幫助團隊建立專業的DevOps研發交付流水線。本平臺提供了可視化的CI/CD操作界面&#xff0c;高可用的自動化測試和無縫集成的UWA性能保障服務等核心功能。 在最新的Pipeline更新中&#xff0c;UWA引入了參數配…

protobufjs解析proto消息出錯RangeError: index out of range: 2499 + 10 > 2499解決辦法

使用websocket通訊傳輸protobuf消息的時候&#xff0c;decode的時候出錯了&#xff1a; RangeError: index out of range: 2499 10 > 2499 Error: invalid wire type 4 at offset 1986 出現這種錯誤的時候&#xff0c;99%是因為proto里面的消息類型和服務端發送的消息類型不…