神經網絡基礎-神經網絡補充概念-23-神經網絡的梯度下降法

概念

神經網絡的梯度下降法是訓練神經網絡的核心優化算法之一。它通過調整神經網絡的權重和偏差,以最小化損失函數,從而使神經網絡能夠逐漸逼近目標函數的最優值。

步驟

1損失函數(Loss Function):
首先,我們定義一個損失函數,用來衡量神經網絡預測值與真實標簽之間的差距。常見的損失函數包括均方誤差(Mean Squared Error)和交叉熵(Cross-Entropy)等。

2初始化參數:
在訓練之前,需要隨機初始化神經網絡的權重和偏差。

4前向傳播:
通過前向傳播計算神經網絡的輸出,根據輸入數據、權重和偏差計算每一層的激活值和預測值。

5計算損失:
使用損失函數計算預測值與真實標簽之間的差距。

6反向傳播:
反向傳播是梯度下降法的關鍵步驟。它從輸出層開始,計算每一層的誤差梯度,然后根據鏈式法則將梯度傳遞回每一層。這樣,可以得到關于權重和偏差的梯度信息,指導參數的更新。

7更新參數:
使用梯度信息,按照一定的學習率(learning rate)更新神經網絡的權重和偏差。通常采用如下更新規則:新權重 = 舊權重 - 學習率 × 梯度。

8重復迭代:
重復執行前向傳播、計算損失、反向傳播和參數更新步驟,直到損失函數收斂或達到預定的迭代次數。

9評估模型:
在訓練過程中,可以周期性地評估模型在驗證集上的性能,以防止過擬合并選擇合適的模型。

python實現

import numpy as np# 定義 sigmoid 激活函數及其導數
def sigmoid(x):return 1 / (1 + np.exp(-x))def sigmoid_derivative(x):return x * (1 - x)# 設置隨機種子以保證可重復性
np.random.seed(42)# 生成模擬數據
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])# 初始化權重和偏差
input_size = 2
output_size = 1
hidden_size = 4weights_input_hidden = np.random.uniform(-1, 1, (input_size, hidden_size))
bias_hidden = np.zeros((1, hidden_size))weights_hidden_output = np.random.uniform(-1, 1, (hidden_size, output_size))
bias_output = np.zeros((1, output_size))# 設置學習率和迭代次數
learning_rate = 0.1
epochs = 10000# 訓練神經網絡
for epoch in range(epochs):# 前向傳播hidden_input = np.dot(X, weights_input_hidden) + bias_hiddenhidden_output = sigmoid(hidden_input)final_input = np.dot(hidden_output, weights_hidden_output) + bias_outputfinal_output = sigmoid(final_input)# 計算損失loss = np.mean(0.5 * (y - final_output) ** 2)# 反向傳播d_output = (y - final_output) * sigmoid_derivative(final_output)d_hidden = d_output.dot(weights_hidden_output.T) * sigmoid_derivative(hidden_output)# 更新權重和偏差weights_hidden_output += hidden_output.T.dot(d_output) * learning_ratebias_output += np.sum(d_output, axis=0, keepdims=True) * learning_rateweights_input_hidden += X.T.dot(d_hidden) * learning_ratebias_hidden += np.sum(d_hidden, axis=0, keepdims=True) * learning_rateif epoch % 1000 == 0:print(f'Epoch {epoch}, Loss: {loss}')# 打印訓練后的權重和偏差
print('Final weights_input_hidden:', weights_input_hidden)
print('Final bias_hidden:', bias_hidden)
print('Final weights_hidden_output:', weights_hidden_output)
print('Final bias_output:', bias_output)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/40754.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/40754.shtml
英文地址,請注明出處:http://en.pswp.cn/news/40754.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Springboot多路數據源

1、多路數據源配置 (1)SpringBootMyBatis-PlusOracle實現多數據源配置 https://blog.csdn.net/weixin_44812604/article/details/127386828 (2)SpringBootMybatis搭建Oracle多數據源配置簡述 https://blog.csdn.net/HJW_233/arti…

網絡安全 Day29-運維安全項目-iptables防火墻

iptables防火墻 1. 防火墻概述2. 防火墻2.1 防火墻種類及使用說明2.2 必須熟悉的名詞2.3 iptables 執行過程※※※※※2.4 表與鏈※※※※※2.4.1 簡介2.4.2 每個表說明2.4.2.1 filter表 :star::star::star::star::star:2.4.2.2 nat表 2.5 環境準備及命令2.6 案例01&#xff1a…

神經網絡基礎-神經網絡補充概念-31-參數與超參數

概念 參數(Parameters): 參數是模型內部學習的變量,它們通過訓練過程自動調整以最小化損失函數。在神經網絡中,參數通常是連接權重(weights)和偏置(biases),…

ChatGLM2-6B安裝部署(詳盡版)

1、環境部署 安裝Anaconda3 安裝GIT 安裝GUDA 11.8 安裝NVIDIA 圖形化驅動 522.25版本,如果電腦本身是更高版本則不用更新 1.1、檢查CUDA 運行cmd或者Anaconda,運行以下命令 nvidia-smi CUDA Version是版本信息,Dricer Version是圖形化…

LeetCode 160.相交鏈表

文章目錄 💡題目分析💡解題思路🚩步驟一:找尾節點🚩步驟二:判斷尾節點是否相等🚩步驟三:找交點🍄思路1🍄思路2 🔔接口源碼 題目鏈接👉…

Ubuntu下mysql安裝及遠程連接支持配置

1.安裝 下載mysql-server(必須加sudo) sudo apt update sudo apt install mysql-server 查看mysql的狀態 sudo service mysql status 通過如下命令開啟mysql sudo service mysql start 2.配置 第一次安裝mysql后,為root設置一個密碼 …

Linux -- 進階 Autofs應用 : 光驅自動掛載 操作詳解

服務端自動掛載光驅 第一步 : 關閉安全軟件,安裝自動掛載軟件 [rootserver ~]# setenforce 0 [rootserver ~]# systemctl stop firewalld [rootserver ~]# yum install autofs -y 第二步 : 修改 autofs 主配置文件, 計劃掛載光…

C++之map的emplace與pair插入鍵值對用例(一百七十四)

簡介: CSDN博客專家,專注Android/Linux系統,分享多mic語音方案、音視頻、編解碼等技術,與大家一起成長! 優質專欄:Audio工程師進階系列【原創干貨持續更新中……】🚀 人生格言: 人生…

213、仿真-基于51單片機智能電表電能表用電量電費報警Proteus仿真設計(程序+Proteus仿真+原理圖+配套資料等)

畢設幫助、開題指導、技術解答(有償)見文未 目錄 一、硬件設計 二、設計功能 三、Proteus仿真圖 四、原理圖 五、程序源碼 資料包括: 需要完整的資料可以點擊下面的名片加下我,找我要資源壓縮包的百度網盤下載地址及提取碼。 方案選擇 單片機的選…

uniapp tabbar 瀏覽器調試顯示 真機不顯示

解決方案,把tabBar里面的單位全改為px,rpx是不會顯示的! 注意了,改完一定要重新運行,不然無效,坑爹 "tabBar": {"borderStyle": "black","selectedColor": &quo…

java-JVM內存區域JVM運行時內存

一. JVM 內存區域 JVM 內存區域主要分為線程私有區域【程序計數器、虛擬機棧、本地方法區】、線程共享區域【JAVA 堆、方法區】、直接內存。線程私有數據區域生命周期與線程相同, 依賴用戶線程的啟動/結束 而 創建/銷毀(在 HotspotVM 內, 每個線程都與操作系統的本地線程直接映…

SwiftUI 動畫進階:實現行星繞圓周軌道運動

0. 概覽 SwiftUI 動畫對于優秀 App 可以說是布帛菽粟。利用美妙的動畫我們不僅可以活躍界面元素,更可以單獨打造出一整套生動有機的世界,激活無限可能。 如上圖所示,我們用動畫粗略實現了一個小太陽系:8大行星圍繞太陽旋轉,而衛星們圍繞各個行星旋轉。 在本篇博文中,您將…

vue3實現防抖、單頁面引入、全局引入、全局掛載

文章目錄 代碼實現單頁面引入全局引入使用 代碼實現 const debounce (fn: any, delay: number) > {let timer: any undefined;return (item: any) > {if (timer) clearTimeout(timer);timer setTimeout(() > fn(item), delay);} };export default debounce;單頁面…

Python + Selenium 處理瀏覽器Cookie

工作中遇到這么一個場景:自動化測試登錄的時候需要輸入動態驗證碼,由于某些原因,需要從一個已登錄的機器上,復制cookie過來,到自動化這邊繞過登錄。 瀏覽器的F12里復制出來的cookie內容是文本格式的: uui…

【第二講---初識SLAM】

SLAM簡介 視覺SLAM,主要指的是利用相機完成建圖和定位問題。如果傳感器是激光,那么就稱為激光SLAM。 定位(明白自身狀態(即位置))建圖(了解外在環境)。 視覺SLAM中使用的相機與常見…

VB+SQL銀行設備管理系統設計與實現

摘要 隨著銀行卡的普及,很多地方安裝了大量的存款機、取款機和POS機等銀行自助設備。銀行設備管理系統可以有效的記錄銀行設備的安裝和使用情況,規范對自助設備的管理,從而為用戶提供更加穩定和優質的服務。 本文介紹了銀行設備管理系統的設計和開發過程,詳細闡述了整個應…

Flink之Task解析

Flink之Task解析 對Flink的Task進行解析前,我們首先要清楚幾個角色TaskManager、Slot、Task、Subtask、TaskChain分別是什么 角色注釋TaskManager在Flink中TaskManager就是一個管理task的進程,每個節點只有一個TaskManagerSlotSlot就是TaskManager中的槽位,一個TaskManager中可…

數據結構單鏈表

單鏈表 1 鏈表的概念及結構 概念:鏈表是一種物理存儲結構上非連續、非順序的存儲結構,數據元素的邏輯順序是通過鏈表中的指針鏈 接次序實現的 。 在我們開始講鏈表之前,我們是寫了順序表,順序表就是類似一個數組的東西&#xff0…

上海虛擬展廳制作平臺怎么選,蛙色3DVR 助力行業發展

引言: 在數字化時代,虛擬展廳成為了企業宣傳的重要手段。而作為一家位于上海的實力平臺,上海蛙色3DVR憑借其卓越的功能和創新的技術,成為了企業展示和宣傳的首選。 一、虛擬展廳的優勢 虛擬展廳的崛起是指隨著科技的進步&#x…

36_windows環境debug Nginx 源碼-使用 VSCode 和WSL

文章目錄 配置 WSL編譯 NginxVSCode 安裝插件launch.json配置 WSL sudo apt-get -y install gcc cmake sudo apt-get -y install pcre sudo apt-get -y install libpcre3 libpcre3-dev sudo apt-get