梯度下降代碼

整體流程

數據預處理:標準化->加一列全為1的偏置項

訓練:梯度下降,將數學公式轉換成代碼

預測

模型代碼?

import numpy as np# 標準化函數:對特征做均值-方差標準化
# 返回標準化后的特征、新數據的均值和標準差,用于后續預測def standard(feats):new_feats = np.copy(feats).astype(float)mean = np.mean(new_feats, axis=0)std = np.std(new_feats, axis=0)std[std == 0] = 1new_feats = (new_feats - mean) / stdreturn new_feats, mean, stdclass LinearRegression:def __init__(self, data, labels):# 對訓練數據進行標準化new_data, mean, std = standard(data)# 存儲用于預測的均值和標準差self.mean = meanself.std = std# 樣本數 m 和 原始特征數 nm, n = new_data.shape# 在特征矩陣前加一列 1 作為偏置項X = np.hstack((np.ones((m, 1)), new_data))  # shape (m, n+1)self.X = X                # 訓練特征 (m, n+1)self.y = labels           # 訓練標簽 (m, 1)self.m = m                # 樣本數self.n = n + 1            # 特征數(含偏置)# 初始化參數 thetaself.theta = np.zeros((self.n, 1))def train(self, alpha, num_iterations=500):"""執行梯度下降:param alpha: 學習率:param num_iterations: 迭代次數:return: 學習到的 theta 和每次迭代的損失歷史"""cost_history = []for _ in range(num_iterations):self.gradient_step(alpha)cost_history.append(self.cost_function())return self.theta, cost_historydef gradient_step(self, alpha):# 計算預測值predictions = self.X.dot(self.theta)          # shape (m,1)# 計算誤差delta = predictions - self.y                  # shape (m,1)# 計算梯度并更新 thetagrad = (self.X.T.dot(delta)) / self.m         # shape (n+1,1)self.theta -= alpha * graddef cost_function(self):# 計算當前 theta 下的損失delta = self.X.dot(self.theta) - self.y       # shape (m,1)return float((delta.T.dot(delta)) / (2 * self.m))def predict(self, data):"""對新數據進行預測:param data: 新數據,shape (m_new, n):return: 預測值,shape (m_new, 1)"""# 確保輸入為二維數組data = np.array(data, ndmin=2)# 使用訓練時的均值和標準差進行標準化new_data = (data - self.mean) / self.std# 加入偏置項m_new = new_data.shape[0]X_new = np.hstack((np.ones((m_new, 1)), new_data))# 返回預測結果return X_new.dot(self.theta)

測試代碼

import numpy as np
import pandas as pd
import matplotlib.pyplot as pltfrom linear_regression import LinearRegression
data = pd.read_csv('../data/world-happiness-report-2017.csv')train_data = data.sample(frac = 0.8)
test_data = data.drop(train_data.index)
input_param_name = 'Economy..GDP.per.Capita.'
output_param_name = 'Happiness.Score'
# 取出城市gdp的值和對應的幸福指數
x_train = train_data[[input_param_name]].values
y_train = train_data[[output_param_name]].values
x_test = test_data[input_param_name].values
y_test = test_data[output_param_name].valuesnum_iterations = 500
learning_rate = 0.01
# 訓練
# x_train是gdp值,y_train是幸福指數
linear_regression = LinearRegression(x_train,y_train)
# 梯度下降比率,訓練輪數
(theta,cost_history) = linear_regression.train(learning_rate,num_iterations)print ('開始時的損失:',cost_history[0])
print ('訓練后的損失:',cost_history[-1])plt.plot(range(num_iterations),cost_history)
plt.xlabel('Iter')
plt.ylabel('cost')
plt.title('GD')
plt.show()predictions_num = 100
# 最小值,最大值,多少個等間隔的數,然后做成列向量的形式
x_predictions = np.linspace(x_train.min(),x_train.max(),predictions_num).reshape(predictions_num,1)y_predictions = linear_regression.predict(x_predictions)plt.scatter(x_train,y_train,label='Train data')
plt.scatter(x_test,y_test,label='test data')
plt.plot(x_predictions,y_predictions,'r',label = 'Prediction')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/79466.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/79466.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/79466.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

RAG 實戰|用 StarRocks + DeepSeek 構建智能問答與企業知識庫

文章作者: 石強,鏡舟科技解決方案架構師 趙恒,StarRocks TSC Member 👉 加入 StarRocks x AI 技術討論社區 https://mp.weixin.qq.com/s/61WKxjHiB-pIwdItbRPnPA RAG 和向量索引簡介 RAG(Retrieval-Augmented Gen…

從零開始學A2A一:A2A 協議的高級應用與優化

A2A 協議的高級應用與優化 學習目標 掌握 A2A 高級功能 理解多用戶支持機制掌握長期任務管理方法學習服務性能優化技巧 理解與 MCP 的差異 分析多智能體場景下的優勢掌握不同場景的選擇策略 第一部分:多用戶支持機制 1. 用戶隔離架構 #mermaid-svg-Awx5UVYtqOF…

【C++】入門基礎【上】

目錄 一、C的發展歷史二、C學習書籍推薦三、C的第一個程序1、命名空間namespace2、命名空間的使用3、頭文件<iostream>是干什么的&#xff1f; 個人主頁<—請點擊 C專欄<—請點擊 一、C的發展歷史 C的起源可以追溯到1979年&#xff0c;當時Bjarne Stroustrup(本…

1panel第三方應用商店(本地商店)配置和使用

文章目錄 引言資源網站實戰操作說明 引言 1Panel 提供了一個應用提交開發環境&#xff0c;開發者可以通過提交應用的方式將自己的應用推送到 1Panel 的應用商店中&#xff0c;供其他用戶使用。由此衍生了一種本地應用商店的概念&#xff0c;用戶可以自行編寫應用配置并上傳到自…

Evidential Deep Learning和證據理論教材的區別(主要是概念)

最近終于徹底搞懂了Evidential Deep Learning&#xff0c;之前有很多看不是特別明白的地方&#xff0c;原來是和證據理論教材&#xff08;是的&#xff0c;不只是國內老師寫的&#xff0c;和國外的老師寫的教材出入也比較大&#xff09;的說法有很多不一樣&#xff0c;所以特地…

text-decoration: underline;不生效

必須得紀念一下&#xff0c;在給文本加下劃線時&#xff0c;發現在win電腦不生效&#xff0c;部分mac也不生效&#xff0c;只有個別的mac生效了&#xff0c;思考了以下幾種方面&#xff1a; 1.兼容性問題&#xff1f; 因為是electron項目&#xff0c;不存在瀏覽器兼容性問題&…

VUE SSR(服務端渲染)

&#x1f916; 作者簡介&#xff1a;水煮白菜王&#xff0c;一位前端勸退師 &#x1f47b; &#x1f440; 文章專欄&#xff1a; 前端專欄 &#xff0c;記錄一下平時在博客寫作中&#xff0c;總結出的一些開發技巧和知識歸納總結?。 感謝支持&#x1f495;&#x1f495;&#…

ARCGIS國土超級工具集1.5更新說明

ARCGIS國土超級工具集V1.5版本更新說明&#xff1a;因作者近段時間工作比較忙及正在編寫ARCGISPro國土超級工具集&#xff08;截圖附后&#xff09;的原因&#xff0c;故本次更新為小更新&#xff08;沒有增加新功能&#xff0c;只更新了已有的工具&#xff09;。本次更新主要修…

劉鑫煒履新共工新聞社新媒體研究院院長,賦能媒體融合新征程

2025年4月18日&#xff0c;大灣區經濟網戰略媒體共工新聞社正式對外宣布一項重要人事任命&#xff1a;聘任螞蟻全媒體總編劉鑫煒為新媒體研究院第一任院長。這一舉措&#xff0c;無疑是對劉鑫煒在新媒體領域卓越專業能力與突出行業貢獻的又一次高度認可&#xff0c;也預示著共工…

java基礎從入門到上手(九):Java - List、Set、Map

一、List集合 List 是一種用于存儲有序元素的集合接口&#xff0c;它是 java.util 包中的一部分&#xff0c;并且繼承自 Collection 接口。List 接口提供了多種方法&#xff0c;用于按索引操作元素&#xff0c;允許元素重復&#xff0c;并且保持插入順序。常用的 List 實現類包…

UWP發展歷程

通用Windows平臺(UWP)發展歷程 引言 通用Windows平臺(Universal Windows Platform, UWP)是微軟為實現"一次編寫&#xff0c;處處運行"的愿景而打造的現代應用程序平臺。作為微軟統一Windows生態系統的核心戰略組成部分&#xff0c;UWP代表了從傳統Win32應用向現代應…

git忽略已跟蹤的文件/指定文件

在項目開發中&#xff0c;有時候我們并不需要git跟蹤所有文件&#xff0c;而是需要忽略掉某些指定的文件或文件夾&#xff0c;怎么操作呢&#xff1f;我們分兩種情況討論&#xff1a; 1. 要忽略的文件之前并未被git跟蹤 這種情況常用的方法是在項目的根目錄下創建和編輯.gitig…

AI 組件庫是什么?如何影響UI的開發?

AI組件庫是基于人工智能技術構建的、面向用戶界面&#xff08;UI&#xff09;開發的預制模塊集合。它們結合了傳統UI組件&#xff08;如按鈕、表單、圖表&#xff09;與AI能力&#xff08;如機器學習、自然語言處理、計算機視覺&#xff09;&#xff0c;旨在簡化開發流程并增強…

【Win】 cmd 執行curl命令時,輸出 ‘命令管道位置 1 的 cmdlet Invoke-WebRequest 請為以下參數提供值: Uri: ’ ?

1.原因&#xff1a; 有一個名為 Invoke-WebRequest 的 CmdLet&#xff0c;其別名為 curl。因此&#xff0c;當您執行此命令時&#xff0c;它會嘗試使用 Invoke-WebRequest&#xff0c;而不是使用 curl。 2.解決辦法 在cmd中輸入如下命令刪除這個curl別名&#xff1a; Remov…

UE5 UE循環體里怎么寫延遲

注&#xff1a;需要修改UE循環藍圖節點或者自己新建個藍圖宏庫把UE循環節點的原來代碼粘貼進去修改。 一、For Loop With Delay 二、For Each Loop With Delay 示例使用&#xff1a; 標注參考出處&#xff1a;分享UE5自制Loop with delay宏&#xff0c;在loop循環中添加執行…

IP檢測工具“ipjiance”

目錄 IP質量檢測 應用場景 對網絡安全的貢獻 對網絡管理的幫助 對用戶決策的輔助作用 IP質量檢測 檢測IP的網絡提供商&#xff1a;通過ASN&#xff08;自治系統編號&#xff09;識別IP地址所屬的網絡運營商&#xff0c;例如電信、移動、聯通等。 識別網絡類型&#xff1…

[工具]Java xml 轉 Json

[工具]Java xml 轉 Json 依賴 <!-- https://mvnrepository.com/artifact/cn.hutool/hutool-all --> <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.37</version> </dependen…

vue3 傳參 傳入變量名

背景&#xff1a; 需求是&#xff1a;在vue框架中&#xff0c;接口傳參我們需要穿“變量名”&#xff0c;而不是字符串 通俗點說法是&#xff1a;在網絡接口請求的時候&#xff0c;要傳屬性名 效果展示&#xff1a; vue2核心代碼&#xff1a; this[_keyParam] vue3核心代碼&…

spring響應式編程系列:總體流程

目錄 示例 程序流程 just subscribe new LambdaMonoSubscriber ???????MonoJust.subscribe ???????new Operators.ScalarSubscription ???????onSubscribe ???????request ???????onNext 時序圖 類圖 數據發布者 MonoJust …

基于slimBOXtv 9.16 V2-晶晨S905L3A/ S905L3AB-Mod ATV-Android9.0-線刷通刷固件包

基于slimBOXtv 9.16 V2-晶晨S905L3A&#xff0f; S905L3AB-Mod ATV-Android9.0-線刷通刷固件包&#xff0c;基于SlimBOXtv 9 修改而來&#xff0c;貼近于原生ATV&#xff0c;僅支持晶晨S905L3A&#xff0f; S905L3AB芯片刷機。 適用型號&#xff1a;M401A、CM311-1a、CM311-1s…