梯度下降原理及Python實現

梯度下降算法是一個很基本的算法,在機器學習和優化中有著非常重要的作用,本文首先介紹了梯度下降的基本概念,然后使用python實現了一個基本的梯度下降算法。梯度下降有很多的變種,本文只介紹最基礎的梯度下降,也就是批梯度下降。


實際應用例子就不詳細說了,網上關于梯度下降的應用例子很多,最多的就是NG課上的預測房價例子:?
假設有一個房屋銷售的數據如下:

面積(m^2) 銷售價錢(萬元)

面積(m^2) 銷售價錢(萬元)
123 250
150 320
87 180

根據上面的房價我們可以做這樣一個圖:

這里寫圖片描述

于是我們的目標就是去擬合這個圖,使得新的樣本數據進來以后我們可以方便進行預測:?
這里寫圖片描述

對于最基本的線性回歸問題,公式如下:?
這里寫圖片描述?
x是自變量,比如說房子面積。θ是權重參數,也就是我們需要去梯度下降求解的具體值。

在這兒,我們需要引入損失函數(Loss function 或者叫 cost function),目的是為了在梯度下降時用來衡量我們更新后的參數是否是向著正確的方向前進,如圖損失函數(m表示訓練集樣本數量):?
這里寫圖片描述?
下圖直觀顯示了我們梯度下降的方向,就是希望從最高處一直下降到最低出:?
這里寫圖片描述

梯度下降更新權重參數的過程中我們需要對損失函數求偏導數:?
這里寫圖片描述?
求完偏導數以后就可以進行參數更新了:?
這里寫圖片描述?
偽代碼如圖所示:?
這里寫圖片描述

好了,下面到了代碼實現環節,我們用Python來實現一個梯度下降算法,求解:

y=2x1+x2+3
,也就是求解:
y=ax1+bx2+c
中的a,b,c三個參數 。

下面是代碼:

import numpy as np
import matplotlib.pyplot as plt
#y=2 * (x1) + (x2) + 3 
rate = 0.001
x_train = np.array([    [1, 2],    [2, 1],    [2, 3],    [3, 5],    [1, 3],    [4, 2],    [7, 3],    [4, 5],    [11, 3],    [8, 7]    ])
y_train = np.array([7, 8, 10, 14, 8, 13, 20, 16, 28, 26])
x_test  = np.array([    [1, 4],    [2, 2],    [2, 5],    [5, 3],    [1, 5],    [4, 1]    ])a = np.random.normal()
b = np.random.normal()
c = np.random.normal()def h(x):return a*x[0]+b*x[1]+cfor i in range(10000):sum_a=0sum_b=0sum_c=0for x, y in zip(x_train, y_train):sum_a = sum_a + rate*(y-h(x))*x[0]sum_b = sum_b + rate*(y-h(x))*x[1]sum_c = sum_c + rate*(y-h(x))a = a + sum_ab = b + sum_bc = c + sum_cplt.plot([h(xi) for xi in x_test])print(a)
print(b)
print(c)result=[h(xi) for xi in x_train]
print(result)result=[h(xi) for xi in x_test]
print(result)plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

x_train是訓練集x,y_train是訓練集y, x_test是測試集x,運行后得到如下的圖,圖片顯示了算法對于測試集y的預測在每一輪迭代中是如何變化的:?
這里寫圖片描述

我們可以看到,線段是在逐漸逼近的,訓練數據越多,迭代次數越多就越逼近真實值。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/387283.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/387283.shtml
英文地址,請注明出處:http://en.pswp.cn/news/387283.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

dagger2的初次使用

一、使用前準備 1、打開app的build.gradle文件: 頂部停用apt插件: //添加如下代碼,應用apt插件 apply plugin: com.neenbedankt.android-apt dependencies中添加依賴: //Dagger2compile com.google.dagger:dagger:2.4apt com.goog…

Storm教程2安裝部署

Storm 安裝部署 部署Storm集群需要依次完成的安裝步驟: 1.安裝jdk6及以上版本;   2. 搭建Zookeeper集群;   3. 安裝Storm依賴庫;   4. 下載并解壓Storm發布版本;   5. 修改storm.yaml配置文件;   6…

matplotlib一些常用知識點的整理,

本文作為學習過程中對matplotlib一些常用知識點的整理,方便查找。 強烈推薦ipython 無論你工作在什么項目上,IPython都是值得推薦的。利用ipython --pylab,可以進入PyLab模式,已經導入了matplotlib庫與相關軟件包(例如…

JAVA課程09

package 月份輸出;import java.util.*;public class 月份輸出 {public static void main(String[] args) {// TODO Auto-generated method stubScanner sc new Scanner(System.in);int s sc.nextInt();String a[] {"January","February","March&q…

Storm教程3編程接口

Spouts Spout是Stream的消息產生源,Spout組件的實現可以通過繼承BaseRichSpout類或者其他Spout類來完成,也可以通過實現IRichSpout接口來實現。 需要根據情況實現Spout類中重要的幾個方法有: open方法 當一個Task被初始化的時候會調用此…

梳理操作系統概論

1、用一張圖總結操作系統的結構、功能特征、采用的技術和提供服務方式等。 2、用一張圖描述CPU的工作原理。 3、用一張圖描述系統程序與應用程序、特權指令與非特權指令、CPU狀態、PSW及中斷是如何協同工作的? 轉載于:https://www.cnblogs.com/ljgljg/p/10503190.ht…

機器學習01簡介

Machine Learning 是人工智能的核心,主要使用歸納、綜合而不是演繹。 讓計算機模擬人類行為,以獲取新的知識或技能 重新組織已有的知識結構使之不斷改善自身性能 一個程序能從經驗 E 中學習,解決任務 T,達到性能度量值P&#xf…

位置指紋法的實現(KNN)

基本原理 位置指紋法可以看作是分類或回歸問題(特征是RSS向量,標簽是位置),監督式機器學習方法可以從數據中訓練出一個從特征到標簽的映射關系模型。kNN是一種很簡單的監督式機器學習算法,可以用來做分類或回歸。 對于…

室內定位系列 ——WiFi位置指紋(譯)

摘要 GPS難以解決室內環境下的一些定位問題,大部分室內環境下都存在WiFi,因此利用WiFi進行定位無需額外部署硬件設備,是一個非常節省成本的方法。然而WiFi并不是專門為定位而設計的,傳統的基于時間和角度的定位方法并不適用于WiFi…

機器學習02線性回歸、多項式回歸、正規方程

單變量線性回歸(Linear Regression with One Variable) 預測器表達式: 選擇合適的參數(parameters)θ0 和 θ1,其決定了直線相對于訓練集的準確程度。 建模誤差(modeling error)&a…

最大乘積

給定一個無序數組,包含正數、負數和0,要求從中找出3個數的乘積,使得乘積最大,要求時間復雜度:O(n),空間復雜度:O(1) def solve():n input()a input().split()for i in range(len(a)):a[i] in…

機器學習03Logistic回歸

邏輯回歸 (Logistic Regression) 目前最流行,使用最廣泛的一種學習算法。 分類問題,要預測的變量 y 是離散的值。 邏輯回歸算法的性質是:它的輸出值永遠在 0 到 1 之間。 邏輯回歸模型的假設是: 其中&a…

基礎架構系列匯總

為了方便查找,把基礎架構系統文章按時間正序整理了一下,記錄如下: 1. 基礎架構之日志管理平臺搭建及java&net使用 2. 基礎架構之日志管理平臺及釘釘&郵件告警通知 3. 基礎架構之分布式配置中心 4. 基礎架構之分布式任務平臺 5. 基礎架…

CNN理解比較好的文章

什么是卷積神經網絡?為什么它們很重要? 卷積神經網絡(ConvNets 或者 CNNs)屬于神經網絡的范疇,已經在諸如圖像識別和分類的領域證明了其高效的能力。卷積神經網絡可以成功識別人臉、物體和交通信號,從而為機…

Windows 安裝Angular CLI

1、安裝nvm npm cnpm nrm(onenote筆記上有記錄) 參考:https://blog.csdn.net/tyro_java/article/details/51232458 提示:如果發現配置完后,出現類似“npm不是內部命令……”等信息。 可采取如下措施進行解決—— 檢查環…

機器學習04正則化

正則化(Regularization) 過擬合問題(Overfitting): 如果有非常多的特征,通過學習得到的假設可能能夠非常好地適應訓練集 :代價函數可能幾乎為 0), 但是可能會不能推廣到…

Adaboost算法

概述 一句話概述Adaboost算法的話就是:把多個簡單的分類器結合起來形成個復雜的分類器。也就是“三個臭皮匠頂一個諸葛亮”的道理。 可能僅看上面這句話還沒什么概念,那下面我引用個例子。 如下圖所示: 在D1這個數據集中有兩類數據“”和“-”…

Codeforces 408D Long Path (DP)

題目: One day, little Vasya found himself in a maze consisting of (n??1) rooms, numbered from 1 to (n??1). Initially, Vasya is at the first room and to get out of the maze, he needs to get to the (n??1)-th one. The maze is organized as fol…

機器學習05神經網絡--表示

神經網絡:表示(Neural Networks: Representation) 如今的神經網絡對于許多應用來說是最先進的技術。 對于現代機器學習應用,它是最有效的技術方法。 神經網絡模型是許多邏輯單元按照不同層級組織起來的網絡, 每一層…

邏輯回歸(Logistic Regression, LR)又稱為邏輯回歸分析,是分類和預測算法中的一種。通過歷史數據的表現對未來結果發生的概率進行預測。例如,我們可以將購買的概率設置為因變量,將用戶的

邏輯回歸(Logistic Regression, LR)又稱為邏輯回歸分析,是分類和預測算法中的一種。通過歷史數據的表現對未來結果發生的概率進行預測。例如,我們可以將購買的概率設置為因變量,將用戶的特征屬性,例如性別,年齡&#x…