深度學習——回歸實戰

線性回歸:

線性:自變量和應變量之間是線性關系,如:y = wx +b

回歸:擬合一條曲線,使真實值和擬合值差距盡可能小

目標:求解參數w和b? ? ? ? ?所用算法:梯度下降算法

梯度下降:向著梯度方向(下降最快的方向)走一步,不停迭代

梯度下降過程:

訓練循環(核心步驟):

  • 前向傳播
    • 將訓練集中的一批數據(一個批次)輸入到模型中,通過模型的各層計算得到輸出。
    • 這個過程中,數據按照模型的架構順序依次通過各層,每層根據其權重和偏置對數據進行計算,如在全連接層中,計算是通過矩陣乘法和加法實現的。
  • 計算損失
    • 使用定義好的損失函數,計算模型預測輸出與該批次數據真實標簽之間的損失。
    • 損失值反映了當前模型在這一批次數據上的預測誤差大小。
  • 反向傳播
    • 根據計算得到的損失,通過鏈式法則從最后一層開始,逐層計算損失對每個模型參數(權重和偏置)的梯度。
    • 反向傳播算法使得模型能夠知道每個參數對損失的影響程度,從而為參數更新提供依據。
  • 更新參數
    • 使用選擇的優化器,根據計算得到的梯度更新模型的參數。例如,在使用 SGD 優化器時,參數更新公式為:參數 = 參數 - 學習率 * 梯度。
    • 更新后的參數將用于下一次迭代的前向傳播,通過不斷重復這個過程,模型的參數逐漸調整,使得損失函數不斷減小。
      import torch #深度學習框架
      import matplotlib.pyplot as plt #畫圖
      import random #隨機def create_data(w, b, data_num):   #生成數據x = torch.normal(0, 1, (data_num, len(w)))   #平均數為0,方差為1,長度為data_num,寬度為len(w)y = torch.matmul(x, w) + b #通過矩陣乘法將輸入數據x與權重w相乘,然后加上偏置項b,生成新的輸出ynoise = torch.normal(0, 0.01, y.shape)   #噪聲要加到y上y+= noise      #模擬真實數據的不確定性、防止模型過擬合return x, ynum = 500   # 生成的數據數量true_w = torch.tensor([8.1, 2, 2, 4])  # 真實的權重
      true_b = torch.tensor(1.1)   # 真實的偏置X, Y = create_data(true_w, true_b, num)  # 生成數據plt.scatter(X[:, 0], Y, 1)  #對x張量進行切片,選擇所有行、第一列
      plt.show()def data_provider(data, label, batchsize):   #每次訪問這個函數,就能提供一批數據,傳入參數依次為:數據、標簽、步長length = len(label)   # 獲取標簽數據的長度,由于數據和標簽一一對應,以此代表整個數據集的長度indices = list(range(length))  #  創建一個從0到length - 1的索引列表,每個索引對應數據集中的一個樣本,用于后續操作random.shuffle(indices)  # 對索引列表進行隨機打亂,確保每次取數據批次時是隨機順序,避免數據順序依賴,增強模型訓練效果# 按照指定的批量大小batchsize遍歷整個數據集,每次取出一個批次的數據范圍for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]  # 從打亂后的索引列表中取出當前批次對應的索引范圍get_data = data[get_indices]   # 根據取出的索引范圍從數據張量data中獲取當前批次的數據get_label = label[get_indices]  # 根據取出的索引范圍從標簽張量label中獲取當前批次對應的標簽yield get_data, get_label  # 使用yield關鍵字返回當前批次的數據和標簽,使函數成為生成器,下次調用繼續返回下一批次batchsize = 16  #步長設置為16
      # for batch_x, batch_y in data_provider(X, Y, batchsize):
      #     print(batch_x, batch_y)
      #     #break# 定義函數fun,用于根據輸入數據x、權重w和偏置b進行線性變換計算,得到預測輸出
      def fun(x, w, b):pred_y = torch.matmul(x, w) + b   # 使用torch.matmul對輸入數據x和權重w進行矩陣乘法運算,然后加上偏置b,得到預測輸出pred_yreturn pred_y# 定義函數maeloss,用于計算預測值pre_y和真實值y之間的平均絕對誤差(MAE)損失
      def maeloss(pre_y, y):return torch.sum(abs(pre_y-y))/len(y)  # 先計算預測值和真實值之間差值的絕對值,再對所有差值的絕對值求和,最后除以數據數量len(y),得到平均絕對誤差作為損失值并返回# 定義函數sgd,實現隨機梯度下降算法,用于更新模型參數
      def sgd(paras, lr):          #隨機梯度下降,更新參數# 使用torch.no_grad()上下文管理器,在這個范圍內的操作不會進行梯度計算,因為參數更新階段不需要對更新操作本身計算梯度with torch.no_grad():  #屬于這句代碼的部分,不計算梯度for para in paras:# 遍歷要更新的參數列表paras中的每一個參數# 根據隨機梯度下降算法規則,將當前參數para減去其梯度para.grad與學習率lr的乘積,實現參數更新(注意要用 -= 操作符進行原位更新)para -= para.grad* lr  #不能寫成  para = para - para.grad*lrpara.grad.zero_()     #使用過的梯度,歸0,避免下一次迭代時梯度累積,導致參數更新錯誤lr = 0.03  #學習率
      w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)   # 使用torch.normal函數按照正態分布來初始化權重參數w_0,其中均值為0,方差為0.01,形狀與之前定義的真實權重true_w保持一致,并且設置requires_grad為True,意味著這個參數在后續的計算中需要跟蹤計算梯度,以便進行自動求導來更新它
      b_0 = torch.tensor(0.01, requires_grad=True)  # 初始化偏置參數b_0,將其設置為值是0.01的標量張量,同時設置requires_grad為True,這樣該參數就能參與到梯度計算以及后續的參數更新過程中
      print(w_0, b_0)epochs  = 50#訓練多少輪# 按訓練輪數epochs循環,每輪訓練模型
      for epoch in range(epochs):data_loss = 0   # 初始化本輪累計損失為0for batch_x, batch_y in data_provider(X, Y, batchsize):   # 按批次獲取數據,循環處理每批次pred_y = fun(batch_x, w_0, b_0)   # 用當前參數預測本批次數據,得預測值loss = maeloss(pred_y, batch_y)  # 計算預測值與真實值的損失loss.backward()  # 反向傳播求梯度sgd([w_0, b_0], lr)  # 用sgd更新模型參數data_loss += loss  # 累加本批次損失到本輪累計損失print("epoch %03d: loss: %.6f"%(epoch, data_loss))  # 打印本輪輪數和累計損失,觀察訓練情況print("真實的函數值是", true_w, true_b)
      print("訓練得到的參數值是", w_0, b_0)idx = 0  # 初始化一個索引變量idx為0,這個索引通常用于選擇數據張量X中的某一列數據,后續可能用于可視化等相關操作
      plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())  # 繪制擬合直線,取X第idx列數據轉numpy作橫坐標,按線性回歸公式用訓練參數算縱坐標來繪制
      plt.scatter(X[:, idx], Y, 1)  # 繪制散點圖,以輸入數據張量X的第idx列數據作為橫坐標,以對應的輸出數據Y作為縱坐標,展示原始數據的分布情況
      plt.show()
      Tips:
      1、yield get_data, get_label當函數執行到?yield?語句時,函數會暫停執行,將?yield?后面的值返回給調用者,但函數并沒有結束。下次調用這個函數時,它會從上次暫停的地方繼續執行,直到遇到下一個?yield?或者函數結束。這意味著在一個生成器函數中,可以通過多個?yield?語句多次返回不同的值。就像在?data_provider?函數中,每次調用會返回一個新的數據批次和標簽批次,直到所有批次都返回完。
      2、para.grad.zero_() 如果我們不清零這個梯度,在第二次訓練批次進行反向傳播時,新計算出來的梯度會和第一次遺留下來的梯度相加。就好像你在走迷宮,第一次得到的指示(梯度)是向左走三步,但是你沒記住這個指示,第二次又得到一個指示(新的梯度)是向右走兩步,但是你把兩次的指示混在一起,變成了向左走一步(假設梯度相加的情況),這樣就會讓你的方向(參數更新方向)變得混亂。
      3、plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())①繪制一條直線,用于表示擬合的線性關系(在二維平面上展示線性回歸擬合情況)。②首先從輸入數據張量X中取出第idx列數據(通過X[:,?idx]),并將其轉換為numpy數組(使用detach().numpy()方法,目的是從計算圖中分離出來并轉為numpy格式方便繪圖),作為橫坐標。然后根據線性回歸的公式y?=?w?*?x?+?b,計算對應的縱坐標,這里使用當前訓練得到的權重w_0的第idx個元素(w_0[idx])和偏置b_0,同樣轉換為numpy數組后參與計算,以此繪制出擬合的直線。
      

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/65438.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/65438.shtml
英文地址,請注明出處:http://en.pswp.cn/web/65438.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Angular 最新版本和 Vue 對比完整指南

1. Angular 最新版本 當前 Angular 最新穩定版本是 Angular 17(2024年初) 2. 主要區別對比表 特性 | Angular | Vue 框架類型 | 完整框架 | 漸進式框架 默認語言 | TypeScript | JavaScript/TypeScript 數據處理 | RxJS | Promise/async/await 架構特點 | 依賴注入,…

單片機-串轉并-74HC595芯片

1、74HC595芯片介紹 74HC595 是一個 8 位串行輸入、并行輸出的位移緩存器,其中并行輸出為三態輸出(即高電平、低電平和高阻抗)。 15 和 1 到 7 腳 QA--QH:并行數據輸出 9 腳 QH 非:串行數據輸出 10 腳 SCLK 非&#x…

探索AI在地質科研繪圖中的應用:ChatGPT與Midjourney繪圖流程與效果對比

文章目錄 個人感受一、AI繪圖流程1.1 Midjourney(1)環境配置(2)生成prompt(3)完善prompt(4)開始繪圖(5)后處理 1.2 ChatGPT不合理的出圖結果解決方案 二、主題…

【微服務】6、限流 熔斷

線程隔離與容錯處理 本視頻主要講解了在購物車業務中,因商品微服務響應慢導致的問題及解決方案,重點介紹了線程隔離后查詢購物車業務不可用的情況,以及如何通過Fallback邏輯進行緩解,包括配置Feign調用為簇點資源、添加Fallback邏…

springboot+vue使用easyExcel實現導出功能

vue部分 // 導出計算數據exportDataHandle(id) {this.$http({url: this.$http.adornUrl(/xxx/xxx/exportCalDataExcel),method: post,data: this.$http.adornData({id: id}),responseType: blob, // 重要:告訴axios我們希望接收二進制數據}).then(({data}) > {c…

25年01月HarmonyOS應用基礎認證最新題庫

判斷題 “一次開發,多端部署”指的是一個工程,一次開發上架,多端按需部署。為了實現這一目的,HarmonyOS提供了多端開發環境,多端開發能力以及多端分發機制。 答案:正確 《鴻蒙生態應用開發白皮書》全面闡釋…

ELK實戰(最詳細)

一、什么是ELK ELK是三個產品的簡稱:ElasticSearch(簡稱ES) 、Logstash 、Kibana 。其中: ElasticSearch:是一個開源分布式搜索引擎Logstash :是一個數據收集引擎,支持日志搜集、分析、過濾,支持大量數據…

Dubbo-筆記隨記一

一、實戰 1 . Springboot整合 1.1 服務提供者 1.1.1 依賴 <dependency><groupId>org.apache.dubbo</groupId><artifactId>dubbo-spring-boot-starter</artifactId><version>3.2.10</version></dependency><dependency&g…

git tag

文章目錄 1.簡介2.格式3.選項4.示例參考文獻 1.簡介 同大多數 VCS 一樣&#xff0c;Git 也可以對某一時間點的版本打上標簽&#xff0c;用于版本的發布管理。 一個版本發布時&#xff0c;我們可以為當前版本打上類似于 v.1.0.1、v.1.0.2 這樣的 Tag。一個 Tag 指向一個 Commi…

ETCD滲透利用指南

目錄 未指定使用put操作報錯 未指定操作版本使用get報錯 首先etcd分為兩個版本v2和v3&#xff0c;不同的API結果無論是訪問URL還是使用etcdctl進行通信&#xff0c;都會導致問題&#xff0c;例如使用etcdctl和v3進行通信&#xff0c;如果沒有實名ETCDCTL_API3指定API版本會直接…

使用VUE3創建個人靜態主頁

使用VUE3創建個人靜態主頁 &#x1f31f; 前言&#x1f60e;體驗&#x1f528; 具體實現? 核心功能&#x1f3d7;? 項目結構&#x1f680; 用這個項目部署 Git Page &#x1f4d6; 參考 &#x1f31f; 前言 作為開發者或者內容創作者&#xff0c;我們經常需要創建靜態網頁&a…

Lua語言中常用的字符串操作函數

string.sub(s, i, j) 功能: 截取字符串 s 中從位置 i 到位置 j 的子字符串。 local s "Hello, Lua!" print(string.sub(s, 1, 5)) -- 輸出 "Hello" print(string.sub(s, 8, 11)) -- 輸出 "Lua!" string.len(s) 功能&#xff1a;將字符串長度…

llm大模型學習

llm大模型 混合專家模型&#xff08;MoE&#xff09;MoE結構路由router專家expertSwitch Transformer的典型MOE模型最后MoE總結 混合專家模型&#xff08;MoE&#xff09; 模型規模是提升LLM大語言模型性能的關鍵因素&#xff0c;但也會增加計算成本。Mixture of Experts (MoE…

Linux入門攻堅——43、keepalived入門-1

Linux Cluster&#xff08;Linux集群的類型&#xff09;&#xff1a;LB、HA、HPC&#xff0c;分別是負載均衡集群、高可用性集群、高性能集群。 LB&#xff1a;lvs&#xff0c;nginx HA&#xff1a;keepalived&#xff0c;heartbeat&#xff0c;corosync&#xff0c;cman HP&am…

HTML5 動畫效果:淡入淡出(Fade In/Out)詳解

HTML5 動畫效果&#xff1a;淡入淡出&#xff08;Fade In/Out&#xff09;詳解 淡入淡出&#xff08;Fade In/Out&#xff09;是一種常見的動畫效果&#xff0c;使元素逐漸顯現或消失&#xff0c;增強用戶體驗。以下是淡入淡出的詳細介紹及實現示例。 1. 淡入淡出的特點 平滑…

YOLOv8/YOLOv11改進 添加CBAM、GAM、SimAM、EMA、CAA、ECA、CA等多種注意力機制

目錄 前言 CBAM GAM SimAM EMA CAA ECA CA 添加方法 YAML文件添加 使用改進訓練 前言 本篇文章將為大家介紹Ultralytics/YOLOv8/YOLOv11中常用注意力機制的添加&#xff0c;可以滿足一些簡單的漲點需求。本文僅寫方法&#xff0c;原理不多講解&#xff0c;需要可跳…

Go語言的 的多態性(Polymorphism)基礎知識

Go語言的多態性&#xff08;Polymorphism&#xff09;基礎知識 在編程語言中&#xff0c;多態性是一個核心概念&#xff0c;它允許同一接口被不同的數據類型所實現&#xff0c;從而在不影響代碼結構的情況下增強代碼的靈活性和可擴展性。在Go語言中&#xff0c;多態性通過接口…

nginx運行之后顯示的是上一個項目,如何解決

重啟 Nginx 使配置生效 修改 Nginx 配置后&#xff0c;你需要重新加載或重啟 Nginx&#xff0c;以使配置生效。執行以下命令&#xff1a; sudo nginx -t # 測試配置是否正確 sudo systemctl restart nginx # 重啟 Nginxbash 復制代碼 檢查瀏覽器緩存 瀏覽器可能緩存了舊…

與 Oracle Dataguard 相關的進程及作用分析

與 Oracle Dataguard 相關的進程及作用分析 目錄 與 Oracle Dataguard 相關的進程及作用分析與 Oracle Dataguard 相關的進程及作用分析一、主庫的進程1、LGWR 進程2、ARCH進程3、LNS 進程 二、備庫的進程1、RFS 進程2、ARCH3、MRP&#xff08;Managed Recovery Process&#x…

【C語言】_指針與數組

目錄 1. 數組名的含義 1.1 數組名與數組首元素的地址的聯系 1.3 數組名與首元素地址相異的情況 2. 使用指針訪問數組 3. 一維數組傳參的本質 3.1 代碼示例1&#xff1a;函數體內計算sz&#xff08;sz不作實參傳遞&#xff09; 3.2 代碼示例2&#xff1a;sz作為實參傳遞 3…