1.手動LogisticRegression模型的訓練和預測

通過這個示例,可以了解邏輯回歸模型的基本原理和訓練過程,同時可以通過修改和優化代碼來進一步探索機器學習模型的訓練和調優方法。

過程:

  1. 生成了一個模擬的二分類數據集:通過隨機生成包含兩個特征的數據data_x,并基于一定規則生成對應的二分類標簽數據data_y
  2. 創建了一個手動實現的邏輯回歸模型LogisticRegressionManually,其中包括:
    • 初始化函數__init__:初始化模型的權重參數w和偏置參數b
    • 前向傳播函數forward:計算給定輸入數據的預測值。
    • 損失函數loss_func:定義了交叉熵損失函數,用于評估模型的預測性能。
    • 訓練函數train:在每個epoch中,遍歷數據集的每個樣本,計算預測值、損失值、梯度,并利用梯度下降法更新模型參數。
  3. 實例化LogisticRegressionManually類,然后調用train方法對模型進行訓練。
  4. 在訓練過程中,打印每個epoch的損失值。

演示:

# 生成模擬的二分類數據集,其中X數據是隨機生成的,Y數據根據一定規則生成。import torch
import torch.nn.functional as Fn_items = 1000
n_features = 2
learning_rate = 0.001
epochs = 100# 置了隨機種子,以確保每次運行代碼時生成的隨機數相同,從而使結果具有可重現性。
torch.manual_seed(123) 
# 生成了一個大小為(1000, 2)的張量data_x,其中包含1000個樣本,每個樣本具有2個特征。這里使用torch.randn生成標準正態分布的隨機數作為數據,并將數據類型轉換為float。
data_x = torch.randn(size=(n_items, n_features)).float()
# 成了標簽數據data_y,通過對第一個特征乘以0.5和第二個特征乘以1.5的差值進行判斷,如果差值大于0就將標簽設為1,否則為0。這樣生成了一個二分類標簽數據集,同樣將數據類型轉換為float。
data_y = torch.where(torch.subtract(data_x[:, 0]*0.5, data_x[:, 1]*1.5) > 0, 1., 0.).float()# print(data_x)
# print(data_y)

# 在每個epoch中,遍歷數據集的每個樣本,計算預測值、損失值、梯度,利用梯度下降法更新模型參數。通過這種方式訓練模型可以逐漸優化模型參數,以達到更好的預測效果。class LogisticRegressionManually(object):# 初始化函數__init__def __init__(self):# w是一個大小為(n_features, 1)的張量,用于存儲權重參數,并且設置了requires_grad=True表示需要計算梯度;self.w = torch.randn(size=(n_features, 1), requires_grad=True)# b是一個大小為(1, 1)的張量,用于存儲偏置參數,并且設置了requires_grad=Trueself.b = torch.zeros(size=(1, 1), requires_grad=True)# 前向傳播函數forwarddef forward(self, x):# 過矩陣乘法計算預測值y_hat:將參數w轉置后與輸入數據x相乘,并加上偏置b后通過F.sigmoid函數進行激活,最終返回激活后的預測值。y_hat = F.sigmoid(torch.matmul(self.w.transpose(0, 1), x) + self.b)return y_hat# 損失函數loss_func@staticmethoddef loss_func(y_hat, y):# 定義了交叉熵損失函數。通過計算實際標簽y和預測值y_hat之間的交叉熵損失來評估模型的預測性能。return -(torch.log(y_hat)*y + (1-y)*torch.log(1-y_hat))# 訓練函數traindef train(self):# 在每個epoch中,遍歷數據集中的每個樣本for epoch in range(epochs):for step in range(n_items):# 利用模型的前向傳播函數forward計算當前樣本的預測值y_hat。y_hat = self.forward(data_x[step])# 獲取當前樣本的真實標簽yy = data_y[step]# 調用損失函數loss_func計算預測值與真實標簽之間的損失。loss = self.loss_func(y_hat, y)# 利用反向傳播計算損失對模型參數的梯度loss.backward()# 進入torch.no_grad()上下文管理器,保證在該范圍內的操作不會被記錄用于自動微分。with torch.no_grad():# 更新權重參數w和偏置參數b,通過梯度下降法更新參數,learning_rate是學習率。self.w.data -= learning_rate * self.w.gradself.b.data -= learning_rate * self.b.grad# 清零梯度,以便進行下一次參數更新時重新計算梯度。self.w.grad.data.zero_()self.b.grad.data.zero_()print("Epoch: %03d, Loss: %.3f" % (epoch, loss.item()))

lrm = LogisticRegressionManually()
lrm.train()

結果:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/16045.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/16045.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/16045.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

秋招突擊——算法打卡——5/25、5/26——尋找兩個正序數組的中位數

題目描述 自我嘗試 首先,就是兩個有序的數組進行遍歷,遍歷到一半即可。然后求出均值,下述是我的代碼。但這明顯是有問題的,具體錯誤的代碼如下。計算復雜度太高了,O(n),所以會超時&…

數據結構--《二叉樹》

二叉樹 1、什么是二叉樹 二叉樹(Binar Tree)是n(n>0)個結點的優先集合,該集合或者為空集(稱為空二叉樹),或者由一個根結點和兩顆互不相交的、分別稱為根結點的左子樹和右子樹的二叉樹構成。 這里給張圖,能更直觀的感受二叉樹&#xff1…

GDPU JavaWeb mvc模式

搭建一個mvc框架的小實例。 簡易計算器 有一個名為inputNumber.jsp的頁面提供一個表單,用戶可以通過表單輸入兩個數和運算符號提交給Servlet控制器;由名為ComputerBean.java生成的JavaBean負責存儲運算數、運算符號和運算結果,由名為handleCo…

C#中獲取FTP服務器文件

1、從ftp下載pdf的方法 public static void DownloadPdfFileFromFtp(string ftpUrl,string user,string password string localPath) { // 創建FtpWebRequest對象 FtpWebRequest request (FtpWebRequest)WebRequest.Create(ftpUrl); request.Method WebRequestMethods.Ftp…

簡單好用的文本識別方法--付費的好用,免費的更有性價比-記筆記

文章目錄 先說付費的進入真題,免費的來喏!PixPin微信 先說付費的 直達網址!!! 進入真題,免費的來喏! PixPin 商店里就有 使用示例: 可以看到:貼在桌面上的圖片可以復制圖片中的文字,真的很…

深入了解ASPICE標準:提升汽車軟件開發與質量管理的利器

隨著汽車行業的快速發展和技術創新,汽車軟件的開發和質量管理的重視程度不斷提升。ASPICE(Automotive Software Process Improvement and Capability Determination)標準作為一種專門針對汽車軟件開發過程的改進和能力評定的框架,…

Springboot+Vue+ElementUI開發前后端分離的員工管理系統01--系統介紹

項目介紹 springboot_vue_emp是一個基于SpringbootVueElementUI實現的前后端分離的員工管理系統 功能涵蓋: 系統管理:用戶管理、角色管理、菜單管理、字典管理、部門管理出勤管理:請假管理、考勤統計、工資發放、工資統計、離職申請、個人資…

8.Redis之hash類型

1.hash類型的基本介紹 哈希表[之前學過的所有數據結構中,最最重要的] 1.日常開發中,出場頻率非常高. 2.面試中,非常重要的考點, Redis 自身已經是鍵值對結構了Redis 自身的鍵值對就是通過 哈希 的方式來組織的 把 key 這一層組織完成之后, 到了 value 這一層~~ value 的其中…

最重要的時間表示,柯橋外貿俄語小班課

в第四格 1、與表示“鐘點”的數詞詞組連用 例: в шесть часов утра 在早上六點 в пять тридцать 在五點半 2、與表示“星期”的名詞連用 例: в пятницу 在周五 в следующий понедельник …

包和依賴管理:Python的pip和conda使用指南

包和依賴管理:Python的pip和conda使用指南 對于Python新手來說,包和依賴管理可能是一個令人困惑的概念。但不用擔心,本文將用淺顯易懂的語言,詳細介紹如何使用Python的兩個主要包管理工具:pip和conda。我們還會探討在安…

為 AWS 子賬戶添加安全組修改權限

文章目錄 步驟 1:創建 IAM 策略步驟 2:附加策略到子賬戶步驟 3:驗證權限 本文檔將操作如何為 AWS 子賬戶(IAM 用戶或角色)添加修改安全組的權限,包括 AuthorizeSecurityGroupIngress 和 RevokeSecurityGr…

解決uniApp 中不能直接使用 Axios 的問題

最近在使用 uniapp 進行小程序開發的時候,發現 uniapp 不能直接使用 axios,需要自己進行封裝一個 http 庫使用,于是有了這個項目。 項目地址:https://www.npmjs.com/package/uni-app-wxnetwork-tool 該包的功能介紹: u…

String類為什么設計成不可變的?

目錄 緩存 安全性 線程安全 hashCode緩存 性能 其實這個問題我們可以通過緩存、安全性、線程安全和性能幾個維度去解析。 緩存 字符串是Java最常用的數據結構,我們都知道字符串大量創建是非常耗費資源的,所以Java中就將String設計為帶有緩存的功能…

軟考 系統架構設計師之考試感悟2

接前一篇文章:軟考 系統架構設計師之考試感悟 今天是2024年5月25號,是個人第二次參加軟考系統架構師考試的正日子。和上次一樣,考了一天,身心俱疲。天是陰的,心是沉的,感覺比上一次更加沉重。仍然有諸多感悟…

express框架下后端獲取req.body報錯undefined

express框架下后端獲取req.body報錯undefined_express服務器post中data為undefine-CSDN博客 /*** 特殊說明:Express是一個單線程服務器器程序【必須存在指定的順序調用,否則無法達到預期的效果】*//*** 第一步:創建一個Express實例對象,并且在匹配路由之…

【python】python tkinter 計算器GUI版本(模仿windows計算器 源碼)【獨一無二】

👉博__主👈:米碼收割機 👉技__能👈:C/Python語言 👉公眾號👈:測試開發自動化【獲取源碼商業合作】 👉榮__譽👈:阿里云博客專家博主、5…

17.分類問題

機器學習分類問題詳解與實戰 介紹 在機器學習中,分類問題是一類常見的監督學習任務,其目標是根據輸入特征將數據樣本劃分為預先定義的類別之一。分類問題廣泛應用于各個領域,如圖像識別、自然語言處理、金融風險評估等。本文將詳細介紹機器…

Spring Cloud 項目中使用 Swagger

Spring Cloud 項目中使用 Swagger 關于方案的選擇 在 Spring Cloud 項目中使用 Swagger 有以下 4 種方式: 方式一 :在網關處引入 Swagger ,去聚合各個微服務的 Swagger。未來是訪問網關的 Swagger 原生界面。 方式二 :在網關處引…

RedHat9 | DNS剖析-配置輔助DNS服務器

一、實驗環境 1、輔助域名DNS服務器 DNS通過劃分為若干個區域進行管理,每一個區域由1臺或多臺DNS服務器負責解析,如果僅僅采用1臺DNS服務器,在DNS服務器出現故障后,用戶將無法完成解析。 輔助DNS服務器的優點 容災備份&#x…

區間預測 | Matlab實現DNN-KDE深度神經網絡結合核密度估計多置信區間多變量回歸區間預測

區間預測 | Matlab實現DNN-KDE深度神經網絡結合核密度估計多置信區間多變量回歸區間預測 目錄 區間預測 | Matlab實現DNN-KDE深度神經網絡結合核密度估計多置信區間多變量回歸區間預測效果一覽基本介紹程序設計參考資料 效果一覽 基本介紹 1.Matlab實現DNN-KDE深度神經網絡結合…