【單層神經網絡】基于MXNet的線性回歸實現(底層實現)

寫在前面

  1. 剛開始先從普通的尋優算法開始,熟悉一下學習訓練過程
  2. 下面將使用梯度下降法尋優,但這大概只能是局部最優,它并不是一個十分優秀的尋優算法

整體流程

  1. 生成訓練數據集(實際工程中,需要從實際對象身上采集數據)
  2. 確定模型及其參數(輸入輸出個數、階次,偏置等)
  3. 確定學習方式(損失函數、優化算法,學習率,訓練次數,終止條件等)
  4. 讀取數據集(不同的讀取方式會影響最終的訓練效果)
  5. 訓練模型

完整程序及注釋

from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random'''
獲取(生成)訓練集
'''
input_num = 2				# 輸入個數
examples_num = 1000			# 生成樣本個數
# 確定真實模型參數
real_W = [10.9, -8.7]		
real_bias = 6.5	features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 標準差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根據特征和參數生成對應標簽
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 為標簽附加噪聲,模擬真實情況# 繪制標簽和特征的散點圖(矢量圖)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()# 創建一個迭代器(確定從數據集獲取數據的方式)
def data_iter(batch_size, features, labels):num = len(features)indices = list(range(num))                                  # 生成索引數組random.shuffle(indices)                                     # 打亂indices# 該遍歷方式同時確保了隨機采樣和無遺漏for i in range(0, num, batch_size):j = nd.array(indices[i: min(i+batch_size, num)])        # 對indices從i開始取,取batch_size個樣本,并轉換為列表yield features.take(j), labels.take(j)                  # take方法使用索引數組,從features和labels提取所需數據"""
訓練的基礎準備
"""
# 聲明訓練變量,并賦高斯隨機初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同寫法,等價于上面的
w.attach_grad()         # 為需要迭代的參數申請求梯度空間
b.attach_grad()# 定義模型
def linreg(X, w, b):return nd.dot(X,w)+b# 定義損失函數
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) **2 /2# 定義尋優算法
def sgd(params, learning_rate, batch_size):for param in params:# 新參數 = 原參數 - 學習率*當前批量的參數梯度/當前批量的大小param[:] = param - learning_rate * param.grad / batch_size# 確定超參數和學習方式
lr = 0.03
num_iterations = 5
net = linreg				# 目標模型
loss = squared_loss			# 代價函數(損失函數)
batch_size = 10				# 每次隨機小批量的大小'''
開始訓練
'''
for iteration in range(num_iterations):		# 確定迭代次數for x, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(x,w,b), y)			# 求當前小批量的總損失l.backward()						# 求梯度sgd([w,b], lr, batch_size)			# 梯度更新參數train_l = loss(net(features,w,b), labels)print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比較真實參數和訓練得到的參數
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具體程序解釋

param[:] = param - learning_rate * param.grad / batch_size
將batch_size與參數調整相關聯的原因,是為了使得每次更新的步長不受批次大小的影響
具體來說,當計算一批數據的損失函數的梯度時,實際上是將這批數據中每個樣本對損失函數的貢獻累加起來。這意味著如果批次較大,梯度的模也會相應增大
故更新權值時,使用的是數據集的平均梯度,而不是總和

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/894648.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/894648.shtml
英文地址,請注明出處:http://en.pswp.cn/news/894648.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

本地快速部署DeepSeek-R1模型——2025新年賀歲

一晃年初六了,春節長假余額馬上歸零了。今天下午在我的電腦上成功部署了DeepSeek-R1模型,抽個時間和大家簡單分享一下過程: 概述 DeepSeek模型 是一家由中國知名量化私募巨頭幻方量化創立的人工智能公司,致力于開發高效、高性能…

C++11詳解(一) -- 列表初始化,右值引用和移動語義

文章目錄 1.列表初始化1.1 C98傳統的{}1.2 C11中的{}1.3 C11中的std::initializer_list 2.右值引用和移動語義2.1左值和右值2.2左值引用和右值引用2.3 引用延長生命周期2.4左值和右值的參數匹配問題2.5右值引用和移動語義的使用場景2.5.1左值引用主要使用場景2.5.2移動構造和移…

在K8S中,pending狀態一般由什么原因導致的?

在Kubernetes中,資源或Pod處于Pending狀態可能有多種原因引起。以下是一些常見的原因和詳細解釋: 資源不足 概述:當集群中的資源不足以滿足Pod或服務的需求時,它們可能會被至于Pending狀態。這通常涉及到CPU、內存、存儲或其他資…

手寫MVVM框架-構建虛擬dom樹

MVVM的核心之一就是虛擬dom樹,我們這一章節就先構建一個虛擬dom樹 首先我們需要創建一個VNode的類 // 當前類的位置是src/vnode/index.js export default class VNode{constructor(tag, // 標簽名稱(英文大寫)ele, // 對應真實節點children,…

linux內核源代碼中__init的作用?

在 Linux 內核源代碼中,__init是一個特殊的宏,用于標記在內核初始化階段使用的變量或函數。這個宏的作用是告訴內核編譯器和鏈接器,被標記的變量或函數只在內核的初始化階段使用,在系統啟動完成后就不再需要了。因此,這…

【大數據技術】教程03:本機PyCharm遠程連接虛擬機Python

本機PyCharm遠程連接虛擬機Python 注意:本文需要使用PyCharm專業版。 pycharm-professional-2024.1.4VMware Workstation Pro 16CentOS-Stream-10-latest-x86_64-dvd1.iso寫在前面 本文主要介紹如何使用本地PyCharm遠程連接虛擬機,運行Python腳本,提高編程效率。 注意: …

pytorch實現門控循環單元 (GRU)

人工智能例子匯總:AI常見的算法和例子-CSDN博客 特性GRULSTM計算效率更快,參數更少相對較慢,參數更多結構復雜度只有兩個門(更新門和重置門)三個門(輸入門、遺忘門、輸出門)處理長時依賴一般適…

PAT甲級1032、sharing

題目 To store English words, one method is to use linked lists and store a word letter by letter. To save some space, we may let the words share the same sublist if they share the same suffix. For example, loading and being are stored as showed in Figure …

最小生成樹kruskal算法

文章目錄 kruskal算法的思想模板 kruskal算法的思想 模板 #include <bits/stdc.h> #define lowbit(x) ((x)&(-x)) #define int long long #define endl \n #define PII pair<int,int> #define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0); using na…

為何在Kubernetes容器中以root身份運行存在風險?

作者&#xff1a;馬辛瓦西奧內克&#xff08;Marcin Wasiucionek&#xff09; 引言 在Kubernetes安全領域&#xff0c;一個常見的建議是讓容器以非root用戶身份運行。但是&#xff0c;在容器中以root身份運行&#xff0c;實際會帶來哪些安全隱患呢&#xff1f;在Docker鏡像和…

js --- 獲取時間戳

介紹 使用js獲取當前時間戳 語法 Date.now()

ConcurrentHashMap線程安全:分段鎖 到 synchronized + CAS

專欄系列文章地址&#xff1a;https://blog.csdn.net/qq_26437925/article/details/145290162 本文目標&#xff1a; 理解ConcurrentHashMap為什么線程安全&#xff1b;ConcurrentHashMap的具體細節還需要進一步研究 目錄 ConcurrentHashMap介紹JDK7的分段鎖實現JDK8的synchr…

Vue和Java使用AES加密傳輸

背景&#xff1a;Vue對參數進行加密&#xff0c;對響應進行解密。Java對參數進行解密&#xff0c;對響應進行解密。不攔截文件上傳類請求、GET請求。 【1】前端配置 安裝crypto npm install crypto-js編寫加解密工具類encrypt.js import CryptoJS from crypto-jsconst KEY …

開發板目錄 /usr/lib/fonts/ 中的字體文件 msyh.ttc 的介紹【微軟雅黑(Microsoft YaHei)】

本文是博文 https://blog.csdn.net/wenhao_ir/article/details/145433648 的延伸擴展。 本文是博文 https://blog.csdn.net/wenhao_ir/article/details/145433648 的延伸擴展。 問&#xff1a;運行 ls /usr/lib/fonts/ 發現有一個名叫 msyh.ttc 的字體文件&#xff0c;能介紹…

[ESP32:Vscode+PlatformIO]新建工程 常用配置與設置

2025-1-29 一、新建工程 選擇一個要創建工程文件夾的地方&#xff0c;在空白處鼠標右鍵選擇通過Code打開 打開Vscode&#xff0c;點擊platformIO圖標&#xff0c;選擇PIO Home下的open&#xff0c;最后點擊new project 按照下圖進行設置 第一個是工程文件夾的名稱 第二個是…

述評:如果抗拒特朗普的“普征關稅”

題 記 美國總統特朗普宣布對美國三大貿易夥伴——中國、墨西哥和加拿大&#xff0c;分別征收10%、25%的關稅。 他威脅說&#xff0c;如果這三個國家不解決他對非法移民和毒品走私的擔憂&#xff0c;他就要征收進口稅。 去年&#xff0c;中國、墨西哥和加拿大這三個國家&#…

九. Redis 持久化-AOF(詳細講解說明,一個配置一個說明分析,步步講解到位 2)

九. Redis 持久化-AOF(詳細講解說明&#xff0c;一個配置一個說明分析&#xff0c;步步講解到位 2) 文章目錄 九. Redis 持久化-AOF(詳細講解說明&#xff0c;一個配置一個說明分析&#xff0c;步步講解到位 2)1. Redis 持久化 AOF 概述2. AOF 持久化流程3. AOF 的配置4. AOF 啟…

C++11新特性之long long超長整形

1.介紹 long long 超長整形是C11標準新添加的&#xff0c;用于表示更大范圍整數的類型。 2.用法 占用空間&#xff1a;至少64位&#xff08;8個字節&#xff09;。 對于有符號long long 整形&#xff0c;后綴用“LL”或“II”標識。例如&#xff0c;“10LL”就表示有符號超長整…

瀏覽器查詢所有的存儲信息,以及清除的語法

要在瀏覽器的控制臺中查看所有的存儲&#xff08;例如 localStorage、sessionStorage 和 cookies&#xff09;&#xff0c;你可以使用瀏覽器開發者工具的 "Application" 標簽頁。以下是操作步驟&#xff1a; 1. 打開開發者工具 在 Chrome 或 Edge 瀏覽器中&#xf…

基于Springboot框架的學術期刊遴選服務-項目演示

項目介紹 本課程演示的是一款 基于Javaweb的水果超市管理系統&#xff0c;主要針對計算機相關專業的正在做畢設的學生與需要項目實戰練習的 Java 學習者。 1.包含&#xff1a;項目源碼、項目文檔、數據庫腳本、軟件工具等所有資料 2.帶你從零開始部署運行本套系統 3.該項目附…