深入理解Softmax函數及其在PyTorch中的實現

Softmax函數簡介

Softmax函數在機器學習和深度學習中,被廣泛用于多分類問題的輸出層。它將一個實數向量轉換為概率分布,使得每個元素介于0和1之間,且所有元素之和為1。

Softmax函數的定義

給定一個長度為 K K K的輸入向量 z = [ z 1 , z 2 , … , z K ] \boldsymbol{z} = [z_1, z_2, \dots, z_K] z=[z1?,z2?,,zK?],Softmax函數 σ ( z ) \sigma(\boldsymbol{z}) σ(z)定義為:

σ ( z ) i = e z i ∑ j = 1 K e z j , 對于所有? i = 1 , 2 , … , K \sigma(\boldsymbol{z})_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}, \quad \text{對于所有 } i = 1, 2, \dots, K σ(z)i?=j=1K?ezj?ezi??,對于所有?i=1,2,,K

其中:

  • e e e是自然對數的底數,約為2.71828。
  • σ ( z ) i \sigma(\boldsymbol{z})_i σ(z)i?是輸入向量第 i i i個分量對應的Softmax輸出。

Softmax函數的特點

  1. 將輸出轉換為概率分布:Softmax的輸出向量中的每個元素都在 ( 0 , 1 ) (0, 1) (0,1)之間,并且所有元素的和為1,這使得輸出可以視為各類別的概率。

  2. 強調較大的值:Softmax函數會放大輸入向量中較大的元素對應的概率,同時壓縮較小的元素對應的概率。這種特性有助于突出模型認為更有可能的類別。

  3. 可微性:Softmax函數是可微的,這對于基于梯度的優化算法(如反向傳播)非常重要。


數值穩定性的問題

在實際計算中,為了防止指數函數計算過程中可能出現的數值溢出,通常會對輸入向量進行調整。常見的做法是在計算Softmax之前,從輸入向量的每個元素中減去向量的最大值:

σ ( z ) i = e z i ? z max ∑ j = 1 K e z j ? z max \sigma(\boldsymbol{z})_i = \frac{e^{z_i - z_{\text{max}}}}{\sum_{j=1}^{K} e^{z_j - z_{\text{max}}}} σ(z)i?=j=1K?ezj??zmax?ezi??zmax??

其中, z max = max ? { z 1 , z 2 , … , z K } z_{\text{max}} = \max\{z_1, z_2, \dots, z_K\} zmax?=max{z1?,z2?,,zK?}。這種調整不會改變Softmax的輸出結果,但能提高計算的數值穩定性。


Softmax函數的應用場景

  1. 多分類問題:在神經網絡的最后一層,Softmax函數常用于將模型的線性輸出轉換為概率分布,以進行多分類預測。

  2. 注意力機制:在深度學習中的注意力模型中,Softmax用于計算注意力權重,以突顯重要的輸入特征。

  3. 語言模型:在自然語言處理任務中,Softmax函數用于預測下一個詞的概率分布。


Softmax函數的示例計算

假設有一個三類別分類問題,神經網絡的輸出為一個長度為3的向量:

z = [ z 1 , z 2 , z 3 ] = [ 2.0 , 1.0 , 0.1 ] \boldsymbol{z} = [z_1, z_2, z_3] = [2.0, 1.0, 0.1] z=[z1?,z2?,z3?]=[2.0,1.0,0.1]

我們想使用Softmax函數將其轉換為概率分布。

步驟1:計算每個元素的指數

e z 1 = e 2.0 = 7.3891 e z 2 = e 1.0 = 2.7183 e z 3 = e 0.1 = 1.1052 \begin{align*} e^{z_1} &= e^{2.0} = 7.3891 \\ e^{z_2} &= e^{1.0} = 2.7183 \\ e^{z_3} &= e^{0.1} = 1.1052 \end{align*} ez1?ez2?ez3??=e2.0=7.3891=e1.0=2.7183=e0.1=1.1052?

步驟2:計算指數和

sum = e z 1 + e z 2 + e z 3 = 7.3891 + 2.7183 + 1.1052 = 11.2126 \text{sum} = e^{z_1} + e^{z_2} + e^{z_3} = 7.3891 + 2.7183 + 1.1052 = 11.2126 sum=ez1?+ez2?+ez3?=7.3891+2.7183+1.1052=11.2126

步驟3:計算Softmax輸出

σ 1 = e z 1 sum = 7.3891 11.2126 = 0.6590 σ 2 = e z 2 sum = 2.7183 11.2126 = 0.2424 σ 3 = e z 3 sum = 1.1052 11.2126 = 0.0986 \begin{align*} \sigma_1 &= \frac{e^{z_1}}{\text{sum}} = \frac{7.3891}{11.2126} = 0.6590 \\ \sigma_2 &= \frac{e^{z_2}}{\text{sum}} = \frac{2.7183}{11.2126} = 0.2424 \\ \sigma_3 &= \frac{e^{z_3}}{\text{sum}} = \frac{1.1052}{11.2126} = 0.0986 \end{align*} σ1?σ2?σ3??=sumez1??=11.21267.3891?=0.6590=sumez2??=11.21262.7183?=0.2424=sumez3??=11.21261.1052?=0.0986?

因此,經過Softmax函數后,輸出概率分布為:

σ ( z ) = [ 0.6590 , 0.2424 , 0.0986 ] \sigma(\boldsymbol{z}) = [0.6590, 0.2424, 0.0986] σ(z)=[0.6590,0.2424,0.0986]

這表示模型預測第一個類別的概率約為65.9%,第二個類別約為24.24%,第三個類別約為9.86%。


使用PyTorch實現Softmax函數

在PyTorch中,可以通過多種方式實現Softmax函數。以下將通過示例演示如何使用torch.nn.functional.softmaxtorch.nn.Softmax

創建輸入數據

首先,創建一個示例輸入張量:

import torch
import torch.nn as nn
import torch.nn.functional as F# 創建一個輸入張量,形狀為 (batch_size, features)
input_tensor = torch.tensor([[2.0, 1.0, 0.1],[1.0, 3.0, 0.2]])
print("輸入張量:")
print(input_tensor)

輸出:

輸入張量:
tensor([[2.0000, 1.0000, 0.1000],[1.0000, 3.0000, 0.2000]])

方法一:使用torch.nn.functional.softmax

利用PyTorch中torch.nn.functional.softmax函數直接對輸入數據應用Softmax。

# 在維度1上(即特征維)應用Softmax
softmax_output = F.softmax(input_tensor, dim=1)
print("\nSoftmax輸出:")
print(softmax_output)

輸出:

Softmax輸出:
tensor([[0.6590, 0.2424, 0.0986],[0.1065, 0.8726, 0.0209]])

方法二:使用torch.nn.Softmax模塊

也可以使用torch.nn中的Softmax模塊。

# 創建一個Softmax層實例
softmax = nn.Softmax(dim=1)# 對輸入張量應用Softmax層
softmax_output_module = softmax(input_tensor)
print("\n使用nn.Softmax模塊的輸出:")
print(softmax_output_module)

輸出:

使用nn.Softmax模塊的輸出:
tensor([[0.6590, 0.2424, 0.0986],[0.1065, 0.8726, 0.0209]])

在神經網絡模型中應用Softmax

構建一個簡單的神經網絡模型,在最后一層使用Softmax激活函數。

class SimpleNetwork(nn.Module):def __init__(self, input_size, num_classes):super(SimpleNetwork, self).__init__()self.layer1 = nn.Linear(input_size, 5)self.layer2 = nn.Linear(5, num_classes)# 使用LogSoftmax提高數值穩定性self.softmax = nn.LogSoftmax(dim=1)def forward(self, x):x = F.relu(self.layer1(x))x = self.layer2(x)x = self.softmax(x)return x# 定義輸入大小和類別數
input_size = 3
num_classes = 3# 創建模型實例
model = SimpleNetwork(input_size, num_classes)# 查看模型結構
print("\n模型結構:")
print(model)

輸出:

模型結構:
SimpleNetwork((layer1): Linear(in_features=3, out_features=5, bias=True)(layer2): Linear(in_features=5, out_features=3, bias=True)(softmax): LogSoftmax(dim=1)
)

前向傳播:

# 將輸入數據轉換為浮點型張量
input_data = input_tensor.float()# 前向傳播
output = model(input_data)
print("\n模型輸出(對數概率):")
print(output)

輸出:

模型輸出(對數概率):
tensor([[-1.2443, -0.7140, -1.2645],[-1.3689, -0.6535, -1.5142]], grad_fn=<LogSoftmaxBackward0>)

轉換為概率:

# 取指數,轉換為概率
probabilities = torch.exp(output)
print("\n模型輸出(概率):")
print(probabilities)

輸出:

模型輸出(概率):
tensor([[0.2882, 0.4898, 0.2220],[0.2541, 0.5204, 0.2255]], grad_fn=<ExpBackward0>)

預測類別:

# 獲取每個樣本概率最大的類別索引
predicted_classes = torch.argmax(probabilities, dim=1)
print("\n預測的類別:")
print(predicted_classes)

輸出:

預測的類別:
tensor([1, 1])

torch.nn.functional.softmaxtorch.nn.Softmax的區別

函數式API與模塊化API的設計理念

PyTorch提供了兩種API:

  1. 函數式API (torch.nn.functional)

    • 特點:無狀態(Stateless),不包含可學習的參數。
    • 使用方式:直接調用函數。
    • 適用場景:需要在forward方法中靈活應用各種操作。
  2. 模塊化API (torch.nn.Module)

    • 特點:有狀態(Stateful),可能包含可學習的參數,即使某些模塊沒有參數(如Softmax),但繼承自nn.Module
    • 使用方式:需要先實例化,再在前向傳播中調用。
    • 適用場景:構建模型時,統一管理各個層和操作。

具體到Softmax的實現

  • torch.nn.functional.softmax(函數)

    • 使用示例

      import torch.nn.functional as F
      output = F.softmax(input_tensor, dim=1)
      
    • 特點:直接調用,簡潔靈活。

  • torch.nn.Softmax(模塊)

    • 使用示例

      import torch.nn as nn
      softmax = nn.Softmax(dim=1)
      output = softmax(input_tensor)
      
    • 特點:作為模型的一層,便于與其他層組合,保持代碼結構一致。

為什么存在兩個實現?

提供兩種實現方式是為了滿足不同開發者的需求和編程風格。

  • 使用nn.Softmax的優勢

    • 在模型定義階段明確各層,結構清晰。
    • 便于使用nn.Sequential構建順序模型。
    • 統一管理模型的各個部分。
  • 使用F.softmax的優勢

    • 代碼簡潔,直接調用函數。
    • 適用于需要在forward中進行靈活操作的情況。

使用示例

使用nn.Softmax
import torch
import torch.nn as nn# 定義模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layer = nn.Linear(10, 5)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.layer(x)x = self.softmax(x)return x# 實例化和使用
model = MyModel()
input_tensor = torch.randn(2, 10)
output = model(input_tensor)
print(output)
使用F.softmax
import torch
import torch.nn as nn
import torch.nn.functional as F# 定義模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layer = nn.Linear(10, 5)def forward(self, x):x = self.layer(x)x = F.softmax(x, dim=1)return x# 實例化和使用
model = MyModel()
input_tensor = torch.randn(2, 10)
output = model(input_tensor)
print(output)

總結

Softmax函數在深度學習中起著關鍵作用,尤其在多分類任務中。PyTorch為了滿足不同的開發需求,提供了torch.nn.functional.softmaxtorch.nn.Softmax兩種實現方式。

  • F.softmax:函數式API,靈活簡潔,適合在forward方法中直接調用。

  • nn.Softmax:模塊化API,便于模型結構的統一管理,適合在模型初始化時定義各個層。

在實際開發中,選擇適合你的項目和團隊的方式。如果更喜歡模塊化的代碼結構,使用nn.Softmax;如果追求簡潔和靈活,使用F.softmax。同時,要注意數值穩定性的問題,尤其是在計算損失函數時,建議使用nn.LogSoftmaxnn.NLLLoss結合使用。


參考文獻

  • PyTorch官方文檔 - Softmax函數
  • PyTorch官方文檔 - nn.Softmax
  • PyTorch官方教程 - 構建神經網絡
  • PyTorch論壇 - Softmax激活函數:nn.Softmax vs F.softmax

希望本文能幫助讀者深入理解Softmax函數及其在PyTorch中的實現和應用。如有任何疑問,歡迎交流討論!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/77608.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/77608.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/77608.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Vue 3 響應式更新問題解析

在 Vue 3 中&#xff0c;即使使用 reactive 或 ref 創建的響應式數據&#xff0c;當數據量很大時也可能出現更新不及時的情況。以下是原因和解決方案&#xff1a; 核心原因 ??響應式系統優化機制??&#xff1a; Vue 3 使用 Proxy 實現響應式&#xff0c;比 Vue 2 更高效但為…

異形遮罩之QML中的 `OpacityMask` 實戰

文章目錄 &#x1f327;? 傳統實現的問題&#x1f449; 效果圖 &#x1f308; 使用 OpacityMask 的理想方案&#x1f449;代碼如下&#x1f3af; 最終效果&#xff1a; ? 延伸應用&#x1f9e0; 總結 在 UI 設計中&#xff0c;經常希望實現一些“異形區域”擁有統一透明度或顏…

數據可視化 —— 堆形圖應用(大全)

一、案例一&#xff1a;溫度堆積圖 # 導入 matplotlib 庫中的 pyplot 模塊&#xff0c;這個模塊提供了類似于 MATLAB 的繪圖接口&#xff0c; # 方便我們創建各種類型的可視化圖表&#xff0c;比如折線圖、柱狀圖、散點圖等 import matplotlib.pyplot as plt # 導入 numpy 庫&…

python工程中的包管理(requirements.txt)

pip install -r requirements.txtpython工程通過requirements.txt來管理依賴庫版本&#xff0c;上述命令&#xff0c;可以一把安裝依賴庫&#xff0c;類似java中maven的pom.xml文件。 參考 [](

操作系統 3.4-段頁結合的實際內存管理

段與頁結合的初步思路 虛擬內存的引入&#xff1a; 為了結合段和頁的優勢&#xff0c;操作系統引入了虛擬內存的概念。虛擬內存是一段地址空間&#xff0c;它映射到物理內存上&#xff0c;但對用戶程序是透明的。 段到虛擬內存的映射&#xff1a; 用戶程序中的段首先映射到虛…

【Amazon EC2】為何基于瀏覽器的EC2 Instance Connect 客戶端連接不上EC2實例

文章目錄 前言&#x1f4d6;一、報錯先知?二、問題復現&#x1f62f;三、解決辦法&#x1f3b2;四、驗證結果&#x1f44d;五、參考鏈接&#x1f517; 前言&#x1f4d6; 這篇文章將講述我在 Amazon EC2 上使用 RHEL9 AMI 時無法連接到 EC2 實例時所遇到的麻煩&#x1f616; …

Python學習筆記(二)(字符串)

文章目錄 編寫簡單的程序一、標識符 (Identifiers)及關鍵字命名規則&#xff1a;命名慣例&#xff1a;關鍵字 二、變量與賦值 (Variables & Assignment)變量定義&#xff1a;多重賦值&#xff1a;變量交換&#xff1a;&#xff08;很方便喲&#xff09; 三、輸入與輸出 (In…

Hydra Columnar:一個開源的PostgreSQL列式存儲引擎

Hydra Columnar 是一個 PostgreSQL 列式存儲插件&#xff0c;專為分析型&#xff08;OLAP&#xff09;工作負載設計&#xff0c;旨在提升大規模分析查詢和批量更新的效率。 Hydra Columnar 以擴展插件的方式提供&#xff0c;主要特點包括&#xff1a; 采用列式存儲&#xff0c…

es的告警信息

Elasticsearch&#xff08;ES&#xff09;是一個開源的分布式搜索和分析引擎&#xff0c;在運行過程中可能會產生多種告警信息&#xff0c;以提示用戶系統中存在的潛在問題或異常情況。以下是一些常見的 ES 告警信息及其含義和處理方法&#xff1a; 集群健康狀態告警 信息示例…

健康與好身體筆記

文章目錄 保證睡眠飯后百步走&#xff0c;活到九十九補充鈣質一副好腸胃肚子咕咕叫 健康和工作的取舍 以前對健康沒概念&#xff0c;但是隨著年齡增長&#xff0c;健康問題凸顯出來。 持續維護該文檔&#xff0c;健康是個永恒的話題。 保證睡眠 一是心態要好&#xff0c;沾枕…

vue實現在線進制轉換

vue實現在線進制轉換 主要功能包括&#xff1a; 1.支持2-36進制之間的轉換。 2.支持整數和浮點數的轉換。 3.輸入驗證&#xff08;雖然可能存在不嚴格的情況&#xff09;。 4.錯誤提示。 5.結果展示&#xff0c;包括大寫字母。 6.用戶友好的界面&#xff0c;包括下拉菜單、輸…

智體知識庫:poplang編程語言是什么?

問&#xff1a;poplang語言是什么 Poplang 語言簡介 Poplang&#xff08;OPCode-Oriented Programming Language&#xff09;是一種面向操作碼&#xff08;Opcode&#xff09;的輕量級編程語言&#xff0c;主要用于智體&#xff08;Agent&#xff09;系統中的自動化任務處理、…

二分查找5:852. 山脈數組的峰頂索引

鏈接&#xff1a;852. 山脈數組的峰頂索引 - 力扣&#xff08;LeetCode&#xff09; 題解&#xff1a; 事實證明&#xff0c;二分查找不局限于有序數組&#xff0c;非有序的數組也同樣適用 二分查找主要思想在于二段性&#xff0c;即將數組分為兩段。本體就可以將數組分為ar…

下列軟件包有未滿足的依賴關系: python3-catkin-pkg : 沖突: catkin 但是 0.8.10-

下列軟件包有未滿足的依賴關系: python3-catkin-pkg : 沖突: catkin 但是 0.8.10- 解決&#xff1a; 1. 確認當前的包狀態 首先&#xff0c;運行以下命令來查看當前安裝的catkin和python3-catkin-pkg版本&#xff0c;以及它們之間的依賴關系&#xff1a; dpkg -l | grep ca…

深度學習:AI 大模型時代的智能引擎

當 Deepspeek 以逼真到難辨真假的語音合成和視頻生成技術橫空出世&#xff0c;瞬間引發了全球對 AI 倫理與技術邊界的激烈討論。從偽造名人演講、制造虛假新聞&#xff0c;到影視行業的特效革新&#xff0c;這項技術以驚人的速度滲透進大眾視野。但在 Deepspeek 強大功能的背后…

醫學分割新標桿!雙路徑PGM-UNet:CNN+Mamba實現病灶毫厘級捕捉

一、引言&#xff1a;醫學圖像分割的挑戰與機遇 醫學圖像分割是輔助疾病診斷和治療規劃的關鍵技術&#xff0c;但傳統方法常受限于復雜病理特征和微小結構。現有深度學習模型&#xff08;如CNN和Transformer&#xff09;雖各有優勢&#xff0c;但CNN難以建模長距離依賴&…

CV - 目標檢測

物體檢測 目標檢測和圖片分類的區別&#xff1a; 圖像分類&#xff08;Image Classification&#xff09; 目的&#xff1a;圖像分類的目的是識別出圖像中主要物體的類別。它試圖回答“圖像是什么&#xff1f;”的問題。 輸出&#xff1a;通常輸出是一個標簽或一組概率值&am…

高并發秒殺系統設計:關鍵技術解析與典型陷阱規避

電商、在線票務等眾多互聯網業務場景中&#xff0c;高并發秒殺活動屢見不鮮。這類活動往往在短時間內會涌入海量的用戶請求&#xff0c;對系統架構的性能、穩定性和可用性提出了極高的挑戰。曾經&#xff0c;高并發秒殺架構設計讓許多開發者望而生畏&#xff0c;然而&#xff0…

藍橋杯--結束

沖刺題單 基礎 一、簡單模擬&#xff08;循環數組日期進制&#xff09; &#xff08;一&#xff09;日期模擬 知識點 1.把月份寫為數組&#xff0c;二月默認為28天。 2.寫一個判斷閏年的方法&#xff0c;然后循環年份的時候判斷并更新二月的天數 3.對于星期數的計算&#…

13、nRF52xx藍牙學習(GPIOTE組件方式的任務配置)

下面再來探討下驅動庫如何實現任務的配置&#xff0c;驅動庫的實現步驟應該和寄存器方式對應&#xff0c;關 鍵點就是如何調用驅動庫的函數。 本例里同樣的對比寄存器方式編寫兩路的 GPOITE 任務輸出&#xff0c;一路配置為輸出翻轉&#xff0c;一路設 置為輸出低電平。和 …