Pytorch框架中SGD&Adam優化器以及BP反向傳播入門思想及實現

因為這章內容比較多,分開來敘述,前面先講理論后面是講代碼。最重要的是代碼部分,結合代碼去理解思想。

SGD優化器

思想:

根據梯度,控制調整權重的幅度
在這里插入圖片描述

公式:

在這里插入圖片描述
權重(新) = 權重(舊) - 學習率 × 梯度

Adam優化器

思想:

在我看來,Adam優化器重點是能動態調整學習率,防止學習率較大時反復震蕩,比如說當梯度一直為正的時候,權重一直減小,這時直到梯度為負的時候,權重不應該一下子增長太多,而是應該緩慢增長。
在這里插入圖片描述
這個公式可以在Diy框架代碼中找到對應的代碼并進行了解釋。

優勢:

  1. 實現簡單,計算高效,對內存需求少
  2. 超參數具有很好的解釋性,且通常無需調整或僅需很少的微調
  3. 更新的步長能夠被限制在大致的范圍內(初始學習率)
  4. 能夠表現出自動調整學習率
  5. 很適合應用于大規模的數據及參數的場景
  6. 適用于不穩定目標函數
  7. 適用于梯度稀疏或梯度存在很大噪聲的問題

BP反向傳播

傳播過程:

1.根據輸入x和模型當前權重,計算預測值y’

2.根據y’和y使用loss函數計算loss

3.根據loss計算模型權重的梯度

4.使用梯度和學習率,根據優化器調整模型權重

Pytorch框架下實現代碼:

import torch
import torch.nn as nn
import numpy as np
import copy"""
基于pytorch的網絡編寫
實現梯度計算和反向傳播
加入激活函數
"""class TorchModel(nn.Module):def __init__(self, hidden_size):super(TorchModel, self).__init__()self.layer = nn.Linear(hidden_size, hidden_size, bias=False)    #線性層,輸入輸出維度都是hidden_sizeself.activation = torch.sigmoid     #套一個激活函數(不套也可以)self.loss = nn.functional.mse_loss  #loss采用均方差損失#當輸入真實標簽,返回loss值;無真實標簽,返回預測值def forward(self, x, y=None):y_pred = self.layer(x)          #將輸入放入線性層中,獲得預測值y_pred = self.activation(y_pred)    #將預測值激活if y is not None:return self.loss(y_pred, y)else:return y_predx = np.array([1, 2, 3, 4])  #輸入
y = np.array([3, 2, 4, 5])  #預期輸出#torch實驗
torch_model = TorchModel(len(x))    #給torchmodel函數傳入x的維度作為hidden_size
torch_model_w = torch_model.state_dict()["layer.weight"]
print(torch_model_w, "初始化權重")
numpy_model_w = copy.deepcopy(torch_model_w.numpy())    #拷貝一下,用于diy中對初始權重計算torch_x = torch.FloatTensor([x])
torch_y = torch.FloatTensor([y])#torch的前向計算過程,得到loss
torch_loss = torch_model.forward(torch_x, torch_y)
print("torch模型計算loss:", torch_loss)#設定優化器
learning_rate = 0.1
optimizer = torch.optim.SGD(torch_model.parameters(), lr=learning_rate)     #SGD優化器
#torch_model.parameters傳遞模型中所有的參數,也可以只有選擇想傳遞的參數
# optimizer = torch.optim.Adam(torch_model.parameters())		#Adam優化器
optimizer.zero_grad()   #先把優化器歸零#pytorch的反向傳播操作
torch_loss.backward()       #完成梯度計算
print(torch_model.layer.weight.grad, "torch 計算梯度")  #查看某層權重的梯度#torch梯度更新
optimizer.step()#查看更新后權重
update_torch_model_w = torch_model.state_dict()["layer.weight"]
print(update_torch_model_w, "torch更新后權重")

接下來是DIY手動框架實現

"""
手動實現梯度計算和反向傳播
加入激活函數
"""
#自定義模型,接受一個參數矩陣作為入參
class DiyModel:def __init__(self, weight):self.weight = weightdef forward(self, x, y=None):y_pred = np.dot(self.weight, x)y_pred = self.diy_sigmoid(y_pred)if y is not None:return self.diy_mse_loss(y_pred, y)else:return y_pred#sigmoiddef diy_sigmoid(self, x):return 1 / (1 + np.exp(-x))#手動實現mse,均方差lossdef diy_mse_loss(self, y_pred, y_true):return np.sum(np.square(y_pred - y_true)) / len(y_pred)#手動實現梯度計算def calculate_grad(self, y_pred, y_true, x):#前向過程與反向過程對比看# wx = np.dot(self.weight, x)# sigmoid_wx = self.diy_sigmoid(wx)# loss = self.diy_mse_loss(sigmoid_wx, y_true)#反向過程(通過前向的loss來獲得對權重w的導數)# 均方差函數 (y_pred - y_true) ^ 2 / n 的導數 = 2 * (y_pred - y_true) / ngrad_loss_sigmoid_wx = 2/len(x) * (y_pred - y_true)# sigmoid函數 y = 1/(1+e^(-x)) 的導數 = y * (1 - y)grad_sigmoid_wx_wx = y_pred * (1 - y_pred)# wx對w求導 = xgrad_wx_w = x#導數鏈式相乘grad = grad_loss_sigmoid_wx * grad_sigmoid_wx_wxgrad = np.dot(grad.reshape(len(x),1), grad_wx_w.reshape(1,len(x)))  #轉化為矩陣形式相乘return grad#sgd梯度更新
def diy_sgd(grad, weight, learning_rate):return weight - grad * learning_rate#adam梯度更新
def diy_adam(grad, weight):#參數應當放在外面,此處為保持后方代碼整潔簡單實現一步alpha = 1e-3  #學習率beta1 = 0.9   #超參數(推薦)beta2 = 0.999 #超參數(推薦)eps = 1e-8    #超參數t = 0         #初始化mt = 0        #初始化vt = 0        #初始化#開始計算t = t + 1gt = gradmt = beta1 * mt + (1 - beta1) * gt      #前面的mt累積在后面的mt中,使前面的梯度占高比例,本輪的占低比例vt = beta2 * vt + (1 - beta2) * gt ** 2mth = mt / (1 - beta1 ** t)     #分母不斷增大,mth不斷減小vth = vt / (1 - beta2 ** t)weight = weight - alpha * mth / (np.sqrt(vth) + eps)    #alpha學習率動態調整return weight#手動實現loss計算
diy_model = DiyModel(numpy_model_w)
diy_loss = diy_model.forward(x, y)
print("diy模型計算loss:", diy_loss)#手動實現反向傳播
grad = diy_model.calculate_grad(diy_model.forward(x), y, x)
print(grad, "diy 計算梯度")     #梯度的維度應與矩陣維度一致#手動梯度更新
diy_update_w = diy_sgd(grad, numpy_model_w, learning_rate)	#grad優化器
# diy_update_w = diy_adam(grad, numpy_model_w)		#adam優化器
print(diy_update_w, "diy更新權重")

運行結果:

SGD優化器情況下:

在這里插入圖片描述

Adam優化器的情況下:

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389574.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389574.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389574.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

朱曄和你聊Spring系列S1E3:Spring咖啡罐里的豆子

標題中的咖啡罐指的是Spring容器,容器里裝的當然就是被稱作Bean的豆子。本文我們會以一個最基本的例子來熟悉Spring的容器管理和擴展點。閱讀PDF版本 為什么要讓容器來管理對象? 首先我們來聊聊這個問題,為什么我們要用Spring來管理對象&…

ab實驗置信度_為什么您的Ab測試需要置信區間

ab實驗置信度by Alos Bissuel, Vincent Grosbois and Benjamin HeymannAlosBissuel,Vincent Grosbois和Benjamin Heymann撰寫 The recent media debate on COVID-19 drugs is a unique occasion to discuss why decision making in an uncertain environment is a …

基于Pytorch的NLP入門任務思想及代碼實現:判斷文本中是否出現指定字

今天學了第一個基于Pytorch框架的NLP任務: 判斷文本中是否出現指定字 思路:(注意:這是基于字的算法) 任務:判斷文本中是否出現“xyz”,出現其中之一即可 訓練部分: 一&#xff…

erlang下lists模塊sort(排序)方法源碼解析(二)

上接erlang下lists模塊sort(排序)方法源碼解析(一),到目前為止,list列表已經被分割成N個列表,而且每個列表的元素是有序的(從大到小) 下面我們重點來看看mergel和rmergel模塊,因為我…

洛谷P4841 城市規劃(多項式求逆)

傳送門 這題太珂怕了……如果是我的話完全想不出來…… 題解 1 //minamoto2 #include<iostream>3 #include<cstdio>4 #include<algorithm>5 #define ll long long6 #define swap(x,y) (x^y,y^x,x^y)7 #define mul(x,y) (1ll*(x)*(y)%P)8 #define add(x,y) (x…

支撐阻力指標_使用k表示聚類以創建支撐和阻力

支撐阻力指標Note from Towards Data Science’s editors: While we allow independent authors to publish articles in accordance with our rules and guidelines, we do not endorse each author’s contribution. You should not rely on an author’s works without seek…

高版本(3.9版本)python在anaconda安裝opencv庫及skimage庫(scikit_image庫)諸多問題解決辦法

今天開始CV方向的學習&#xff0c;然而剛拿到基礎代碼的時候發現 from skimage.color import rgb2gray 和 import cv2標紅&#xff08;這里是因為我已經配置成功了&#xff0c;所以沒有紅標&#xff09;&#xff0c;我以為是單純兩個庫沒有下載&#xff0c;去pycharm中下載ski…

python 實現斐波那契數列

# coding:utf8 __author__ blueslidef fun(arg1,arg2,stop):if arg10:print(arg1,arg2)arg3 arg1arg2print(arg3)if arg3<stop:arg3 fun(arg2,arg3,stop)fun(0,1,100)轉載于:https://www.cnblogs.com/bluesl/p/9079705.html

單機安裝ZooKeeper

2019獨角獸企業重金招聘Python工程師標準>>> zookeeper下載、安裝以及配置環境變量 本節介紹單機的zookeeper安裝&#xff0c;官方下載地址如下&#xff1a; https://archive.apache.org/dist/zookeeper/ 我這里使用的是3.4.11版本&#xff0c;所以找到相應的版本點…

均線交易策略的回測 r_使用r創建交易策略并進行回測

均線交易策略的回測 rR Programming language is an open-source software developed by statisticians and it is widely used among Data Miners for developing Data Analysis. R can be best programmed and developed in RStudio which is an IDE (Integrated Development…

opencv入門課程:彩色圖像灰度化和二值化(采用skimage庫和opencv庫兩種方法)

用最簡單的辦法實現彩色圖像灰度化和二值化&#xff1a; 首先采用skimage庫&#xff08;skimage庫現在在scikit_image庫中&#xff09;實現&#xff1a; from skimage.color import rgb2gray import numpy as np import matplotlib.pyplot as plt""" skimage庫…

SVN中Revert changes from this revision 跟Revert to this revision

譬如有個文件&#xff0c;有十個版本&#xff0c;假定版本號是1&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;5&#xff0c;6&#xff0c;7&#xff0c;8&#xff0c;9&#xff0c;10。Revert to this revision&#xff1a; 如果是在版本6這里點擊“Revert to this rev…

歸 [拾葉集]

歸 心歸故鄉 想象行走在 鄉間恬靜小路上 讓那些疲憊的夢 都隨風飛散吧&#xff01; 不去想那些世俗 人來人往 熙熙攘攘 秋日午后 陽光下 細數落葉 來日方長 世上的路 有詩人、浪子 歌詠吟唱 世上的人 在欲望、信仰中 彷徨 彷徨又迷茫 親愛的人兒 快結束那 無休止的獨自流浪 莫要…

instagram分析以預測與安的限量版運動鞋轉售價格

Being a sneakerhead is a culture on its own and has its own industry. Every month Biggest brands introduce few select Limited Edition Sneakers which are sold in the markets according to Lottery System called ‘Raffle’. Which have created a new market of i…

opencv:用最鄰近插值和雙線性插值法實現上采樣(放大圖像)與下采樣(縮小圖像)

上采樣與下采樣 概念&#xff1a; 上采樣&#xff1a; 放大圖像&#xff08;或稱為上采樣&#xff08;upsampling&#xff09;或圖像插值&#xff08;interpolating&#xff09;&#xff09;的主要目的 是放大原圖像,從而可以顯示在更高分辨率的顯示設備上。 下采樣&#xff…

CSS魔法堂:那個被我們忽略的outline

前言 在CSS魔法堂&#xff1a;改變單選框顏色就這么吹毛求疵&#xff01;中我們要模擬原生單選框通過Tab鍵獲得焦點的效果&#xff0c;這里涉及到一個常常被忽略的屬性——outline&#xff0c;由于之前對其印象確實有些模糊&#xff0c;于是本文打算對其進行稍微深入的研究^_^ …

初創公司怎么做銷售數據分析_初創公司與Faang公司的數據科學

初創公司怎么做銷售數據分析介紹 (Introduction) In an increasingly technological world, data scientist and analyst roles have emerged, with responsibilities ranging from optimizing Yelp ratings to filtering Amazon recommendations and designing Facebook featu…

opencv:灰色和彩色圖像的像素直方圖及直方圖均值化的實現與展示

直方圖及直方圖均值化的理論&#xff0c;實現及展示 直方圖&#xff1a; 首先&#xff0c;我們來看看什么是直方圖&#xff1a; 理論概念&#xff1a; 在圖像處理中&#xff0c;經常用到直方圖&#xff0c;如顏色直方圖、灰度直方圖等。 圖像的灰度直方圖就描述了圖像中灰度分…

mysql.sock問題

Cant connect to local MySQL server through socket /tmp/mysql.sock 上述提示可能在啟動mysql時遇到&#xff0c;即在/tmp/mysql.sock位置找不到所需要的mysql.sock文件&#xff0c;主要是由于my.cnf文件里對mysql.sock的位置設定導致。 mysql.sock默認的是在/var/lib/mysql,…

交換機的基本原理配置(一)

1、配置主機名 在全局模式下輸入hostname 名字 然后回車即可立馬生效&#xff08;在生產環境交換機必須有自己唯一的名字&#xff09; Switch(config)#hostname jsh-sw1jsh-sw1(config)#2、顯示系統OS名稱及版本信息 特權模式下&#xff0c;輸入命令 show version Switch#show …