使用 Numpy 自定義數據集,使用pytorch框架實現邏輯回歸并保存模型,然后保存模型后再加載模型進行預測,對預測結果計算精確度和召回率及F1分數

1. 導入必要的庫

首先,導入我們需要的庫:Numpy、Pytorch 和相關工具包。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, f1_score
2. 自定義數據集

使用 Numpy 創建一個簡單的線性可分數據集,并將其轉換為 Pytorch 張量。

# 創建數據集
X = np.random.rand(100, 2)  # 100 個樣本,2 個特征
y = (X[:, 0] + X[:, 1] > 1).astype(int)  # 標簽,若特征之和大于1則為 1,否則為 0# 轉換為 PyTorch 張量
X_train = torch.tensor(X, dtype=torch.float32)
y_train = torch.tensor(y, dtype=torch.long)
3. 定義邏輯回歸模型

在 Pytorch 中定義一個簡單的邏輯回歸模型。

class LogisticRegressionModel(nn.Module):def __init__(self, input_dim):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, 2)  # 二分類問題def forward(self, x):return self.linear(x)
4. 初始化模型、損失函數和優化器
# 初始化模型
model = LogisticRegressionModel(input_dim=2)# 損失函數與優化器
criterion = nn.CrossEntropyLoss()  # 交叉熵損失函數
optimizer = optim.SGD(model.parameters(), lr=0.01)
5. 訓練模型

訓練模型并保存訓練好的權重。

epochs = 100
for epoch in range(epochs):# 前向傳播outputs = model(X_train)loss = criterion(outputs, y_train)# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 20 == 0:print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")# 保存模型
torch.save(model.state_dict(), 'logistic_regression.pth')
6. 加載模型并進行預測

加載保存的模型并進行預測。

# 加載模型
model = LogisticRegressionModel(input_dim=2)
model.load_state_dict(torch.load('logistic_regression.pth'))
model.eval()  # 設為評估模式# 預測
with torch.no_grad():y_pred = model(X_train)_, predicted = torch.max(y_pred, 1)
7. 計算精確度、召回率和 F1 分數

使用 sklearn 中的評估函數計算精確度、召回率和 F1 分數。

accuracy = accuracy_score(y_train, predicted)
recall = recall_score(y_train, predicted)
f1 = f1_score(y_train, predicted)print(f"Accuracy: {accuracy:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
8. 總結

這篇博客展示了如何使用 Numpy 自定義數據集,利用 Pytorch 框架實現邏輯回歸模型,并進行訓練。訓練后的模型被保存,并在加載后進行預測,最后計算了精確度、召回率和 F1 分數。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/894592.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/894592.shtml
英文地址,請注明出處:http://en.pswp.cn/news/894592.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Unity-編譯構建Android的問題記錄

文章目錄 報錯:AAPT2 aapt2-4.1.2-6503028-osx Daemon #0 Failed to shutdown within timeout報錯信息解讀:原因分析最終處理方法 報錯:AAPT2 aapt2-4.1.2-6503028-osx Daemon #0 Failed to shutdown within timeout 報錯信息解讀&#xff1…

【axios二次封裝】

axios二次封裝 安裝封裝使用 安裝 pnpm add axios封裝 // 進行axios二次封裝:使用請求與響應攔截器 import axios from axios import { ElMessage } from element-plus//創建axios實例 const request axios.create({baseURL: import.meta.env.VITE_APP_BASE_API,…

SQL進階實戰技巧:如何構建用戶行為轉移概率矩陣,深入洞察會話內活動流轉?

目錄 1 場景描述 1.1 用戶行為轉移概率矩陣概念 1.2 用戶行為轉移概率矩陣構建方法 (1) 數據收集

Vue3.0實戰:大數據平臺可視化(附完整項目源碼)

文章目錄 創建vue3.0項目項目初始化項目分辨率響應式設置項目頂部信息條創建頁面主體創建全局引入echarts和axios后臺接口創建express銷售總量圖實現完整項目下載項目任何問題都可在評論區,或者直接私信即可。 創建vue3.0項目 創建項目: vue create vueecharts選擇第三項:…

Java自定義IO密集型和CPU密集型線程池

文章目錄 前言線程池各類場景描述常見場景案例設計思路公共類自定義工廠類-MyThreadFactory自定義拒絕策略-RejectedExecutionHandlerFactory自定義阻塞隊列-TaskQueue(實現 核心線程->最大線程數->隊列) 場景1:CPU密集型場景思路&…

【VM】VirtualBox安裝ubuntu22.04虛擬機

閱讀本文之前,請先根據 安裝virtualbox 教程安裝virtulbox虛擬機軟件。 1.下載Ubuntu系統鏡像 打開阿里云的鏡像站點:https://developer.aliyun.com/mirror/ 找到如圖所示位置,選擇Ubuntu 22.04.3(destop-amd64)系統 Ubuntu 22.04.3(desto…

Pandas基礎08(分箱操作/時間序列/畫圖)

3.8.1 Pandas分箱操作 數據分箱(Binning) 是一種數據預處理方法,用于將連續型變量的數值范圍分割成若干個區間或“箱”(bins),將數據按照這些區間進行分類,從而轉換為離散型變量。這種方法常用…

C#,shell32 + 調用控制面板項(.Cpl)實現“新建快捷方式對話框”(全網首發)

Made By 于子軒,2025.2.2 不管是使用System.IO命名空間下的File類來創建快捷方式文件,或是使用Windows Script Host對象創建快捷方式,亦或是使用Shell32對象創建快捷方式,都對用戶很不友好,今天小編為大家帶來一種全新…

國產編輯器EverEdit - 輸出窗口

1 輸出窗口 1.1 應用場景 輸出窗口可以顯示用戶執行某些操作的結果,主要包括: 查找類:查找全部,篩選等待操作,可以把查找結果打印到輸出窗口中; 程序類:在執行外部程序時(如:命令窗…

Vue-data數據

目錄 一、Vue中的data數據是什么?二、data支持的數據類型有哪些? 一、Vue中的data數據是什么? Vue中用到的數據定義在data中。 二、data支持的數據類型有哪些? data中可以寫復雜類型的數據,渲染復雜類型數據時只要遵…

02.03 遞歸運算

使用遞歸求出 1 1/3 -1/5 1/7 - 1/9 ... 1/n的值。 1>程序代碼 #include <stdio.h> #include <string.h> #include <unistd.h> #include <stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #inc…

數據分析系列--⑥RapidMiner構建決策樹(泰坦尼克號案例含數據)

一、資源下載 二、數據處理 1.導入數據 2.數據預處理 三、構建模型 1.構建決策樹 2.劃分訓練集和測試集 3.應用模型 4.結果分析 一、資源下載 點擊下載數據集 二、數據處理 1.導入數據 2.數據預處理 三、構建模型 1.構建決策樹 雖然決策樹已經構建,但對于大多數初學者或…

高階開發基礎——快速入門C++并發編程6——大作業:實現一個超級迷你的線程池

目錄 實現一個無返回的線程池 完全代碼實現 Reference 實現一個無返回的線程池 實現一個簡單的線程池非常簡單&#xff0c;我們首先聊一聊線程池的定義&#xff1a; 線程池&#xff08;Thread Pool&#xff09; 是一種并發編程的設計模式&#xff0c;用于管理和復用多個線程…

pytorch實現主成分分析 (PCA):用于數據降維和特征提取

人工智能例子匯總&#xff1a;AI常見的算法和例子-CSDN博客 使用 PyTorch 實現主成分分析&#xff08;PCA&#xff09;可以通過以下步驟進行&#xff1a; 標準化數據&#xff1a;首先&#xff0c;需要對數據進行標準化處理&#xff0c;確保每個特征的均值為 0&#xff0c;方差…

100 ,【8】 buuctf web [藍帽杯 2021]One Pointer PHP(別看)

進入靶場 沒提示&#xff0c;去看源代碼。 user.php <?php // 定義一個名為 User 的類&#xff0c;該類可用于表示用戶相關信息或執行與用戶有關的操作 class User{// 聲明一個公共屬性 $count&#xff0c;可在類的內部和外部直接訪問// 這個屬性可能用于記錄與用戶相關…

巧妙利用數據結構優化部門查詢

目錄 一、出現的問題 部門樹接口超時 二、問題分析 源代碼分析 三、解決方案 具體實現思路 四、優化的效果 一、出現的問題 部門樹接口超時 無論是在A項目還是在B項目中&#xff0c;都存在類似的頁面&#xff0c;其實就是一個部門列表或者叫組織列表。 從頁面的展示形式…

QT簡單實現驗證碼(字符)

0&#xff09; 運行結果 1&#xff09; 生成隨機字符串 Qt主要通過QRandomGenerator類來生成隨機數。在此之前的版本中&#xff0c;qrand()函數也常被使用&#xff0c;但從Qt 5.10起&#xff0c;推薦使用更現代化的QRandomGenerator類。 在頭文件添加void generateRandomNumb…

JavaFX - 3D 形狀

在前面的章節中&#xff0c;我們已經了解了如何在 JavaFX 應用程序中的 XY 平面上繪制 2D 形狀。除了這些 2D 形狀之外&#xff0c;我們還可以使用 JavaFX 繪制其他幾個 3D 形狀。 通常&#xff0c;3D 形狀是可以在 XYZ 平面上繪制的幾何圖形。它們由兩個或多個維度定義&#…

深入理解開放尋址法中的三種探測序列

一、引言 開放尋址法是解決散列表中沖突的一種重要方法&#xff0c;當發生沖突&#xff08;即兩個不同的鍵通過散列函數計算得到相同的散列值&#xff09;時&#xff0c;它會在散列表中尋找下一個可用的存儲位置。而探測序列就是用于確定在發生沖突后&#xff0c;依次嘗試哪些…

【雙指針題目】

雙指針 美麗區間&#xff08;滑動窗口&#xff09;合并數列&#xff08;雙指針的應用&#xff09;等腰三角形全部所有的子序列 美麗區間&#xff08;滑動窗口&#xff09; 美麗區間 滑動窗口模板&#xff1a; int left 0, right 0;while (right < nums.size()) {// 增大…