python 線性回歸模型

教材鏈接-3.2. 線性回歸的從零開始實現

c++實現

該博客僅用于記錄一下自己的代碼,可與c++實現作為對照

from d2l import torch as d2l
import torch
import random
# nn是神經網絡的縮寫
from torch import nn
from torch.utils import data# 加載訓練數據  
# 加載訓練數據集 
simples = torch.load('datas.pt')
# 這里是加載了訓練和測試數據集的真實權重和偏差,僅作為最后訓練結果的驗證使用
tw, tb = torch.load('wb.pt')
# 加載測試數據集  
tests = torch.load('test.pt')
# 獲取訓練數據集的樣本數量  
simple_num = simples.shape[0]# 獲取數據讀取迭代器  
def data_iter(batch_size, features, labels):# 計算數據的總數量num_examples = len(features)# 創建一個包含數據索引的列表  indices = list(range(num_examples))# 隨機打亂索引列表,以實現隨機讀取樣本,對訓練結果意義不明# random.shuffle(indices)# 遍歷打亂后的indices,每次取出batch_size個索引,用于構建一個小批量數據  for i in range(0, num_examples, batch_size):# 獲取當前批次的索引號并以張量形式存儲batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])# 根據索引從特征和標簽中提取數據  yield features[batch_indices], labels[batch_indices]
# 在Python中,yield 是一個關鍵字,用于定義一個生成器(generator)。生成器是一種特殊的迭代器,它允許你定義一個可以記住上一次返回時在函數體中的位置的函數。對生成器函數的第二次(或第n次)調用將恢復函數的執行,并繼續從上次掛起的位置開始。# 定義一個函數來加載并批量處理數據,返回數據獲取迭代器 
def load_array(data_arrays, batch_size, is_train=True):  #@save"""構造一個PyTorch數據迭代器"""# 使用TensorDataset將多個tensor組合成一個數據集  dataset = data.TensorDataset(*data_arrays)# 使用DataLoader加載數據集,并指定批量大小和是否打亂數據return data.DataLoader(dataset, batch_size, shuffle=is_train)# 定義線性回歸模型  
def linreg(X, w, b):  #@save"""線性回歸模型"""# 使用矩陣乘法計算預測值,并加上偏差  return torch.matmul(X, w) + b# 定義平方損失函數  
def squared_loss(y_hat, y):  #@save"""均方損失"""# 計算預測值與實際值之間的平方差,并除以2(方便梯度計算)return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2# 定義交叉熵損失函數,線性回歸模型用不到
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])# 定義一個魯棒的損失函數,結合了平方損失和絕對值損失
def robust_loss(y_hat, y, delta=1.0):residual = torch.abs(y_hat - y)return torch.where(residual<delta, 0.5* residual **2, delta*(residual-0.5*delta))# 絕對值損失函數  
def abs_loss(y_hat, y):return torch.abs(y_hat - y.reshape(y_hat.shape))# 定義隨機梯度下降函數  
def sgd(params, lr, batch_size):  #@save"""小批量隨機梯度下降"""with torch.no_grad():# 遍歷模型參數 for param in params:# 更新參數值,使用學習率lr乘以參數的梯度,并除以批量大小 param -= lr * param.grad / batch_size# 清除參數的梯度,為下一輪迭代做準備  param.grad.zero_()# 數據標準化處理  
def standard(X):X_mean = torch.mean(X, dim=0)X_std = torch.std(X, dim=0)return (X-X_mean)/X_std# 數據最小最大歸一化處理  
def min_max(X):X_min = torch.min(X, dim=0)[0]X_max = torch.max(X, dim=0)[0]return (X-X_min)/(X_max-X_min)# 不進行任何處理,直接返回輸入
def noProcess(X):return X
#Linear Regression Implementation from Scratch
if __name__ == '__main__':# 設置學習率和訓練輪數  lr = 0.03num_epochs = 20# 這里其實net變量并沒有定義為一個神經網絡模型,而是一個函數  # 但為了與后續代碼保持一致,我們仍然使用net來表示這個線性回歸函數# loss同理net = linregloss = squared_loss# 使用不進行任何處理的數據處理方式  data_process = noProcess# 將數據分成50個批次,計算每批數據的數量 batch_size = simple_num // 50# 提取特征和標簽 # 提取最后一列作為標簽  label = simples[:,-1]# 提取除最后一列外的所有列作為特征,并使用data_process進行處理feature=data_process(simples[:, :-1])# 初始化權重和偏差,權重使用正態分布初始化,偏差初始化為0  w = torch.normal(0, 1, size=(feature.shape[1], 1), requires_grad=True)# w = torch.tensor([0.3], requires_grad=True)b = torch.tensor([0.0], requires_grad=True)timer = d2l.Timer()# 開始訓練  for epoch in range(num_epochs):# 通過data_iter遍歷數據進行一輪訓練for X,y in data_iter(batch_size, feature, label):# 計算預測值y_hat = net(X, w, b)# 計算損失l = loss(y_hat, y)# 反向傳播計算梯度  l.sum().backward()# 使用隨機梯度下降更新參數sgd([w,b], lr, batch_size)# 一輪訓練結束后,計算整個訓練集上的損失,用以監控訓練效果# with torch.no_grad(): 告訴 PyTorch 在這個上下文內不要計算梯度,從而節省內存并加速計算。with torch.no_grad():label_hat = net(feature, w, b)epoch_loss = loss(label_hat, label)if epoch%5 == 0:print(f'in epoch{epoch+1}, loss is {epoch_loss.sum()}')# 在訓練完成后,計算測試集上的預測值和損失  # 提取測試集的特征和標簽 test_feature = data_process(tests[:, :-1])test_label = tests[:, -1]# 計算測試集上的預測值和損失 test_label_hat = net(test_feature, w, b)label_loss = loss(test_label_hat, test_label)print(f'in test epoch, loss is {label_loss.mean()}')print(f'true_w={tw}, true_b={tb}, w={w}, b={b}')print(f' {num_epochs} epoch, time {timer.stop():.2f} sec')
#Concise Implementation of Linear Regression
#the concise implementation have lower accuracy than from scratch
if __name__ == '__main2__':# 設置學習率、訓練輪數、數據處理方式和批量大小  lr = 0.03num_epochs = 15# 使用不進行任何處理的數據處理方式  data_process = noProcess# 將數據分成50個批次,計算每批數據的數量  batch_size = simple_num // 50# 提取特征和標簽 label = simples[:,-1]feature=data_process(simples[:, :-1])# 加載數據并創建數據迭代器  data_iter = load_array((feature, label), batch_size)# 構建神經網絡模型,這里是一個簡單的線性回歸模型  net = nn.Sequential(nn.Linear(feature.shape[1], 1))# 我們的模型只包含一個層,因此實際上不需要Sequential# 不使用Sequential時,后面的net[0]需要改為net# net = nn.Linear(feature.shape[1], 1)# 初始化網絡權重和偏置 net[0].weight.data.normal_(0, 0.01)net[0].bias.data.fill_(0)# 使用均方誤差損失函數loss = nn.MSELoss()# 使用隨機梯度下降優化器  trainer = torch.optim.SGD(net.parameters(), lr=lr)# 開始訓練  for epoch in range(num_epochs):# 通過data_iter遍歷數據進行一輪訓練for X, y in data_iter:# 前向傳播計算預測值y_hat = net(X)# 計算損失 l = loss(y_hat, y.reshape(y_hat.shape))# 梯度清零,為下一輪迭代計算做準備trainer.zero_grad()# 反向傳播計算梯度   l.backward()# 使用隨機梯度下降更新參數trainer.step()# 在每個epoch結束后,對整個數據集進行前向傳播并計算損失,用于監控訓練過程 label_hat = net(feature)epoch_loss = loss(label_hat, label.reshape(label_hat.shape))if epoch%5 == 0:print(f'in epoch{epoch+1}, loss is {epoch_loss.mean()}')# 在訓練完成后,計算測試集上的預測值和損失  # 提取測試集的特征和標簽 test_feature = data_process(tests[:, :-1])test_label = tests[:, -1]# 計算測試集上的預測值和損失 test_label_hat = net(test_feature)label_loss = loss(test_label_hat, test_label.reshape(test_label_hat.shape))print(f'in test epoch, loss is {label_loss.mean():f}')print(f'tw={tw}, tb={tb}, w={net[0].weight.data}, b={net[0].bias.data}')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/14376.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/14376.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/14376.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

什么是網關,網關有哪些作用?

網關(Gateway)是在計算機網絡中用于連接兩個獨立的網絡的設備&#xff0c;它能夠在兩個不同協議的網絡之間傳遞數據。在互聯網中&#xff0c;網關是一個可以連接不同協議的網絡的設備&#xff0c;比如說可以連接局域網和互聯網&#xff0c;它可以把局域網 的內部網絡地址轉換成…

論文閱讀--GLIP

把detection和phrase ground(對于給定的sentence&#xff0c;要定位其中提到的全部物體)這兩個任務合起來變成統一框架&#xff0c;從而擴展數據來源&#xff0c;因為文本圖像對的數據還是很好收集的 目標檢測的loss是分類loss定位loss&#xff0c;它與phrase ground的定位los…

爬蟲學習--11.MySQL數據庫的基本操作(上)

MySQL數據庫的基本操作 創建數據庫 我們可以在登陸 MySQL 服務后&#xff0c;使用命令創建數據庫&#xff0c;語法如下: CREATE DATABASE 數據庫名; 顯示所有的數據庫 show databases; 刪除數據庫 使用普通用戶登陸 MySQL 服務器&#xff0c;你可能需要特定的權限來創建或者刪…

Docker部署Minio小記

概述 因為工作需要搭建對象存儲的測試環境&#xff0c;故而使用Docker部署Minio&#xff0c;測試通過博文記錄用以備忘 步驟 拉取鏡像 docker pull minio/minio啟動容器 docker run -p 9000:9000 -p 9090:9090 \--name minio \-d --restartalways \-e "MINIO_ACCESS_K…

內臟油脂是什么?如何減掉?

真想減的人&#xff0c;減胖是很容易的&#xff0c;但想要形體美又健康&#xff0c;還是得從減內臟油脂開始&#xff0c;那么&#xff0c;問題來了&#xff0c;什么是內臟油脂&#xff1f; 油脂它分部于身體的各個角落&#xff0c;四肢、腹部、腰、臀部、臉、脖子...等&#xf…

VUE3+TS+elementplus創建table,純前端的table

一、前言 開始學習前端&#xff0c;直接從VUE3開始&#xff0c;從簡單的創建表格開始。因為自己不是專業的程序員&#xff0c;編程主要是為了輔助自己的工作&#xff0c;提高工作效率&#xff0c;VUE的基礎知識并不牢固&#xff0c;主要是為了快速上手&#xff0c;能夠做出一些…

Kubernetes中 Requests 和 Limits 的初步理解

1 靈魂拷問 我們在使用 Kubernetes 時是否遇到以下情況&#xff1a; 你會不會部署負載的時候將 CPU requests/limits 設置得過低或過高&#xff1f;你會不會部署負載的時候將 內存 requests/limits 設置得過低或過高&#xff1f;又或者你根本不設置 requests/limits&#xff…

SVN創建項目分支

目錄 背景調整目錄結構常規目錄結構當前現狀目標 調整SVN目錄調整目錄結構創建項目分支 效果展示 背景 當前自己本地做項目的時候發現對SVN創建項目不規范&#xff0c;沒有什么目錄結構&#xff0c;趁著創建目錄分支的契機&#xff0c;順便調整下SVN服務器上的目錄結構 調整目…

Stable Diffusion WebUI使用inpaint anything插件實現圖片局部重繪

Inpaint Anything是一個強大的圖像處理工具,它結合了SAM(Segment Anything Model)、圖像修補模型(如LaMa)和AIGC模型(如Stable Diffusion)等先進技術,以實現圖像中物體的移除、內容的填補以及場景的替換。無論是對圖像中的任何元素進行編輯,還是對圖像整體進行場景轉換…

【Vue】Vue2使用ElementUI

目錄 Element UI介紹特點Vue2使用Element安裝引入ElementUI組件庫 使用ElementUI用戶注冊列表展示其他 mint-ui介紹特點安裝組件引入組件Mint-ui相關組件 Element UI 介紹 官網(基于 Vue 2.x ):https://element.eleme.cn/#/zh-CN ElementUI 是一個基于 Vue.js 的桌面端組件庫…

Vue文本溢出如何自動換行

css新增 word-break: break-all; word-wrap: break-word;

【Linux系統】文件與基礎IO

本篇博客整理了文件與文件系統、文件與IO的相關知識&#xff0c;借由庫函數、系統調用、硬件之間的交互、操作系統管理文件的手段等&#xff0c;旨在讓讀者更深刻地理解“Linux下一切皆文件”。 【Tips】文件的基本認識 文件 內容 屬性。文件在創建時就有基本屬性&#xff0…

網易:一季度營收269億元,連續7季研發強度超15%領跑行業

5月23日&#xff0c;網易發布2024年第一季度財報。財報顯示&#xff0c;網易Q1營收269億元&#xff0c;歸屬于公司股東的凈利潤85億元&#xff08;Non-GAAP&#xff09;&#xff0c;以連續7個季度超15%的研發投入強度領跑行業&#xff0c;首季業績穩健啟航。 一季度&#xff0…

JVM學習-動態鏈接和方法返回地址

動態鏈接–指向運行時常量池的方法引用 每一個棧幀內部包含一個指向運行時常量池中該棧幀所屬方法的引用&#xff0c;包含這個引用的目的為了支持當前方法的代碼能夠實現動態鏈接(Dynamic Linking)&#xff0c;如invokednamic指令。在Java源文件被編譯到字節碼文件中時&#x…

云平臺概要設計文檔 -大綱

1. 引言 1.1 目的 本文檔的目的是提供一份詳細的技術規范,用以指導開發團隊實現云平臺的建設和部署。該文檔旨在確保所有開發人員和相關技術人員對系統的架構、組件、交互流程、數據處理及安全措施有深入的理解,從而能夠高效、一致地開發出符合預期功能和性能要求的系統。 …

JAVA:淺談JSON與JSON轉換

可能有很多人&#xff0c;無論是前端還是后端&#xff0c;無論是JAVA還是Python還是C&#xff0c;都應該跟JSON這種數據格式打過交道&#xff0c;那么有沒有仔細的想過&#xff0c;什么叫JSON&#xff1f; JSON是一種輕量級的數據交換格式。它基于JavaScript語言的對象表示法&a…

初識java——javaSE(6)抽象類與接口【求個關注!】

文章目錄 前言一 抽象類1.1 抽象類的概念1.2 抽象類的語法&#xff1a;1.3 抽象類與普通類的區別&#xff1a; 二 接口2.1 接口的概念2.2 接口的語法2.2.1 接口的各個組成2.2.2 接口之間的繼承 2.3 接口的實現接口不可以實例化對象 2.4 接口實現多態 三 Object類3.1 Object類是…

【shell】腳本練習題

案例&#xff1a; 1. for ping測試指網段的主機 網段由用戶輸入&#xff0c;例如用戶輸入192.168.2 &#xff0c;則ping 192.168.2.10 --- 192.168.2.20 UP&#xff1a; /tmp/host_up.txt Down: /tmp/host_down.txt 2. 使用case實現成績優良差的判斷 1. for ping測試指…

Android異常及解決方式記錄

異常1&#xff1a;Tmp detached view should be removed from RecyclerView before it can be recycled: 解決方法&#xff1a; recycleView.setItemAnimator(null);

第17講:C語言內存函數

目錄 1.memcpy使用和模擬實現2.memmove使用和模擬實現3.memset函數的使用4.memcmp函數的使用 1.memcpy使用和模擬實現 void * memcpy (void * destination, const void * source, size_t num);? 函數memcpy從source的位置開始向后復制num個字節的數據到destination指向的內存…