前饋神經網絡優化器

引用的知乎上的文章內容,現在有些地方還不太明白,留待以后查看。

import math
import numpy as np
import matplotlib.pyplot as pltRATIO = 3   # 橢圓的長寬比
LIMIT = 1.2 # 圖像的坐標軸范圍class PlotComparaison(object):"""多種優化器來優化函數 x1^2 + x2^2 * RATIO^2.每次參數改變為(d1, d2).梯度為(dx1, dx2)t+1次迭代,標準GD,d1_{t+1} = - eta * dx1d2_{t+1} = - eta * dx2帶Momentum,d1_{t+1} = eta * (mu * d1_t - dx1_{t+1})d2_{t+1} = eta * (mu * d2_t - dx2_{t+1})    帶Nesterov Momentum,d1_{t+1} = eta * (mu * d1_t - dx1^{nag}_{t+1})d2_{t+1} = eta * (mu * d2_t - dx2^{nag}_{t+1})其中(dx1^{nag}, dx2^{nag})為(x1 + eta * mu * d1_t, x2 + eta * mu * d2_t)處的梯度RMSProp,w1_{t+1} = beta2 * w1_t + (1 - beta2) * dx1_t^2w2_{t+1} = beta2 * w2_t + (1 - beta2) * dx2_t^2d1_{t+1} = - eta * dx1_t / (sqrt(w1_{t+1}) + epsilon)d2_{t+1} = - eta * dx2_t / (sqrt(w2_{t+1}) + epsilon)Adam,每次參數改變為(d1, d2)v1_{t+1} = beta1 * v1_t + (1 - beta1) * dx1_tv2_{t+1} = beta1 * v2_t + (1 - beta1) * dx2_tw1_{t+1} = beta2 * w1_t + (1 - beta2) * dx1_t^2w2_{t+1} = beta2 * w2_t + (1 - beta2) * dx2_t^2v1_corrected = v1_{t+1} / (1 - beta1^{t+1})v2_corrected = v2_{t+1} / (1 - beta1^{t+1})w1_corrected = w1_{t+1} / (1 - beta2^{t+1})w2_corrected = w2_{t+1} / (1 - beta2^{t+1})d1_{t+1} = - eta * v1_corrected / (sqrt(w1_corrected) + epsilon)d2_{t+1} = - eta * v2_corrected / (sqrt(w2_corrected) + epsilon)"""def __init__(self, eta=0.1, mu=0.9, beta1=0.9, beta2=0.99, epsilon=1e-10, angles=None, contour_values=None,stop_condition=1e-4):# 全部算法的學習率self.eta = eta# 啟發式學習的終止條件self.stop_condition = stop_condition# Nesterov Momentum超參數self.mu = mu# RMSProp超參數self.beta1 = beta1self.beta2 = beta2self.epsilon = epsilon# 用正態分布隨機生成初始點self.x1_init, self.x2_init = np.random.uniform(LIMIT / 2, LIMIT), np.random.uniform(LIMIT / 2, LIMIT) / RATIOself.x1, self.x2 = self.x1_init, self.x2_init# 等高線相關if angles == None:angles = np.arange(0, 2 * math.pi, 0.01)self.angles = anglesif contour_values == None:contour_values = [0.25 * i for i in range(1, 5)]self.contour_values = contour_valuessetattr(self, "contour_colors", None)def draw_common(self, title):"""畫等高線,最優點和設置圖片各種屬性"""# 坐標軸尺度一致plt.gca().set_aspect(1)# 根據等高線的值生成坐標和顏色# 海拔越高顏色越深num_contour = len(self.contour_values)if not self.contour_colors:self.contour_colors = [(i / num_contour, i / num_contour, i / num_contour) for i in range(num_contour)]self.contour_colors.reverse()self.contours = [[list(map(lambda x: math.sin(x) * math.sqrt(val), self.angles)),list(map(lambda x: math.cos(x) * math.sqrt(val) / RATIO, self.angles))]for val in self.contour_values]# 畫等高線for i in range(num_contour):plt.plot(self.contours[i][0],self.contours[i][1],linewidth=1,linestyle='-',color=self.contour_colors[i],label="y={}".format(round(self.contour_values[i], 2)))# 畫最優點plt.text(0, 0, 'x*')# 圖片標題plt.title(title)# 設置坐標軸名字和范圍plt.xlabel("x1")plt.ylabel("x2")plt.xlim((-LIMIT, LIMIT))plt.ylim((-LIMIT, LIMIT))# 顯示圖例plt.legend(loc=1)def forward_gd(self):"""SGD一次迭代"""self.d1 = -self.eta * self.dx1self.d2 = -self.eta * self.dx2self.ite += 1def draw_gd(self, num_ite=5):"""畫基礎SGD的迭代優化.包括每次迭代的點,以及表示每次迭代改變的箭頭"""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)# 畫每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_gd()# 迭代的箭頭plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐標為({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的點if self.loss < self.stop_condition:breakdef forward_momentum(self):"""帶Momentum的SGD一次迭代"""self.d1 = self.eta * (self.mu * self.d1_pre - self.dx1)self.d2 = self.eta * (self.mu * self.d2_pre - self.dx2)self.ite += 1self.d1_pre, self.d2_pre = self.d1, self.d2def draw_momentum(self, num_ite=5):"""畫帶Momentum的迭代優化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "d1_pre", 0)setattr(self, "d2_pre", 0)# 畫每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_momentum()# 迭代的箭頭plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐標為({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的點if self.loss < self.stop_condition:breakdef forward_nag(self):"""Nesterov Accelerated的SGD一次迭代"""self.d1 = self.eta * (self.mu * self.d1_pre - self.dx1_nag)self.d2 = self.eta * (self.mu * self.d2_pre - self.dx2_nag)self.ite += 1self.d1_pre, self.d2_pre = self.d1, self.d2def draw_nag(self, num_ite=5):"""畫Nesterov Accelerated的迭代優化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "d1_pre", 0)setattr(self, "d2_pre", 0)# 畫每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_nag()# 迭代的箭頭plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐標為({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的點if self.loss < self.stop_condition:breakdef forward_rmsprop(self):"""RMSProp一次迭代"""w1 = self.beta2 * self.w1_pre + (1 - self.beta2) * (self.dx1 ** 2)w2 = self.beta2 * self.w2_pre + (1 - self.beta2) * (self.dx2 ** 2)self.ite += 1self.w1_pre, self.w2_pre = w1, w2self.d1 = -self.eta * self.dx1 / (math.sqrt(w1) + self.epsilon)self.d2 = -self.eta * self.dx2 / (math.sqrt(w2) + self.epsilon)def draw_rmsprop(self, num_ite=5):"""畫RMSProp的迭代優化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "w1_pre", 0)setattr(self, "w2_pre", 0)# 畫每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_rmsprop()# 迭代的箭頭plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐標為({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的點if self.loss < self.stop_condition:breakdef forward_adam(self):"""AdaM一次迭代"""w1 = self.beta2 * self.w1_pre + (1 - self.beta2) * (self.dx1 ** 2)w2 = self.beta2 * self.w2_pre + (1 - self.beta2) * (self.dx2 ** 2)v1 = self.beta1 * self.v1_pre + (1 - self.beta1) * self.dx1v2 = self.beta1 * self.v2_pre + (1 - self.beta1) * self.dx2self.ite += 1self.v1_pre, self.v2_pre = v1, v2self.w1_pre, self.w2_pre = w1, w2v1_corr = v1 / (1 - math.pow(self.beta1, self.ite))v2_corr = v2 / (1 - math.pow(self.beta1, self.ite))w1_corr = w1 / (1 - math.pow(self.beta2, self.ite))w2_corr = w2 / (1 - math.pow(self.beta2, self.ite))self.d1 = -self.eta * v1_corr / (math.sqrt(w1_corr) + self.epsilon)self.d2 = -self.eta * v2_corr / (math.sqrt(w2_corr) + self.epsilon)def draw_adam(self, num_ite=5):"""畫AdaM的迭代優化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "w1_pre", 0)setattr(self, "w2_pre", 0)setattr(self, "v1_pre", 0)setattr(self, "v2_pre", 0)# 畫每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_adam()# 迭代的箭頭plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐標為({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的點if self.loss < self.stop_condition:break@propertydef dx1(self, x1=None):return self.x1 * 2@propertydef dx2(self):return self.x2 * 2 * (RATIO ** 2)@propertydef dx1_nag(self, x1=None):return (self.x1 + self.eta * self.mu * self.d1_pre) * 2@propertydef dx2_nag(self):return (self.x2 + self.eta * self.mu * self.d2_pre) * 2 * (RATIO ** 2)@propertydef loss(self):return self.x1 ** 2 + (RATIO * self.x2) ** 2def rms(self, x):return math.sqrt(x + self.epsilon)def show(self):# 設置圖片大小plt.figure(figsize=(20, 20))# 展示plt.show()def main(num_ite=15):xixi = PlotComparaison()print("起始點為({}, {})".format(xixi.x1_init, xixi.x2_init))xixi.draw_momentum(num_ite)xixi.draw_common("Optimize x1^2+x2^2*{} Using SGD With Momentum".format(RATIO ** 2))xixi.show()xixi.draw_rmsprop(num_ite)xixi.draw_common("Optimize x1^2+x2^2*{} Using RMSProp".format(RATIO ** 2))xixi.show()xixi.draw_adam(num_ite)xixi.draw_common("Optimize x1^2+x2^2*{} Using AdaM".format(RATIO ** 2))xixi.show()main()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/41354.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/41354.shtml
英文地址,請注明出處:http://en.pswp.cn/news/41354.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python爬蟲的應用場景與技術難點:如何提高數據抓取的效率與準確性

作為專業爬蟲程序員&#xff0c;我們在數據抓取過程中常常面臨效率低下和準確性不高的問題。但不用擔心&#xff01;本文將與大家分享Python爬蟲的應用場景與技術難點&#xff0c;并提供一些實際操作價值的解決方案。讓我們一起來探索如何提高數據抓取的效率與準確性吧&#xf…

python3實現線性規劃求解

Background 對于數學規劃問題&#xff0c;有很多的實現。MatlabYALMIPCPLEX這個組合應該是比較主流的&#xff0c;尤其是在電力相關系統中占據著比較重要的地位。MATLAB是一個強大的數值計算工具&#xff0c;用于數學建模、算法開發和數據分析。Yalmip是一個MATLAB工具箱&#…

MongoDB:MySQL,Redis,ES,MongoDB的應用場景

簡單明了說明MySQL,ES,MongoDB的各自特點,應用場景,以及MongoDB如何使用的第一章節. 一. SQL與NoSQL SQL被稱為結構化查詢語言.是傳統意義上的數據庫,數據之間存在很明確的關聯關系,例如主外鍵關聯,這種結構可以確保數據的完整性(數據沒有缺失并且正確).但是正因為這種嚴密的結…

神經網絡基礎-神經網絡補充概念-34-正則化

概念 正則化是一種用于控制模型復雜度并防止過擬合的技術&#xff0c;在機器學習和深度學習中廣泛應用。它通過在損失函數中添加一項懲罰項來限制模型的參數&#xff0c;從而使模型更傾向于選擇簡單的參數配置。 理解 L1 正則化&#xff08;L1 Regularization&#xff09;&a…

數據分析 | Boosting與Bagging的區別

Boosting與Bagging的區別 Bagging思想專注于降低方差&#xff0c;操作起來較為簡單&#xff0c;而Boosting思想專注于降低整體偏差來降低泛化誤差&#xff0c;在模型效果方面的突出表現制霸整個弱分類器集成的領域。具體區別體現在如下五點&#xff1a; 弱評估器&#xff1a;Ba…

vb數控加工技術教學素材資源庫的設計和構建

摘 要 20世紀以來,社會生產力迅速發展,科學技術突飛猛進,人們進行信息交流的深度與廣度不斷增加,信息量急劇增長,傳統的信息處理與決策的手段已不能適應社會的需要,信息的重要性和信息處理問題的緊迫性空前提高了,面對著日益復雜和不斷發展,變化的社會環境,特別是企業…

Windows上使用dump文件調試

dump文件 dump文件記錄當前程序運行某一時刻的信息&#xff0c;包括內存&#xff0c;線程&#xff0c;線程棧&#xff0c;變量等等&#xff0c;相當于調試程序時運行到某個斷點上&#xff0c;把程序運行的信息記錄下來。可以通過Windbg打開dump&#xff0c;查看程序運行的變量…

mysql 修改存儲路徑,重啟失敗授權

目錄 停掉mysql修改mysql 配置文件my.cnf目錄授權重啟mysql 停掉mysql 修改mysql 配置文件my.cnf 更改mysql 存儲位置 到/data/mysql_data目錄下&#xff1a; datadir/data/mysql/mysql_data/socket/data/mysql/mysql_data/mysql.sockmysql 默認路么徑在 /var/lib/mysql/ 防止…

go_并發編程(1)

go并發編程 一、 并發介紹1&#xff0c;進程和線程2&#xff0c;并發和并行3&#xff0c;協程和線程4&#xff0c;goroutine 二、 Goroutine1&#xff0c;使用goroutine1&#xff09;啟動單個goroutine2&#xff09;啟動多個goroutine 2&#xff0c;goroutine與線程3&#xff0…

在 React 中獲取數據的6種方法

一、前言 數據獲取是任何 react 應用程序的核心方面。對于 React 開發人員來說&#xff0c;了解不同的數據獲取方法以及哪些用例最適合他們很重要。 但首先&#xff0c;讓我們了解 JavaScript Promises。 簡而言之&#xff0c;promise 是一個 JavaScript 對象&#xff0c;它將…

Python Web:Django、Flask和FastAPI框架對比

原文&#xff1a;百度安全驗證 Django、Flask和FastAPI是Python Web框架中的三個主要代表。這些框架都有著各自的優點和缺點&#xff0c;適合不同類型和規模的應用程序。 1. Django&#xff1a; Django是一個全功能的Web框架&#xff0c;它提供了很多內置的應用程序和工具&am…

排序+運算>直接運算的效率的原因分析

大家好,我是愛編程的喵喵。雙985碩士畢業,現擔任全棧工程師一職,熱衷于將數據思維應用到工作與生活中。從事機器學習以及相關的前后端開發工作。曾在阿里云、科大訊飛、CCF等比賽獲得多次Top名次。現為CSDN博客專家、人工智能領域優質創作者。喜歡通過博客創作的方式對所學的…

ADIS16470和ADIS16500從到手到讀出完整數據,附例程

由于保密原因&#xff0c;不能上傳我這邊的代碼&#xff0c;我所用的開發環境是IAR&#xff0c; 下邊轉載別的博主的文章&#xff0c;他用的是MDK 下文的博主給了你一個很好的思路&#xff0c;特此提出表揚 最下方是我做的一些手冊批注&#xff0c;方便大家了解這個東西 原文鏈…

如何利用 ChatGPT 進行自動數據清理和預處理

推薦&#xff1a;使用 NSDT場景編輯器助你快速搭建可二次編輯的3D應用場景 ChatGPT 已經成為一把可用于多種應用的瑞士軍刀&#xff0c;并且有大量的空間將 ChatGPT 集成到數據科學工作流程中。 如果您曾經在真實數據集上訓練過機器學習模型&#xff0c;您就會知道數據清理和預…

有沒有比讀寫鎖更快的鎖

在之前的文章中&#xff0c;我們介紹了讀寫鎖&#xff0c;學習完之后你應該已經知道了讀寫鎖允許多個線程同時訪問共享變量&#xff0c;適用于讀多寫少的場景。那么在讀多寫少的場景中還有沒有更快的技術方案呢&#xff1f;還真有&#xff0c;在Java1.8這個版本里提供了一種叫S…

Docker安裝Skywalking APM分布式追蹤系統

Skywalking是一個應用性能管理(APM)系統&#xff0c;具有服務器性能監測&#xff0c;應用程序間調用關系及性能監測等功能&#xff0c;Skywalking分為服務端、管理界面、以及嵌入到程序中的探針部分&#xff0c;由程序中的探針采集各類調用數據發送給服務端保存&#xff0c;在管…

novnc 和 vnc server 如何實現通信?原理?

參考&#xff1a;https://www.codenong.com/js0f3b351a156c/

隨機微分方程

應用隨機過程|第7章 隨機微分方程 見知乎&#xff1a;https://zhuanlan.zhihu.com/p/348366892?utm_sourceqq&utm_mediumsocial&utm_oi1315073218793488384

復習3-5天【80天學習完《深入理解計算機系統》】第七天

專注 效率 記憶 預習 筆記 復習 做題 歡迎觀看我的博客&#xff0c;如有問題交流&#xff0c;歡迎評論區留言&#xff0c;一定盡快回復&#xff01;&#xff08;大家可以去看我的專欄&#xff0c;是所有文章的目錄&#xff09;   文章字體風格&#xff1a; 紅色文字表示&#…

Linux與bash(基礎內容一)

一、常見的linux命令&#xff1a; 1、文件&#xff1a; &#xff08;1&#xff09;常見的文件命令&#xff1a; &#xff08;2&#xff09;文件屬性&#xff1a; &#xff08;3&#xff09;修改文件屬性&#xff1a; 查看文件的屬性&#xff1a; ls -l 查看文件的屬性 ls …