PyTorch中三角函數與特殊運算詳解和實戰場景示例

在 PyTorch 中,三角函數(如 sin, cos, tan 等)和一些特殊數學運算(如雙曲函數、反三角函數、hypot, atan2, clamp, lerp, sigmoid, softplus, special 模塊等)被廣泛用于科學計算、機器學習、深度學習中的前向推理或梯度計算中。


PyTorch 中的三角函數與特殊運算詳解


1. 常見三角函數

torch.sin(input)

  • 定義:對輸入張量每個元素計算正弦值
  • 參數input (Tensor) – 輸入張量(角度以弧度為單位)
  • 返回值:返回一個和 input 同形狀的張量,每個元素是 sin(x)
  • 示例
import torch
x = torch.tensor([0, torch.pi/2, torch.pi])
y = torch.sin(x)
print(y)  # tensor([0.0000, 1.0000, 0.0000])

torch.cos(input)

  • 類似于 sin,返回余弦值

torch.tan(input)

  • 返回正切值(注意:tan(π/2) 會出現無窮大)

torch.asin(input), torch.acos(input), torch.atan(input)

  • 反三角函數,輸出的是角度(單位:弧度)
  • 輸入值需在定義域內(asin/acos 的輸入必須在 [-1, 1]

torch.atan2(input, other)

  • 定義:返回 atan(input / other),但能處理分母為 0 的情形,按象限返回正確角度

  • 參數

    • input: 分子
    • other: 分母
  • 示例

a = torch.tensor([1.0])
b = torch.tensor([1.0])
print(torch.atan2(a, b))  # 輸出 π/4 ≈ 0.7854

torch.hypot(x, y)

  • 定義:計算 sqrt(x2 + y2)
  • 常用于二維向量的模長(歐幾里得范數)
x = torch.tensor([3.0])
y = torch.tensor([4.0])
print(torch.hypot(x, y))  # 輸出 tensor([5.])

2. 雙曲函數

torch.sinh, torch.cosh, torch.tanh

x = torch.tensor([0.0, 1.0])
print(torch.sinh(x))  # tensor([0.0000, 1.1752])
print(torch.tanh(x))  # tensor([0.0000, 0.7616])

3. 特殊數學函數(常用于神經網絡中)

torch.sigmoid(input)

  • 返回:1 / (1 + exp(-x))
  • 常用于二分類模型輸出層
x = torch.tensor([-1.0, 0.0, 1.0])
print(torch.sigmoid(x))  # tensor([0.2689, 0.5000, 0.7311])

torch.nn.functional.softplus(input)

  • 類似于平滑版的 ReLU:log(1 + exp(x))
  • 可用于避免 ReLU 的非可導性問題
import torch.nn.functional as F
x = torch.tensor([-1.0, 0.0, 1.0])
print(F.softplus(x))  # tensor([0.3133, 0.6931, 1.3133])

torch.lerp(start, end, weight)

  • 線性插值:(1 - weight) * start + weight * end
a = torch.tensor([0.0])
b = torch.tensor([10.0])
print(torch.lerp(a, b, 0.3))  # tensor([3.])

torch.clamp(input, min, max)

  • 限制張量的最小最大范圍
x = torch.tensor([-2.0, 0.5, 3.0])
print(torch.clamp(x, min=0.0, max=1.0))  # tensor([0.0, 0.5, 1.0])

4. torch.special 模塊(高級特殊函數)

import torch.special as S# gamma 函數:S.gamma(x) = (x-1)!
x = torch.tensor([1.0, 2.0, 3.0])
print(S.gamma(x))  # tensor([1., 1., 2.])# erf 函數(高斯誤差函數)
print(S.erf(torch.tensor([0.0, 1.0])))  # tensor([0.0000, 0.8427])

常用 torch.special 函數包括:

函數名功能
special.erf高斯誤差函數
special.gammalngamma 函數的對數
special.logitsigmoid 的反函數
special.expitsigmoid 函數(別名)
special.i0第一類貝塞爾函數
special.digammagamma 函數的導數

總結對比表

函數作用說明
torch.sin/cos/tan三角函數輸入為弧度,輸出 [-1, 1]
torch.asin/acos/atan反三角函數返回弧度
torch.atan2atan(y/x) 并考慮象限更安全的除法處理
torch.hypot計算 √(x2 + y2)常用于距離計算
torch.sigmoidS型函數常用于分類神經網絡
torch.softplus平滑 ReLU輸出始終 > 0
torch.clamp限定區間防止梯度爆炸或數值異常
torch.lerp線性插值用于圖像插值或平滑過渡
torch.special特殊函數模塊包含 gamma、貝塞爾等高級函數

使用 PyTorch 計算三角函數的梯度(自動求導演示)

PyTorch 中只要一個張量啟用了 .requires_grad=True,就可以使用 .backward() 自動求導。


示例 1:sin(x) 的梯度(導數為 cos(x)

import torchx = torch.tensor([torch.pi / 4], requires_grad=True)  # 45度,sin(π/4) ≈ 0.7071
y = torch.sin(x)       # y = sin(x)
y.backward()           # 計算 dy/dx
print("y =", y.item())
print("dy/dx =", x.grad.item())  # 應該為 cos(π/4) ≈ 0.7071

輸出(大約):

y = 0.7071
dy/dx = 0.7071

示例 2:tanh(x) 的梯度(導數為 1 - tanh2(x)

x = torch.tensor([1.0], requires_grad=True)
y = torch.tanh(x)
y.backward()
print("y =", y.item())
print("dy/dx =", x.grad.item())  # 約為 1 - tanh(x)^2 ≈ 0.4199

輸出:

y = 0.7616
dy/dx = 0.4199

示例 3:sigmoid(x) 的梯度(導數為 s(x)*(1-s(x))

x = torch.tensor([0.0], requires_grad=True)
y = torch.sigmoid(x)
y.backward()
print("y =", y.item())            # 0.5
print("dy/dx =", x.grad.item())  # 0.5 * (1 - 0.5) = 0.25

示例 4:組合函數 + 多元素求導

x = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
y = torch.sin(x) + torch.cos(x)
z = y.sum()     # 標量才能 .backward()
z.backward()
print("x.grad =", x.grad)

輸出大約為:

x.grad ≈ tensor([cos(0.1) - sin(0.1),cos(0.2) - sin(0.2),cos(0.3) - sin(0.3)])

注意事項

條件說明
requires_grad=True啟用自動求導
輸出必須是標量才能 .backward()否則需手動傳入梯度向量
.grad 只能用于葉子節點變量中間變量(如 y)無 .grad
多次 .backward() 需使用 retain_graph=True否則計算圖會被釋放

示例 5:對 special.erf() 求梯度

import torch.special as Sx = torch.tensor([0.5], requires_grad=True)
y = S.erf(x)
y.backward()
print("dy/dx =", x.grad.item())  # d/dx erf(x) = 2 / sqrt(pi) * exp(-x^2)

梯度驗證工具(額外推薦)

可以使用 torch.autograd.gradcheck 做數值梯度驗證(需要使用 double 類型):

from torch.autograd import gradcheckx = torch.tensor([0.5], dtype=torch.double, requires_grad=True)
test = gradcheck(torch.sin, (x,), eps=1e-6, atol=1e-4)
print("Sin gradcheck:", test)

實際神經網絡中使用示例

下面我們深入介紹 在實際神經網絡中如何使用 PyTorch 的三角函數和特殊函數(如 sin、sigmoid、softplus 等),特別是它們在 loss 函數、自定義激活、周期性建模等場景中的實際用法


場景 1:周期性數據建模(比如時間、角度)用 sin/cos

應用背景:

對于 角度、時間(24小時) 等周期性輸入,用 sin/cos 編碼能避免 0° 和 360° 被認為相差很遠的問題。


示例:用 sin/cos 編碼輸入 + 簡單回歸網絡

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SinCosNet(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(2, 1)  # 輸入是 sin 和 cos,輸出是預測值def forward(self, x_deg):x_rad = x_deg * torch.pi / 180x = torch.stack([torch.sin(x_rad), torch.cos(x_rad)], dim=1)  # [N, 2]out = self.fc(x)return out# 模擬輸入:角度(以度為單位)
x_deg = torch.tensor([0.0, 90.0, 180.0, 270.0]).reshape(-1, 1)
y_true = torch.tensor([0.0, 1.0, 0.0, -1.0]).reshape(-1, 1)  # 例如正弦值作為目標model = SinCosNet()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 前向 + 反向 + 更新
for epoch in range(200):pred = model(x_deg)loss = loss_fn(pred, y_true)optimizer.zero_grad()loss.backward()optimizer.step()print("預測值:", model(x_deg).detach().squeeze())

場景 2:使用 softplus 替代 ReLU 防止死神經

class SoftplusNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(10, 64)self.fc2 = nn.Linear(64, 1)def forward(self, x):x = F.softplus(self.fc1(x))  # 使用 softplus 而不是 ReLUreturn self.fc2(x)

適用場景:

  • 高穩定性訓練
  • 避免 ReLU 導致的神經元“死亡”
  • sigmoid 搭配時,數值更平滑

場景 3:自定義 Loss 函數中使用 torch.atan2, sin, cos

示例:角度差異損失(用于姿態估計、旋轉預測)

def angular_loss(pred_angle, target_angle):"""計算兩個角度之間的最小差(結果范圍在 [-pi, pi])"""diff = pred_angle - target_anglediff = torch.atan2(torch.sin(diff), torch.cos(diff))  # wrap 到 [-pi, pi]return torch.mean(diff ** 2)  # MSE 損失

應用背景:

  • 預測角度(如相機旋轉、姿態)
  • 避免直接使用 pred - target 導致 359° 和 0° 差異非常大

場景 4:使用 sigmoid + torch.special.logit 實現穩定反向函數 Loss

import torch.special as Sdef binary_target_loss(pred, target):pred = torch.clamp(pred, 1e-6, 1 - 1e-6)  # 避免 logit 無窮大loss = torch.mean((S.logit(pred) - S.logit(target)) ** 2)return loss

應用背景:

  • 對于二分類輸出,可以用 logit 空間計算差異,更穩定、更敏感

場景 5:周期函數的擬合(神經網絡擬合 sin 函數)

import matplotlib.pyplot as pltclass SineFitter(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(1, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, 1))def forward(self, x):return self.net(x)x = torch.linspace(-2*torch.pi, 2*torch.pi, 200).unsqueeze(1)
y = torch.sin(x)model = SineFitter()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()for epoch in range(1000):pred = model(x)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()# 可視化
plt.plot(x.detach(), y, label="Ground Truth")
plt.plot(x.detach(), model(x).detach(), label="Predicted")
plt.legend()
plt.title("Sine Function Fitting")
plt.show()

總結:三角函數/特殊函數在神經網絡中的常見用途

應用領域使用的函數說明
角度建模sin, cos, atan2編碼/解碼周期性角度
激活函數softplus, tanh, sigmoid平滑激活、避免死神經
自定義損失函數atan2, logit, erf更穩定地處理周期性/概率性誤差
函數擬合sin, special.*網絡學習任意復雜函數,特別是周期/光波類
數據歸一化或變換clamp, lerp, special.logit控制數據范圍,避免梯度爆炸或損失異常

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/92197.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/92197.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/92197.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

論文閱讀: Mobile Edge Intelligence for Large LanguageModels: A Contemporary Survey

地址:Mobile Edge Intelligence for Large Language Models: A Contemporary Survey 摘要 設備端大型語言模型(LLMs)指在邊緣設備上運行 LLMs,與云端模式相比,其成本效益更高、延遲更低且更能保護隱私,因…

JavaWeb(蒼穹外賣)--學習筆記17(Websocket)

前言 本篇文章是學習B站黑馬程序員蒼穹外賣的學習筆記📑。我的學習路線是Java基礎語法-JavaWeb-做項目,管理端的功能學習完之后,就進入到了用戶端微信小程序的開發,🙌用戶下單并且支付成功后,需要第一時間通…

WebForms 簡介

WebForms 簡介 概述 WebForms 是微軟公司推出的一種用于構建動態網頁和應用程序的技術。自 2002 年推出以來,WebForms 成為 ASP.NET 技術棧中重要的組成部分。它允許開發者以類似于桌面應用程序的方式創建交互式網頁,極大地提高了 Web 開發的效率和體驗。 WebForms 的工作…

vsCode軟件中JS文件中啟用Emmet語法支持(React),外加安裝兩個常用插件

1.點擊vsCode軟件中的設置(就是那個齒輪圖標),如下圖2.在搜索框中輸入emmet,然后點擊添加項,填寫以下值:項:javascript 值:javascriptreact。如下圖3.可以安裝兩個常用插件&#xf…

【第2話:基礎知識】 自動駕駛中的世界坐標系、車輛坐標系、相機坐標系、像素坐標系概念及相互間的轉換公式推導

自動駕駛中的坐標系概念及相互間的轉換公式推導 在自動駕駛系統中,多個坐標系用于描述車輛、傳感器和環境的相對位置。這些坐標系之間的轉換是實現定位、感知和控制的關鍵。下面我將逐步解釋常見坐標系的概念,并推導相互轉換的公式。推導基于標準幾何變換…

深度拆解Dify:開源LLM開發平臺的架構密碼與技術突圍

注:此文章內容均節選自充電了么創始人,CEO兼CTO陳敬雷老師的新書《GPT多模態大模型與AI Agent智能體》(跟我一起學人工智能)【陳敬雷編著】【清華大學出版社】 清華《GPT多模態大模型與AI Agent智能體》書籍配套視頻課程【陳敬雷…

tomcat處理請求流程

1.瀏覽器在請求一個servlet時,會按照HTTP協議構造一個HTTP請求,通過Socket連接發送給Tomcat. 2.Tomcat通過不同的IO模型接收到Socket的字節流數據。 3.接收到數據后,按照HTTP協議解析字節流,得到HttpServletRequest對象 4.通過HttpServletRequest對象,也就是請求信息,找到該請求…

【音視頻】WebRTC 一對一通話-信令服

一、服務器配置 服務器在Ubuntu下搭建,使用C語言實現,由于需要使用WebSocket和前端通訊,同時需要解析JSON格式,因此引入了第三方庫:WebSocketpp和nlonlohmann,這兩個庫的具體配置方式可以參考我之前的博客…

Spring(以 Spring Boot 為核心)與 JDK、Maven、MyBatis-Plus、Tomcat 的版本對應關系及關鍵注意事項

以下是 Spring(以 Spring Boot 為核心)與 JDK、Maven、MyBatis-Plus、Tomcat 的版本對應關系及關鍵注意事項,基于最新技術生態整理: 一、Spring Boot 與 JDK 版本對應 Spring Boot 2.x 系列 最低要求:JDK 1.8推薦版本…

03-基于深度學習的鋼鐵缺陷檢測-yolo11-彩色版界面

目錄 項目介紹🎯 功能展示🌟 一、環境安裝🎆 環境配置說明📘 安裝指南說明🎥 環境安裝教學視頻 🌟 二、系統環境(框架/依賴庫)說明🧱 系統環境與依賴配置說明&#x1f4c…

24. 前端-js框架-Vue

文章目錄前言一、Vue介紹1. 學習導圖2. 特點3. 安裝1. 方式一:獨立版本2. 方式二:CDN方法3. 方式三:NPM方法(推薦使用)4. 搭建Vue的開發環境(大綱)5. 工程結構6. 安裝依賴資源7. 運行項目8. Vue…

Spring 的依賴注入DI是什么?

口語化答案好的,面試官,依賴注入(Dependency Injection,簡稱DI)是Spring框架實現控制反轉(IoC)的主要手段。DI的核心思想是將對象的依賴關系從對象內部抽離出來,通過外部注入的方式提…

匯川PLC通過ModbusTCP轉Profinet網關連接西門子PLC配置案例

本案例是匯川的PLC通過開疆智能研發的ModbusTCP轉Profient網關讀寫西門子1200PLC中的數據。匯川PLC作為ModbusTCP的客戶端網關作為服務器,在Profinet一側網關作為從站接收1200PLC的數據并轉成ModbusTCP協議被匯川PLC讀取。配置過程:匯川PLC配置Modbus TC…

【計組】數據的表示與運算

機器數與真值機器數真值編碼原碼特點表示范圍求真值方法反碼特點補碼特點表示范圍求真值方法移碼特點表示范圍求真值方法相互轉換原碼<->補碼補碼<->移碼原碼<->反碼反碼<->補碼移位左移右移邏輯右移算術右移符號擴展零擴展整數小數符號擴展運算器部件…

視頻水印技術中的變換域嵌入方法對比分析

1. 引言 隨著數字視頻技術的快速發展和網絡傳輸的普及,視頻內容的版權保護問題日益突出。視頻水印技術作為一種有效的版權保護手段,通過在視頻中嵌入不可見或半可見的標識信息,實現對視頻內容的所有權認證、完整性驗證和盜版追蹤。在視頻水印技術的發展歷程中,變換域水印因…

電動汽車電池管理系統設計與實現

電動汽車電池管理系統設計與實現 1. 引言 電動汽車電池管理系統(BMS)是確保電池組安全、高效運行的關鍵組件。本文將詳細介紹一個完整的BMS系統的MATLAB實現,包括狀態估計(SOC/SOH)、參數監測、電池平衡和保護功能。系統設計為模塊化結構,便于擴展和參數調整。 2. 系統架構…

JVM(Java Virtual Machine,Java 虛擬機)超詳細總結

一、JVM的基礎概念1、概述JVM是 Java 程序的運行基礎環境&#xff0c;是 Java 語言實現 “一次編寫&#xff0c;到處運行” &#xff08;"write once , run anywhere. "&#xff09;特性的關鍵組件&#xff0c;具體從以下幾個方面來理解&#xff1a;概念層面JVM 是一…

Balabolka軟件調用微軟離線自然語音合成進行文字轉語音下載安裝教程

首先&#xff0c;需要準備安裝包 Balabolka NaturalVoiceSAPIAdapterMicrosoftWindows.Voice.zh-CN.Xiaoxiao.1_1.0.9.0_x64__cw5n1h2txyewy.Msix MicrosoftWindows.Voice.zh-CN.Yunxi.1_1.0.4.0_x64__cw5n1h2txyewy.Msix借助上面這個工具&#xff1a;NaturalVoiceSAPIAdapter&…

Java修仙之路,十萬字吐血整理全網最完整Java學習筆記(高級篇)

導航&#xff1a; 【Java筆記踩坑匯總】Java基礎JavaWebSSMSpringBootSpringCloud瑞吉外賣/谷粒商城/學成在線設計模式面試題匯總性能調優/架構設計源碼解析 推薦視頻&#xff1a; 黑馬程序員全套Java教程_嗶哩嗶哩 尚硅谷Java入門視頻教程_嗶哩嗶哩 推薦書籍&#xff1a; 《Ja…

接口測試用例和接口測試模板

一、簡介 3天精通Postman接口測試&#xff0c;全套項目實戰教程&#xff01;&#xff01;接口測試區別于傳統意義上的系統測試&#xff0c;下面介紹接口測試用例和接口測試報告。 二、接口測試用例模板 功能測試用例最重要的兩個因素是測試步驟和預期結果&#xff0c;接口測試…