pytorch(四)用pytorch實現線性回歸

文章目錄

    • 代碼過程
    • 準備數據
    • 設計模型
    • 設計構造函數與優化器
    • 訓練過程
    • 訓練代碼和結果
    • pytorch中的Linear層的底層原理(個人喜歡,不用看)
      • 普通矩陣乘法實現
      • Linear層實現
    • 回調機制

代碼過程

訓練過程:

  1. 準備數據集
  2. 設計模型(用來計算 y ^ \hat y y^?
  3. 構造損失函數和優化器(API)
  4. 訓練周期(前饋、反饋、更新)

準備數據

這里的輸入輸出數據均表示為3×1的,也就是維度均為1

# 行表示實例數量,列表示維度feature
import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])

設計模型

模型繼承Module類,并且必須要實現 init 和 forward 兩個方法,其中self.linear=torch.nn.Linear(1,1)表示實例化Linear類,這個類是可調用的,其__call__函數調用了 forward 方法

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel,self).__init__()# weight 和 bias 1 1 self.linear=torch.nn.Linear(1,1)def forward(self,x):# callabley_pred=self.linear(x)return y_pred# callable
model=LinearModel()

pytorch中的linear類是在某一個數據上應用線性轉換,其公式表達為 y = x w T + b y=xw^T+b y=xwT+b

class torch.nn.Linear(in_features,out_features,bias=True) :其中in_features和out_features分別表示輸入和輸出的數據的維度(列的數量),bias表示偏置,默認是true,該類有兩個參數

  • weight:可學習參數,值從均勻分布 U ( ? k , k ) U(-\sqrt k,\sqrt k) U(?k ?,k ?)中獲取,其中 k = 1 i n _ f e a t u r e s k=\frac{1}{in\_features} k=in_features1?
  • bias:shape和輸出的維度一樣,也是從分布 U ( ? k , k ) U(-\sqrt k,\sqrt k) U(?k ?,k ?)中初始化的
    在這里插入圖片描述

設計構造函數與優化器

# 構造損失函數和優化器
criterion=torch.nn.MSELoss(size_average=False)# w和b--->parameters
opyimizer=torch.optim.SGD(model.parameters(),lr=0.01)

在這里插入圖片描述
在這里插入圖片描述

訓練過程

# 訓練過程
for epoch in range(100):y_pred=model(x_data)loss=criterion(y_pred,y_data)# loss標量,自動調用__str__()print(epoch,loss)optimizer.zero_grad()# backwardloss.backward()# updateoptimizer.step()

訓練代碼和結果

# 行表示實例數量,列表示維度feature
import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel,self).__init__()# weight 和 bias 1 1 self.linear=torch.nn.Linear(1,1)def forward(self,x):# callabley_pred=self.linear(x)return y_pred# callable
model=LinearModel()# 構造損失函數和優化器
criterion=torch.nn.MSELoss(size_average=False)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)# 訓練過程
for epoch in range(100):y_pred=model(x_data)loss=criterion(y_pred,y_data)# loss標量,自動調用__str__()print(epoch,loss)optimizer.zero_grad()# backwardloss.backward()# updateoptimizer.step()# 打印信息
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())x_test=torch.Tensor([4.0])
y_test=model(x_test)
print('y_pred=',y_test.data)

在這里插入圖片描述


pytorch中的Linear層的底層原理(個人喜歡,不用看)

我們在課本使用到的線性函數的基本公式表達為 y = x w T + b y=xw^T+b y=xwT+b,但是在Linear層中,當輸入特征被Linear層接收是,它會接收后轉置,然后乘以權重矩陣,得到的是輸出特征的轉置,換句話說可以在底層使用Linear,它實際上做的是 y T = w x T + b y^T=wx^T+b yT=wxT+b。可以使用下面的案例進行驗證:

在這里插入圖片描述

普通矩陣乘法實現

很明顯,上面的圖標表示一個 3×4 的矩陣乘以 4×1 的矩陣,得到一個 3×1 的輸出矩陣,使用普通矩陣的乘法實現如下。

import torchin_features=torch.tensor([1,2,3,4],dtype=torch.float32)
weight_matrix=torch.tensor([[1,2,3,4],[2,3,4,5],[3,4,5,6]
],dtype=torch.float32)weight_matrix.matmul(in_features)# 矩陣乘法

實現截圖:
在這里插入圖片描述

Linear層實現

# 這里還是使用上面使用過的數據
import torch
in_features=torch.tensor([1,2,3,4],dtype=torch.float32)
weight_matrix=torch.tensor([[1,2,3,4],[2,3,4,5],[3,4,5,6]
],dtype=torch.float32)print(weight_matrix.matmul(in_features))# 矩陣乘法fc = torch.nn.Linear(in_features=4, out_features=3, bias=False)
# 這里是隨機一個權重矩陣
print('fc.weight',fc.weight)
fc(in_features)

輸出結果:
在這里插入圖片描述

print('fc.weight',fc.weight)# 使用上面的權重矩陣進行計算
fc.weight = torch.nn.Parameter(weight_matrix)
print('fc.weight',fc.weight)
fc(in_features)

結果截圖:
在這里插入圖片描述

可以看到上面截圖與下面的截圖的區別,一開始隨機一個權重的時候,進行運算,使用到前面提及到的權重矩陣后,Linear層進行運算之后,得到與使用普通矩陣乘法一樣的結果,相同的結果說明,Linear底層的實現與上面的矩陣乘法的邏輯是一致的

以上的論證可以說明,Linear的底層實現其實是 y T = w x T + b y^T=wx^T+b yT=wxT+b,而不是 y = x w T + b y=xw^T+b y=xwT+b,可能會有人好奇,為什么書本上都是寫的后者而不是寫前者,其實本質上二者都一樣,前者的轉置就是后者。

回調機制

在pytorch學習(一)線性模型中,第一個代碼中,我們沒有通過pytorch實現線性模型的時候,我們會顯式調用forward函數,計算前饋的值,我們是這樣寫的y_pred_val=forward(x_val),但是在使用pytorch之后,我們是這樣寫的y_pred=model(x_data),直接實例化一個對象,然后通過對象直接計算預測值(前饋值),但是并沒有使用到forward函數。這是因為pytorch模塊類中實現了python中一個特殊的函數,也就是回調函數

如果一個類實現了回調方法,那么只要對象實例被調用,這個特殊的方法也會被調用。我們不直接調用forward()方法,而是調用對象實例。在對象實例被調用之后,在底層調用了__ call __方法,然后調用了forward()方法。這適用于所有的PyTorch神經網絡模塊。

以上僅代表小白個人學習觀點,如有錯誤歡迎批評指正。

參考

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/718935.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/718935.shtml
英文地址,請注明出處:http://en.pswp.cn/news/718935.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

國圖公考:山東事業編考試即將開始

山東事業編考試時間為2024年3月10日-9.00-11.30分 考試科目為公基寫作 準考證打印時間為2024年3月5日9.00-3月10日9.30分 準考證打印入口:山東考試信息網 綜合類筆試在全省十六市均設置考點,參加考試的考生可憑借準考證和本人身份證參加筆試

Python爬蟲實戰(基礎篇)—13獲取《人民網》【最新】【國內】【國際】寫入Word(附完整代碼)

文章目錄 專欄導讀背景測試代碼分析請求網址請求參數代碼測試數據分析利用lxml+xpath進一步分析將獲取鏈接再獲取文章內容測試代碼寫入word完整代碼總結專欄導讀 ????本文已收錄于《Python基礎篇爬蟲》 ????本專欄專門針對于有爬蟲基礎準備的一套基礎教學,輕松掌握Py…

第 2 個 Java Web 應用工程(JSP JavaBean DB)(含源碼)(圖文版)

JavaBean 是一種符合特定約定的 Java 類,通常用于在 Java 應用程序中封裝數據以及提供對數據的訪問和修改方法。 本文示例:建立一個 Tomcat 工程,編寫一個 JSP 頁面,調用 JavaBean 訪問數據庫并顯示到頁面上,發布到 T…

【開源物聯網平臺】物聯網設備上云提供開箱即用接入SDK

一、項目介紹 IOTDeviceSDK是物聯網平臺提供的設備端軟件開發工具包,可簡化開發過程,實現設備快速接入各大物聯網平臺。 設備廠商獲取SDK后,根據需要選擇相應功能進行移植,即可快速集成IOTDeviceSDK,實現設備的接入。…

gradle中設置變量,在代碼中讀取

在app的gradlew文件中設置變量appModelCode,設置manifestPlaceholders android {def appModelCode 1 //1:模式1 2:模式2def appModelName "model1"if (appModelCode 1) {...}defaultConfig {applicationId appIdminSdk 26targetSdk 32versionCode app…

音視頻數字化(視頻線纜與接口)

目錄 1、DVI接口 2、DP接口 之前的文章【音視頻數字化(線纜與接口)】提到了部分視頻線纜,今天再補充幾個。 視頻模擬信號連接從蓮花頭的“復合”線開始,經歷了S端子、色差分量接口,通過亮度、色度盡量分離的辦法提高畫面質量,到VGA已經到了模擬的頂峰,實現了RGB的獨立…

android 推薦一個上拉加載更多,下拉刷新的框架(非常好用)

作者:scwang 大神 GitHub - scwang90/SmartRefreshLayout: 🔥下拉刷新、上拉加載、二級刷新、淘寶二樓、RefreshLayout、OverScroll,Android智能下拉刷新框架,支持越界回彈、越界拖動,具有極強的擴展性,…

一文讀懂Penpad 以 Fair Launch 方式推出的首個資產 PEN

隨著 2 月 28 日比特幣重新站上 6 萬美元的高峰后,標志著加密市場正在進入新一輪牛市周期。在 ETF 的促進作用下,加密市場不斷有新的資金流入,加密貨幣總市值不斷攀升。Layer2 市場率先做出了反應,有數據顯示,當前以太…

2020PAT--冬

The Closest Fibonacci Number The Fibonacci sequence Fn? is defined by Fn2?Fn1?Fn? for n≥0, with F0?0 and F1?1. The closest Fibonacci number is defined as the Fibonacci number with the smallest absolute difference with the given integer N. Your job…

Spring初始(相關基礎知識和概述)

Spring初始(相關基礎知識和概述) 一、Spring相關基礎知識(引入Spring)1.開閉原則OCP2.依賴倒置原則DIP3.控制反轉IoC 二、Spring概述1.Spring 8大模塊2.Spring特點2.Spring的常用jar文件 一、Spring相關基礎知識(引入S…

除微信視頻號下載器還有哪些可以應用可以下載視頻?

市面上有很多視頻號下載器,但猶豫部分視頻號下載器逐步失效,就有很多小伙伴問還有哪些可以應用可以下載視頻? 視頻下載助手 除視頻號視頻下載器以外,還有【視頻號下載助手】簡稱:視頻下載助手 比如說,抖音…

spring cloud 之 Netflix Eureka

1、Eureka 簡介 Eureka是Spring Cloud Netflix 微服務套件中的一個服務發現組件,本質上是一個基于REST的服務,主要用于AWS云來定位服務以實現中間層服務的負載均衡和故障轉移,它的設計理念就是“注冊中心”。 你可以認為它是一個存儲服務地址信息的大本…

18個驚艷的可視化大屏(第14輯):能源行業應用

能源行業涉及能源生產、轉化、儲存、輸送和使用的各個領域和環節,包括石油和天然氣行業、煤炭行業、核能行業、可再生能源行業和能源服務行業,本期貝格前端工場帶來能源行業可視化大屏界面供大家欣賞。 能源行業的組成 能源行業是指涉及能源生產、轉化、…

Android 11.0 禁止系統界面下拉狀態欄和通知欄 手機 平板 車載 TV 投影 通用

1、禁止systemUI下拉狀態欄和通知欄的核心代碼部分 framework/base/packages/apps/SystemUI/src/com/android/systemui/keyguard/KeyguardViewMediator.java framework/base/packages/apps/SystemUI/src/com/android/systemui/statusbar/phone/CollapsedStatusBarFragment.jav…

數字化轉型導師堅鵬:金融機構數字化運營

金融機構數字化運營 課程背景: 很多金融機構存在以下問題: 不清楚數字化運營對金融機構發展有什么影響? 不知道如何提升金融機構數字化運營能力? 不知道金融機構如何開展數字化運營工作? 課程特色:…

盤點全網哪些超乎想象的高科技工具?有哪些免費開源的最新AI智能工具?短視頻自媒體運營套裝?

盤點全網哪些超乎想象的高科技工具?有哪些免費開源的最新AI智能工具?短視頻自媒體運營套裝? 自媒體主要用來干什么? 可以通過短視頻吸引更多的觀眾和粉絲,提升自媒體賬號的影響力和知名度。 短視頻形式更加生動、直觀…

使用C++界面框架ImGUI開發一個簡單程序

簡介 ImGui 是一個用于C的用戶界面庫,跨平臺、無依賴,支持OpenGL、DirectX等多種渲染API,是一種即時UI(Immediate Mode User Interface)庫,保留模式與即時模式的區別參考保留模式與即時模式。ImGui渲染非常…

關于企業數字化轉型:再認識、再思考、再出發

近年來,隨著國家數字化政策不斷出臺、新興技術不斷進步、企業內生需求持續釋放,數字化轉型逐步成為企業實現高質量發展的必由之路,成為企業實現可持續發展乃至彎道超車的重要途徑。本文重點分析當下阻礙企業數字化轉型的難點,提出…

SPC 之 I-MR 控制圖

概述 1924 年,美國的休哈特博士應用統計數學理論將 3Sigma 原理運用于生產過程中,并發表了 著名的“控制圖法”,對產品特性和過程變量進行控制,開啟了統計過程控制新時代。 什么是控制圖 控制圖指示過程何時不受控制&#xff…

通過 Jenkins 經典 UI 創建一個基本流水線

通過 Jenkins 經典 UI 創建一個基本流水線 點擊左上的 新建任務。 在 輸入一個任務名稱字段,填寫你新建的流水線項目的名稱。 點擊 流水線,然后點擊頁面底部的 確定 打開流水線配置頁 點擊菜單的流水線 選項卡讓頁面向下滾動到 流水線 部分 在 流水線 …