Pytorch 自定義激活函數前向與反向傳播 Tanh

看完這篇,你基本上可以自定義前向與反向傳播,可以自己定義自己的算子

文章目錄

    • Tanh
      • 公式
      • 求導過程
      • 優點:
      • 缺點:
      • 自定義Tanh
      • 與Torch定義的比較
      • 可視化

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F%matplotlib inlineplt.rcParams['figure.figsize'] = (7, 3.5)
plt.rcParams['figure.dpi'] = 150
plt.rcParams['axes.unicode_minus'] = False  #解決坐標軸負數的鉛顯示問題

Tanh

公式

tanh?(x)=sinh?(x)cosh?(x)=ex?e?xex+e?x\tanh(x) = \frac{\sinh(x)}{\cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}tanh(x)=cosh(x)sinh(x)?=ex+e?xex?e?x?

tanh?(x)=2σ(2x)?1\tanh(x) = 2 \sigma(2x) - 1 tanh(x)=2σ(2x)?1

求導過程

tanh?′(x)=(ex?e?xex+e?x)′=[(ex?e?x)(ex+e?x)?1]′=(ex+e?x)(ex+e?x)?1+(ex?e?x)(?1)(ex+e?x)?2(ex?e?x)=1?(ex?e?x)2(ex+e?x)?2=1?(ex?e?x)2(ex+e?x)2=1?tanh?2(x)\begin{aligned} \tanh'(x) =& \big(\frac{e^x - e^{-x}}{e^x + e^{-x}}\big)' \\ =& \big[(e^x - e^{-x})(e^x + e^{-x})^{-1}\big]' \\ =& (e^x + e^{-x})(e^x + e^{-x})^{-1} + (e^x - e^{-x})(-1)(e^x + e^{-x})^{-2} (e^x - e^{-x}) \\ =& 1-(e^x - e^{-x})^2(e^x + e^{-x})^{-2} \\ =& 1 - \frac{(e^x - e^{-x})^2}{(e^x + e^{-x})^2} \\ =& 1- \tanh^2(x) \\ \end{aligned}tanh(x)======?(ex+e?xex?e?x?)[(ex?e?x)(ex+e?x)?1](ex+e?x)(ex+e?x)?1+(ex?e?x)(?1)(ex+e?x)?2(ex?e?x)1?(ex?e?x)2(ex+e?x)?21?(ex+e?x)2(ex?e?x)2?1?tanh2(x)?

優點:

Tanh也稱為雙切正切函數,取值范圍為[-1,1]。tanh在特征相差明顯時的效果會很好,在循環過程中會不斷擴大特征效果。與 sigmoid 的區別是,tanh 是 0 均值的,因此實際應用中 tanh 會比 sigmoid 更好。文獻 [LeCun, Y., et al., Backpropagation applied to handwritten zip code recognition. Neural computation, 1989. 1(4): p. 541-551.] 中提到tanh 網絡的收斂速度要比sigmoid快,因為tanh 的輸出均值比 sigmoid 更接近 0,SGD會更接近 natural gradient[4](一種二次優化技術),從而降低所需的迭代次數。非常優秀,幾乎適合所有的場景

缺點:

  • 該導數在正負飽和區的梯度都會接近于0值,會造成梯度消失。還有其更復雜的冪運算。

自定義Tanh

class SelfDefinedTanh(torch.autograd.Function):@staticmethoddef forward(ctx, inp):exp_x = torch.exp(inp)exp_x_ = torch.exp(-inp)result = torch.divide((exp_x - exp_x_), (exp_x + exp_x_))ctx.save_for_backward(result)return result@staticmethoddef backward(ctx, grad_output):# ctx.saved_tensors is tuple (tensors, grad_fn)result, = ctx.saved_tensorsreturn grad_output * (1 - result.pow(2))class Tanh(nn.Module):def __init__(self):super().__init__()def forward(self, x):out = SelfDefinedTanh.apply(x)return out
def tanh_sigmoid(x):"""according to the equation"""# 2 * torch.sigmoid(2 * x) -1 return torch.mul(torch.sigmoid(torch.mul(x, 2)), 2) - 1

與Torch定義的比較

# self defined
torch.manual_seed(0)tanh = Tanh()  # SelfDefinedTanh
inp = torch.randn(5, requires_grad=True)
out = tanh((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071],grad_fn=<SelfDefinedTanhBackward>)First call
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0178e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])
# self defined tanh_sigmoid
torch.manual_seed(0)inp = torch.randn(5, requires_grad=True)
out = tanh_sigmoid((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071], grad_fn=<SubBackward0>)First call
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0178e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0889e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])
# torch defined
torch.manual_seed(0)inp = torch.randn(5, requires_grad=True)
out = torch.tanh((inp + 1).pow(2))print(f'Out is\n{out}')out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nFirst call\n{inp.grad}")out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
Out is
tensor([1.0000, 0.4615, 0.8831, 0.9855, 0.0071], grad_fn=<TanhBackward>)First call
tensor([ 5.0283e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])Second call
tensor([ 1.0057e-04,  2.2243e+00, -1.0382e+00,  1.8053e-01, -3.3807e-01])Call after zeroing gradients
tensor([ 5.0283e-05,  1.1121e+00, -5.1911e-01,  9.0267e-02, -1.6904e-01])

從上3個結果,可以看出,不管是經過sigmoid來計算,還是公式定義都可以得到一樣的output與gradient。但在輸入的值較大時,torch應該是減去一個小值,使得梯度更小。

可視化

# visualization
inp = torch.arange(-8, 8, 0.1, requires_grad=True)
out = tanh(inp)
out.sum().backward()inp_grad = inp.gradplt.plot(inp.detach().numpy(),out.detach().numpy(),label=r"$\tanh(x)$",alpha=0.7)
plt.plot(inp.detach().numpy(),inp_grad.numpy(),label=r"$\tanh'(x)$",alpha=0.5)
plt.grid()
plt.legend()
plt.show()

請添加圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/260477.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/260477.shtml
英文地址,請注明出處:http://en.pswp.cn/news/260477.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

multi mysql_mysqld_multi 的使用方法

mysqld_multi 的使用方法:官方文檔&#xff1a;https://dev.mysql.com/doc/refman/5.7/en/mysqld-multi.html 【文檔有些問題&#xff0c;按照它的這個配置&#xff0c;mysqld_multi無法關閉實例】mysqld_multi無法關閉實例的解決方法&#xff1a;https://bugs.mysql.com/bug…

vsftp 無法啟動,500 OOPS: bad bool value in config file for: anonymous_enable

朋友的FTP啟動不了&#xff0c;叫我幫他看&#xff0c;啟動時出現以下錯誤信息&#xff1a; 500 OOPS: bad bool value in config file for: anonymous_enable 看似配置文件錯誤&#xff0c;看了一下配置相應的行&#xff1a; anonymous_enableNO 語句沒什么錯誤&#xff0c;不…

HDU ACM 1181 變形課 (廣搜BFS + 動態數組vector)-------第一次使用動態數組vector

http://acm.hdu.edu.cn/showproblem.php?pid1181 題意&#xff1a;給我若干個單詞,若單詞A的結尾與單詞B的開頭相同,則表示A能變成B,判斷能不能從b開頭變成m結尾. 如: big-got-them 第一次使用動態數組vector View Code 1 #include <iostream>2 #include <vector>…

Max Sum 杭電 1003

2019獨角獸企業重金招聘Python工程師標準>>> #題目概述 題目的意思是給你一個數列&#xff0c;找到一個子數列&#xff0c;這個子數列的和是所有子數列中和最大的。 當然把數列的所有數都列出來肯定不現實。 黑黑&#xff0c;不知道正不正確&#xff0c;我是先從第一…

shiro反序列化工具_Apache Shiro 1.2.4反序列化漏洞(CVE-2016-4437)源碼解析

Apache ShiroApache Shiro是一個功能強大且靈活的開源安全框架,主要功能包括用戶認證、授權、會話管理以及加密。在了解該漏洞之前,建議學習下Apache Shiro是怎么使用.debug環境jdk1.8Apache Shiro 1.2.4測試demo本地debug需要以下maven依賴<!-- https://mvnrepository.com/…

window 下的mysql_Windows下MySQL下載安裝、配置與使用

用過MySQL之后&#xff0c;不論容量的話&#xff0c;發現比其他兩個(sql server 、oracle)好用的多&#xff0c;一下子就喜歡上了。下面給那些還不知道怎么弄的童鞋們寫下具體的方法步驟。(我這個寫得有點太詳細了&#xff0c;甚至有些繁瑣&#xff0c;有很多步驟在其他的教程文…

H264視頻通過RTMP直播

http://blog.csdn.net/firehood_/article/details/8783589 前面的文章中提到了通過RTSP&#xff08;Real Time Streaming Protocol&#xff09;的方式來實現視頻的直播&#xff0c;但RTSP方式的一個弊端是如果需要支持客戶端通過網頁來訪問&#xff0c;就需要在在頁面中嵌入一個…

Pytorch 自定義激活函數前向與反向傳播 ReLu系列 含優點與缺點

文章目錄ReLu公式求導過程優點&#xff1a;缺點&#xff1a;自定義ReLu與Torch定義的比較可視化Leaky ReLu PReLu公式求導過程優點&#xff1a;缺點&#xff1a;自定義LeakyReLu與Torch定義的比較可視化自定義PReLuELU公式求導過程優點缺點自定義LeakyReLu與Torch定義的比較可視…

手勢處理

在ios開發中&#xff0c;需用到對于手指的不同操作&#xff0c;以手指點擊為例&#xff1a;分為單指單擊、單指多擊、多指單擊、多指多擊。對于這些事件進行不同的操作處理&#xff0c;由于使用系統自帶的方法通過判斷touches不太容易處理&#xff0c;而且會有事件之間的沖突。…

mybatis select count(*) 一直返回0 mysql_Mybatis教程1:MyBatis快速入門

點擊上方“Java技術前線”&#xff0c;選擇“置頂或者星標”與你一起成長一、Mybatis介紹MyBatis是一個支持普通*SQL*查詢&#xff0c;存儲過程和高級映射的優秀持久層框架。MyBatis消除了幾乎所有的JDBC代碼和參數的手工設置以及對結果集的檢索封裝。MyBatis可以使用簡單的XML…

css預處理器sass使用教程(多圖預警)

css預處理器賦予了css動態語言的特性&#xff0c;如變量、函數、運算、繼承、嵌套等&#xff0c;有助于更好地組織管理樣式文件&#xff0c;以及更高效地開發項目。css預處理器可以更方便的維護和管理css代碼&#xff0c;讓整個網頁變得更加靈活可變。對于預處理器&#xff0c;…

mysql 主從優點_MySql主從配置實踐及其優勢淺談

1、增加兩個MySQL,我將C:\xampp\mysql下的MYSQL復制了一份&#xff0c;放到D:\Mysql2\Mysql5.1修改my.ini(linux下應該是my.cnf)&#xff1a;[client]port 3307[mysqld]port 3307basedirD:/Mysql2/Mysql5.1/mysqldatadirD:/Mysql2/Mysql5.1/mysql/data/之后&#xff0c;再增加…

python 多線程并發編程(生產者、消費者模式),邊讀圖像,邊處理圖像,處理完后保存圖像實現提高處理效率

文章目錄需求實現先導入本次需要用到的包一些輔助函數如下函數是得到指定后綴的文件如下的函數一個是讀圖像&#xff0c;一個是把RGB轉成BGR下面是主要的幾個處理函數在上面幾個函數構建對應的處理函數main函數按順序執行結果需求 本次的需求是邊讀圖像&#xff0c;邊處理圖像…

Sharepoint學習筆記—Site Definition系列-- 2、創建Content Type

Sharepoint本身就是一個豐富的大容器&#xff0c;里面存儲的所有信息我們可以稱其為“內容(Content)”&#xff0c;為了便于管理這些Conent&#xff0c;按照人類的正常邏輯就必然想到的是對此進行“分類”。分類所涉及到的層面又必然包括: 1、分類的標準或特征描述{即&#xf…

arduino byte轉string_Java數組轉List集合的三駕馬車

點擊上方 藍字關注我們來源&#xff1a;blog.csdn.net/x541211190/article/details/79597236前言本文中的代碼命名有的可能不太規范&#xff0c;是因為沒法排版的問題&#xff0c;小仙已經很努力去解決了&#xff0c;希望各位能多多點贊、分享。好了&#xff0c;不多bb了(不要讓…

ES6筆記(4)-- Symbol類型

系列文章 -- ES6筆記系列 Symbol是什么&#xff1f;中文意思是標志、記號&#xff0c;顧名思義&#xff0c;它可以用了做記號。 是的&#xff0c;它是一種標記的方法&#xff0c;被ES6引入作為一種新的數據類型&#xff0c;表示獨一無二的值。 由此&#xff0c;JS的數據類型多了…

mysql類型說明_MYSQL 數據類型說明

MySQL支持大量的列類型&#xff0c;它可以被分為3類&#xff1a;數字類型、日期和時間類型以及字符串(字符)類型。本節首先給出可用類型的一個概述&#xff0c;并且總結每個列類型的存儲需求&#xff0c;然后提供每個類中的類型性質的更詳細的描述。概述有意簡化&#xff0c;更…

LeetCode OJ - Convert Sorted List to Binary Search Tree

題目&#xff1a; Given a singly linked list where elements are sorted in ascending order, convert it to a height balanced BST. 解題思路&#xff1a; 注意是讓構造平衡二叉搜索樹。 每次將鏈表從中間斷開&#xff0c;分成左右兩部分。左邊部分用來構造左子樹&#xff…

手把手教你如下在Linux下如何寫一個C語言代碼,編譯并運行

文章目錄手把手教你如下在Linux下如何寫一個C語言代碼&#xff0c;編譯并運行打開Ubuntu終端創建 helloworld.c編譯C文件手把手教你如下在Linux下如何寫一個C語言代碼&#xff0c;編譯并運行 打開Ubuntu終端 我這里的終端是Windows下的WSL&#xff0c;如果有疑問&#xff0c;…

郵件群發工具的編寫(二)數據的保存

數據的保存與讀取 人類是在不斷探索與改進中進步的 上一篇&#xff0c;郵件群發工具的編寫&#xff08;一&#xff09;郵件地址提取&#xff0c;我們講到了郵箱的提取。 那么這一篇&#xff0c;講一下提取完的郵箱信息的保存和讀取。 首先&#xff0c;我希望對上一篇郵箱提取類…