第R7周:糖尿病預測模型優化探索

文章目錄

  • 1.數據預處理
    • 1.1 設置GPU
    • 1.2 數據導入
    • 1.3 數據檢查
  • 2. 數據分析
    • 2.1 數據分布分析
    • 2.2 相關性分析
  • 3. LSTM模型
    • 3.1 劃分數據集
    • 3.2 數據集構建
    • 3.3 定義模型
  • 4. 訓練模型
    • 4.1 定義訓練函數
    • 4.2 定義測試函數
    • 4.3 訓練模型
  • 5. 模型評估
    • 5.1 Loss與Accuracy圖
  • 6. 總結

  • 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
  • 🍖 原作者:K同學啊

1.數據預處理

1.1 設置GPU

import torch.nn as nn
import torch.nn.functional as F
import torchvision,torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type=‘cpu’)

1.2 數據導入

import numpy             as np
import pandas            as pd
import seaborn           as sns
from sklearn.model_selection   import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['savefig.dpi'] = 500 #圖片像素
plt.rcParams['figure.dpi']  = 500 #分辨率plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標簽import warnings 
warnings.filterwarnings("ignore")DataFrame=pd.read_excel('dia.xls')
DataFrame.head()
卡號性別年齡高密度脂蛋白膽固醇低密度脂蛋白膽固醇極低密度脂蛋白膽固醇甘油三酯總膽固醇脈搏舒張壓高血壓史尿素氮尿酸肌酐體重檢查結果是否糖尿病
0180544210381.252.991.070.645.31838304.99243.35010
1180544220311.151.990.840.503.98856304.72391.04710
2180544230271.292.210.690.604.19736105.87325.75110
3180544240330.932.010.660.843.60836002.40203.24020
4180544250361.172.830.830.734.83856704.09236.84300
DataFrame.shape

(1006, 16)

1.3 數據檢查

# 查看數據是否有缺失值
print('數據缺失值---------------------------------')
print(DataFrame.isnull().sum())

數據缺失值---------------------------------
卡號 0
性別 0
年齡 0
高密度脂蛋白膽固醇 0
低密度脂蛋白膽固醇 0
極低密度脂蛋白膽固醇 0
甘油三酯 0
總膽固醇 0
脈搏 0
舒張壓 0
高血壓史 0
尿素氮 0
尿酸 0
肌酐 0
體重檢查結果 0
是否糖尿病 0
dtype: int64

# 查看數據是否有重復值
print('數據重復值---------------------------------')
print('數據集的重復值為:'f'{DataFrame.duplicated().sum()}')

數據重復值---------------------------------
數據集的重復值為:0

2. 數據分析

2.1 數據分布分析

feature_map = {'年齡': '年齡','高密度脂蛋白膽固醇': '高密度脂蛋白膽固醇','低密度脂蛋白膽固醇': '低密度脂蛋白膽固醇','極低密度脂蛋白膽固醇': '極低密度脂蛋白膽固醇','甘油三酯': '甘油三酯','總膽固醇': '總膽固醇','脈搏': '脈搏','舒張壓':'舒張壓','高血壓史':'高血壓史','尿素氮':'尿素氮','尿酸':'尿酸','肌酐':'肌酐','體重檢查結果':'體重檢查結果'
}
plt.figure(figsize=(15, 10))for i, (col, col_name) in enumerate(feature_map.items(), 1):plt.subplot(3, 5, i)sns.boxplot(x=DataFrame['是否糖尿病'], y=DataFrame[col])plt.title(f'{col_name}的箱線圖', fontsize=14)plt.ylabel('數值', fontsize=12)plt.grid(axis='y', linestyle='--', alpha=0.7)plt.tight_layout()
plt.show()

在這里插入圖片描述

2.2 相關性分析

import plotly
import plotly.express as px# 刪除列 '卡號'
DataFrame.drop(columns=['卡號'], inplace=True)
# 計算各列之間的相關系數
df_corr = DataFrame.corr()# 相關矩陣生成函數
def corr_generate(df):fig = px.imshow(df,text_auto=True,aspect="auto",color_continuous_scale='RdBu_r')fig.show()# 生成相關矩陣
corr_generate(df_corr)

3. LSTM模型

3.1 劃分數據集

from sklearn.preprocessing import StandardScaler# '高密度脂蛋白膽固醇'字段與糖尿病負相關,故而在 X 中去掉該字段
X = DataFrame.drop(['是否糖尿病','高密度脂蛋白膽固醇'],axis=1)
y = DataFrame['是否糖尿病']# 數據集標準化處理
sc_X    = StandardScaler()
X = sc_X.fit_transform(X)X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.2,random_state=1)
# 維度擴增使其符合LSTM模型可接受shape
train_X = train_X.unsqueeze(1)
test_X  = test_X.unsqueeze(1)
train_X.shape, train_y.shape

(torch.Size([804, 1, 13]), torch.Size([804]))

3.2 數據集構建

from torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(train_X, train_y),batch_size=64, shuffle=False)test_dl  = DataLoader(TensorDataset(test_X, test_y),batch_size=64, shuffle=False)

3.3 定義模型

class model_lstm(nn.Module):def __init__(self):super(model_lstm, self).__init__()self.lstm0 = nn.LSTM(input_size=13 ,hidden_size=200, num_layers=1, batch_first=True)self.lstm1 = nn.LSTM(input_size=200 ,hidden_size=200, num_layers=1, batch_first=True)self.fc0   = nn.Linear(200, 2)def forward(self, x):out, hidden1 = self.lstm0(x) out, _ = self.lstm1(out, hidden1) out    = out[:, -1, :]  # 只取最后一個時間步的輸出out    = self.fc0(out) return out   model = model_lstm().to(device)
model

model_lstm(
(lstm0): LSTM(13, 200, batch_first=True)
(lstm1): LSTM(200, 200, batch_first=True)
(fc0): Linear(in_features=200, out_features=2, bias=True)
)

4. 訓練模型

4.1 定義訓練函數

# 訓練循環
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 訓練集的大小num_batches = len(dataloader)   # 批次數目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化訓練損失和正確率for X, y in dataloader:  # 獲取圖片及其標簽X, y = X.to(device), y.to(device)# 計算預測誤差pred = model(X)          # 網絡輸出loss = loss_fn(pred, y)  # 計算網絡輸出和真實值之間的差距,targets為真實值,計算二者差值即為損失# 反向傳播optimizer.zero_grad()  # grad屬性歸零loss.backward()        # 反向傳播optimizer.step()       # 每一步自動更新# 記錄acc與losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

4.2 定義測試函數

def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 測試集的大小num_batches = len(dataloader)          # 批次數目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 當不進行訓練時,停止梯度更新,節省計算內存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 計算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4.3 訓練模型

loss_fn    = nn.CrossEntropyLoss() # 創建損失函數
learn_rate = 1e-4   # 學習率
opt        = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs     = 30train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 獲取當前的學習率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)

Epoch: 1, Train_acc:43.8%, Train_loss:0.693, Test_acc:48.0%, Test_loss:0.682, Lr:1.00E-04
Epoch: 2, Train_acc:52.6%, Train_loss:0.684, Test_acc:62.4%, Test_loss:0.676, Lr:1.00E-04
Epoch: 3, Train_acc:68.4%, Train_loss:0.674, Test_acc:69.8%, Test_loss:0.669, Lr:1.00E-04
Epoch: 4, Train_acc:72.9%, Train_loss:0.662, Test_acc:73.8%, Test_loss:0.661, Lr:1.00E-04
Epoch: 5, Train_acc:76.1%, Train_loss:0.648, Test_acc:74.3%, Test_loss:0.651, Lr:1.00E-04
Epoch: 6, Train_acc:76.4%, Train_loss:0.631, Test_acc:73.8%, Test_loss:0.639, Lr:1.00E-04
Epoch: 7, Train_acc:76.1%, Train_loss:0.611, Test_acc:74.3%, Test_loss:0.625, Lr:1.00E-04
Epoch: 8, Train_acc:76.0%, Train_loss:0.588, Test_acc:75.2%, Test_loss:0.610, Lr:1.00E-04
Epoch: 9, Train_acc:75.0%, Train_loss:0.564, Test_acc:75.2%, Test_loss:0.595, Lr:1.00E-04
Epoch:10, Train_acc:75.0%, Train_loss:0.541, Test_acc:75.2%, Test_loss:0.581, Lr:1.00E-04
Epoch:11, Train_acc:75.2%, Train_loss:0.521, Test_acc:75.2%, Test_loss:0.569, Lr:1.00E-04
Epoch:12, Train_acc:75.7%, Train_loss:0.504, Test_acc:75.7%, Test_loss:0.559, Lr:1.00E-04
Epoch:13, Train_acc:75.7%, Train_loss:0.491, Test_acc:75.7%, Test_loss:0.550, Lr:1.00E-04
Epoch:14, Train_acc:75.6%, Train_loss:0.480, Test_acc:76.7%, Test_loss:0.543, Lr:1.00E-04
Epoch:15, Train_acc:75.7%, Train_loss:0.472, Test_acc:76.2%, Test_loss:0.535, Lr:1.00E-04
Epoch:16, Train_acc:76.7%, Train_loss:0.465, Test_acc:76.2%, Test_loss:0.529, Lr:1.00E-04
Epoch:17, Train_acc:77.4%, Train_loss:0.459, Test_acc:76.7%, Test_loss:0.522, Lr:1.00E-04
Epoch:18, Train_acc:77.9%, Train_loss:0.454, Test_acc:77.2%, Test_loss:0.516, Lr:1.00E-04
Epoch:19, Train_acc:78.4%, Train_loss:0.450, Test_acc:77.7%, Test_loss:0.511, Lr:1.00E-04
Epoch:20, Train_acc:78.2%, Train_loss:0.446, Test_acc:77.2%, Test_loss:0.506, Lr:1.00E-04
Epoch:21, Train_acc:78.2%, Train_loss:0.442, Test_acc:77.2%, Test_loss:0.501, Lr:1.00E-04
Epoch:22, Train_acc:78.6%, Train_loss:0.439, Test_acc:77.2%, Test_loss:0.496, Lr:1.00E-04
Epoch:23, Train_acc:78.9%, Train_loss:0.436, Test_acc:77.2%, Test_loss:0.492, Lr:1.00E-04
Epoch:24, Train_acc:78.9%, Train_loss:0.433, Test_acc:77.7%, Test_loss:0.488, Lr:1.00E-04
Epoch:25, Train_acc:79.2%, Train_loss:0.430, Test_acc:77.7%, Test_loss:0.484, Lr:1.00E-04
Epoch:26, Train_acc:79.2%, Train_loss:0.427, Test_acc:78.2%, Test_loss:0.481, Lr:1.00E-04
Epoch:27, Train_acc:79.4%, Train_loss:0.425, Test_acc:79.2%, Test_loss:0.477, Lr:1.00E-04
Epoch:28, Train_acc:79.4%, Train_loss:0.423, Test_acc:79.2%, Test_loss:0.474, Lr:1.00E-04
Epoch:29, Train_acc:79.5%, Train_loss:0.421, Test_acc:79.2%, Test_loss:0.471, Lr:1.00E-04
Epoch:30, Train_acc:79.6%, Train_loss:0.418, Test_acc:79.7%, Test_loss:0.467, Lr:1.00E-04
==================== Done ====================

5. 模型評估

5.1 Loss與Accuracy圖

import matplotlib.pyplot as plt
#隱藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False      # 用來正常顯示負號
plt.rcParams['figure.dpi']         = 100        #分辨率from datetime import datetime
current_time = datetime.now() # 獲取當前時間epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡請帶上時間戳,否則代碼截圖無效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在這里插入圖片描述

6. 總結

本周主要實現了實現了對于上一次糖尿病預測模型的優化。通過實踐,更加深入地了解了LSTM模型的相關優化。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/81042.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/81042.shtml
英文地址,請注明出處:http://en.pswp.cn/web/81042.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

一些好用的Chrome 擴展程序

以下是按主要功能分類的 Chrome 擴展程序列表,包括其版本號、中文功能簡述以及指向其主頁或 Chrome 網上應用店頁面的鏈接。 翻譯與語言 沉浸式翻譯 - 網頁翻譯插件 | PDF 翻譯 | 免費 版本: 1.16.12 描述: 【沉浸式翻譯】免費的(原文 / 譯文&#xff0…

貪心算法題目合集2

貪心算法題目合集2 一般排序排隊接水整數區間金銀島尋找平面上的極大點NOIP 2008 普及組 排座椅 推導排序規律NOIP 1998 提高組 拼數排序規則的正確性證明:全序關系證明拼數的貪心策略正確P2878 [USACO07JAN] Protecting the Flowers SP1842 [USACO05NOV] 奶牛玩雜技…

全方位詳解微服務架構中的Service Mesh(服務網格)

一、引言 隨著微服務架構的廣泛應用,微服務之間的通信管理、流量控制、安全保障等問題變得日益復雜。服務網格(Service Mesh)作為一種新興的技術,為解決這些問題提供了有效的方案。它將服務間通信的管理從微服務代碼中分離出來&a…

如何在VSCode中更換默認瀏覽器:完整指南

引言 作為前端開發者,我們經常需要在VSCode中快速預覽HTML文件。默認情況下,VSCode會使用系統默認瀏覽器打開文件,但有時我們可能需要切換到其他瀏覽器進行測試。本文將詳細介紹如何在VSCode中更換默認瀏覽器。 方法一:使用VSCo…

【普及+/提高】洛谷P2613 【模板】有理數取余——快讀+快速冪

題目來源 P2613 【模板】有理數取余 - 洛谷 題目描述 給出一個有理數 cba?,求 cmod19260817 的值。 這個值被定義為 bx≡a(mod19260817) 的解。 輸入格式 一共兩行。 第一行,一個整數 a。 第二行,一個整數 b。 輸出格式 一個整數&a…

從編程助手到AI工程師:Trae插件Builder模式實戰Excel合并工具開發

Trae插件下載鏈接:https://www.trae.com.cn/plugin 引言:AI編程工具的新紀元 在軟件開發領域,AI輔助編程正在經歷一場革命性的變革。Trae插件(原MarsCode編程助手)最新推出的Builder模式,標志著AI編程工具…

Python set集合方法詳解

""" set()函數是個無序的去重集合,可以用來過濾重復元素 Python 提供了 2 種創建 set 集合的方法,分別是使用 {} 創建和使用 set() 函數將列表、元組等類型數據轉換為集合 """# 空集合 s0 set() # 正確方式 →…

各類Agent技術的發展現狀和核心痛點

AI Agent主要分類 Agent(智能體)技術是指具有自主感知、決策與執行能力的軟件系統,能夠在環境中完成特定任務。目前常見的Agent類型主要包括: - 基于大模型的智能體:以GPT-4等大型語言模型為核心,如AutoGP…

單片機-STM32部分:18、WiFi模組

飛書文檔https://x509p6c8to.feishu.cn/wiki/WFmqwImDViDUezkF7ercZuNDnve 一、WiFi模組應用 當設備需要連接網絡,實現遠程控制,狀態監控時,就需要添加通信模組,常見的通信模組WiFi模組、2G模組、4G模組等: 我們的板卡…

探索Qwen2ForCausalLM 架構上進行微調

簡述 試驗參考了mini_qwen 的開源實現 GitHub - qiufengqijun/mini_qwen: 這是一個從頭訓練大語言模型的項目,包括預訓練、微調和直接偏好優化,模型擁有1B參數,支持中英文。這是一個從頭訓練大語言模型的項目,包括預訓練、微調和…

hysAnalyser特色的TS流編輯、剪輯和轉存MP4功能說明

摘要 hysAnalyser 是一款特色的 MPEG-TS 數據分析工具,融合了常規TS文件的剪輯,轉存功能,可用于平常的視頻開發和測試。 本文詳細闡述了對MPEG-TS 流的節目ID,名稱,PID,時間戳,流類型&#xff…

前端[插件化]設計思想_Vue、React、Webpack、Vite、Element Plus、Ant Design

前端插件化設計思想旨在提升應用的可擴展性、可維護性和模塊化程度。這種思想不僅體現在框架(如 Vue、React)中,也廣泛應用于構建工具(如 Webpack、Vite)以及 UI 庫(如 Element Plus、Ant Design&#xff0…

2025年高防IP與游戲盾深度對比:如何選擇最佳防護方案?

2025年,隨著DDoS攻擊規模的指數級增長和混合攻擊的常態化,高防IP與游戲盾成為企業網絡安全的核心選擇。然而,兩者在功能定位、技術實現及適用場景上存在顯著差異。本文結合最新行業實踐與技術趨勢,全面解析兩者的優劣,…

日志根因分析:Elastic Observability 的異常檢測與日志分類功能

作者:來自 Elastic Bahubali Shetti Elastic Observability 不僅提供日志聚合、指標分析、APM 和分布式追蹤,Elastic 的機器學習能力還能幫助分析問題的根因,讓你將時間專注于最重要的任務。 隨著越來越多的應用程序遷移到云端,收…

Linux火墻管理及優化

網絡環境配置 使用3個新的虛擬機【配置好軟件倉庫和網絡的】 F1 192.168.150.133 NAT F2 192.168.150.134 192.168.10.20 NAT HOST-ONLY 網絡適配僅主機 F3 192.168.10.30 HOST-ONLY 網絡適配僅主機 1 ~]# hostnamectl hostname double1.timinglee.org 【更…

java配置webSocket、前端使用uniapp連接

一、這個管理系統是基于若依框架&#xff0c;配置webSocKet的maven依賴 <!--websocket--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency> 二、配…

基于Yolov8+PyQT的老人摔倒識別系統源碼

概述 ??基于Yolov8PyQT的老人摔倒識別系統??&#xff0c;該系統通過深度學習算法實時檢測人體姿態&#xff0c;精準識別站立、摔倒中等3種狀態&#xff0c;為家庭或養老機構提供及時預警功能。 主要內容 ??完整可運行代碼?? 項目采用Yolov8目標檢測框架結合PyQT5開發…

Oracle 創建外部表

找別人要一下數據&#xff0c;但是他發來一個 xxx.csv 文件&#xff0c;怎么辦&#xff1f; 1、使用視圖化工具導入 使用導入工具導入&#xff0c;如 DBeaver&#xff0c;右擊要導入的表&#xff0c;選擇導入數據。 選擇對應的 csv 文件&#xff0c;下一步就行了&#xff08;如…

【華為OD- B卷 01 - 傳遞悄悄話 100分(python、java、c、c++、js)】

【華為OD- B卷 01 - 傳遞悄悄話 100分(python、java、c、c++、js)】 題目 給定一個二叉樹,每個節點上站一個人,節點數字表示父節點到該節點傳遞悄悄話需要花費的時間。 初始時,根節點所在位置的人有一個悄悄話想要傳遞給其他人,求二叉樹所有節點上的人都接收到悄悄話花…

房貸利率計算前端小程序

利率計算前端小程序 視圖效果展示如下&#xff1a; 在這里插入代碼片 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0&qu…