第4天:RNN應用(心臟病預測)

  • 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
  • 🍖 原作者:K同學啊

目標

具體實現

(一)環境

語言環境:Python 3.10
編 譯 器: PyCharm
框 架: Pytorch

(二)具體步驟
1. 代碼
import numpy as np  
import pandas as pd  
import torch  
from torch import nn  
import torch.nn.functional as F  
import seaborn as sns  # 設置GPU  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
print("設備:", device)  # 導入數據  
df = pd.read_csv("./data/heart.csv")  
print(df)  # 構建數據集  
# 標準化  
from sklearn.preprocessing import StandardScaler  
from sklearn.model_selection import train_test_split  X = df.iloc[:, :-1]  
y = df.iloc[:, -1]  # 將第一列特征標準化為標準正態分布,注意,標準化是針對第一列而言的。  
sc = StandardScaler()  
X = sc.fit_transform(X)  # 劃分數據集  
X = torch.tensor(np.array(X), dtype=torch.float32)  
y = torch.tensor(np.array(y), dtype=torch.int64)  
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)  X_train = X_train.unsqueeze(1)  
X_test = X_test.unsqueeze(1)  
print("訓練集大小:", X_train.shape, y_train.shape)  from torch.utils.data import TensorDataset, DataLoader  train_dl = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=False)  
test_dl = DataLoader(TensorDataset(X_test, y_test), batch_size=64, shuffle=False)  # 構建模型  
"""  
RNN模型類,用于心臟病預測  Attributes:  rnn0: RNN層,處理序列數據  fc0: 全連接層,將RNN輸出降維到50維  fc1: 最終全連接層,輸出2分類結果  
"""  
class model_rnn(nn.Module):  def __init__(self):  super(model_rnn, self).__init__()  self.rnn0 = nn.RNN(  input_size=13,  hidden_size=200,  num_layers=1,  batch_first=True,  )  self.fc0 = nn.Linear(200, 50)  self.fc1 = nn.Linear(50, 2)  """  前向傳播函數  Args:        x: 輸入張量,形狀為(batch_size, seq_len, input_size)  Returns:        輸出張量,形狀為(batch_size, 2),表示兩個類別的得分  """    def forward(self, x):  out, _ = self.rnn0(x)  # 取最后一個時間步的輸出作為特征  out = out[:, -1, :]  out = self.fc0(out)  out = self.fc1(out)  return out  model = model_rnn().to(device)  
print(model)  print(model(torch.rand(30, 1, 13).to(device)).shape)  #  訓練  
"""  
訓練函數,執行一個epoch的訓練  Args:  dataloader: 數據加載器  model: 神經網絡模型  loss_fn: 損失函數  optimizer: 優化器  Returns:  平均訓練準確率和損失值  
"""  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset)  num_batches = len(dataloader)  train_loss, train_acc = 0, 0  for X, y in dataloader:  X, y = X.to(device), y.to(device)  pred = model(X)  loss = loss_fn(pred, y)  optimizer.zero_grad()  loss.backward()  optimizer.step()  train_acc +=  (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return  train_acc,train_loss  """  
測試函數,評估模型性能  Args:  dataloader: 數據加載器  model: 神經網絡模型  loss_fn: 損失函數  Returns:  平均測試準確率和損失值  
"""  
def test(dataloader, model, loss_fn):  size = len(dataloader.dataset)  num_batches = len(dataloader)  test_loss, test_acc = 0, 0  with torch.no_grad():  for imgs, target in dataloader:  imgs, target = imgs.to(device), target.to(device)  target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  loss_fn = nn.CrossEntropyLoss()  
learning_rate = 1e-4  
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)  
epochs = 50  train_loss = []  
train_acc = []  
test_loss = []  
test_acc = []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)  train_acc.append(epoch_train_acc)  train_loss.append(epoch_train_loss)  test_acc.append(epoch_test_acc)  test_loss.append(epoch_test_loss)  lr = opt.state_dict()['param_groups'][0]['lr']  template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')  print(template.format(epoch, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss, lr))  print("="*20, 'Done', "="*20)  # 評估模型  
import matplotlib.pyplot as plt  
from datetime import datetime  import warnings  
warnings.filterwarnings("ignore")  current_time = datetime.now()  plt.rcParams['font.sans-serif'] = ['SimHei']  
plt.rcParams['axes.unicode_minus'] = False  
plt.rcParams['figure.dpi'] = 200  epochs_range = range(epochs)  # 可視化訓練過程中的準確率和損失曲線  
# 包含訓練集和測試集的對比曲線  
plt.figure(figsize=(12, 3))  
plt.subplot(1, 2, 1)  
plt.plot(epochs_range, train_acc, label='訓練正確率')  
plt.plot(epochs_range, test_acc, label='測試正確率')  
plt.legend(loc='lower right')  
plt.title('訓練和測試正確率')  
plt.xlabel(current_time)  plt.subplot(1, 2, 2)  
plt.plot(epochs_range, train_loss, label='訓練損失')  
plt.plot(epochs_range, test_loss, label='測試損失')  
plt.legend(loc='upper right')  
plt.title('訓練和測試損失')  
plt.show()  # 混淆矩陣  
print("====================輸入數據shape================")  
print("X_test.shape:", X_test.shape)  
print("y_test.shape:", y_test.shape)  
pred = model(X_test.to(device)).argmax(1).cpu().numpy()  
print("\n===================輸出數據shape===============")  
print("pred.shape:", pred.shape)  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay  
#計算混淆矩陣  
# 生成混淆矩陣并可視化  
# 顯示分類結果的混淆矩陣熱力圖  
cm = confusion_matrix(y_test, pred)  plt.figure(figsize=(6, 5))  
plt.suptitle("")  
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')  plt.xticks(fontsize=10)  
plt.yticks(fontsize=10)  
plt.title("Confusion Matrix", fontsize=12)  
plt.xlabel("Predicted Label", fontsize=10)  
plt.ylabel("True Label", fontsize=10)  plt.tight_layout()  
plt.show()  # 心臟病預測  
test_X = X_test[0].unsqueeze(1)  
pred = model(test_X.to(device)).argmax(1).item()  
print("預測結果為:", pred)  
print("=="*20)  
print("0:不會患心臟病")  
print("1:會患心臟病")

結果:

設備: cudaage  sex  cp  trestbps  chol  fbs  ...  exang  oldpeak  slope  ca  thal  target
0     63    1   3       145   233    1  ...      0      2.3      0   0     1       1
1     37    1   2       130   250    0  ...      0      3.5      0   0     2       1
2     41    0   1       130   204    0  ...      0      1.4      2   0     2       1
3     56    1   1       120   236    0  ...      0      0.8      2   0     2       1
4     57    0   0       120   354    0  ...      1      0.6      2   0     2       1
..   ...  ...  ..       ...   ...  ...  ...    ...      ...    ...  ..   ...     ...
298   57    0   0       140   241    0  ...      1      0.2      1   0     3       0
299   45    1   3       110   264    0  ...      0      1.2      1   0     3       0
300   68    1   0       144   193    1  ...      0      3.4      1   2     3       0
301   57    1   0       130   131    0  ...      1      1.2      1   1     3       0
302   57    0   1       130   236    0  ...      0      0.0      1   1     2       0[303 rows x 14 columns]
訓練集大小: torch.Size([272, 1, 13]) torch.Size([272])
model_rnn((rnn0): RNN(13, 200, batch_first=True)(fc0): Linear(in_features=200, out_features=50, bias=True)(fc1): Linear(in_features=50, out_features=2, bias=True)
)
torch.Size([30, 2])
Epoch: 0, Train_acc:41.2%, Train_loss:0.700, Test_acc:45.2%, Test_loss:0.700, Lr:1.00E-04
Epoch: 1, Train_acc:55.9%, Train_loss:0.688, Test_acc:67.7%, Test_loss:0.682, Lr:1.00E-04
Epoch: 2, Train_acc:66.9%, Train_loss:0.676, Test_acc:80.6%, Test_loss:0.664, Lr:1.00E-04
Epoch: 3, Train_acc:76.5%, Train_loss:0.664, Test_acc:87.1%, Test_loss:0.648, Lr:1.00E-04
Epoch: 4, Train_acc:77.6%, Train_loss:0.653, Test_acc:90.3%, Test_loss:0.631, Lr:1.00E-04
Epoch: 5, Train_acc:78.7%, Train_loss:0.642, Test_acc:90.3%, Test_loss:0.615, Lr:1.00E-04
Epoch: 6, Train_acc:79.4%, Train_loss:0.631, Test_acc:90.3%, Test_loss:0.599, Lr:1.00E-04
Epoch: 7, Train_acc:80.9%, Train_loss:0.620, Test_acc:90.3%, Test_loss:0.583, Lr:1.00E-04
Epoch: 8, Train_acc:81.6%, Train_loss:0.609, Test_acc:90.3%, Test_loss:0.567, Lr:1.00E-04
Epoch: 9, Train_acc:82.4%, Train_loss:0.598, Test_acc:90.3%, Test_loss:0.552, Lr:1.00E-04
Epoch:10, Train_acc:81.6%, Train_loss:0.587, Test_acc:90.3%, Test_loss:0.536, Lr:1.00E-04
Epoch:11, Train_acc:80.9%, Train_loss:0.576, Test_acc:90.3%, Test_loss:0.520, Lr:1.00E-04
Epoch:12, Train_acc:81.2%, Train_loss:0.565, Test_acc:90.3%, Test_loss:0.504, Lr:1.00E-04
Epoch:13, Train_acc:80.5%, Train_loss:0.554, Test_acc:90.3%, Test_loss:0.489, Lr:1.00E-04
Epoch:14, Train_acc:81.2%, Train_loss:0.542, Test_acc:90.3%, Test_loss:0.473, Lr:1.00E-04
Epoch:15, Train_acc:80.9%, Train_loss:0.531, Test_acc:90.3%, Test_loss:0.458, Lr:1.00E-04
Epoch:16, Train_acc:80.9%, Train_loss:0.520, Test_acc:90.3%, Test_loss:0.443, Lr:1.00E-04
Epoch:17, Train_acc:80.9%, Train_loss:0.509, Test_acc:90.3%, Test_loss:0.428, Lr:1.00E-04
Epoch:18, Train_acc:80.9%, Train_loss:0.498, Test_acc:90.3%, Test_loss:0.414, Lr:1.00E-04
Epoch:19, Train_acc:81.6%, Train_loss:0.488, Test_acc:90.3%, Test_loss:0.401, Lr:1.00E-04
Epoch:20, Train_acc:81.6%, Train_loss:0.477, Test_acc:90.3%, Test_loss:0.388, Lr:1.00E-04
Epoch:21, Train_acc:81.6%, Train_loss:0.468, Test_acc:90.3%, Test_loss:0.376, Lr:1.00E-04
Epoch:22, Train_acc:81.6%, Train_loss:0.458, Test_acc:87.1%, Test_loss:0.365, Lr:1.00E-04
Epoch:23, Train_acc:81.6%, Train_loss:0.449, Test_acc:87.1%, Test_loss:0.355, Lr:1.00E-04
Epoch:24, Train_acc:82.4%, Train_loss:0.441, Test_acc:87.1%, Test_loss:0.346, Lr:1.00E-04
Epoch:25, Train_acc:82.4%, Train_loss:0.433, Test_acc:87.1%, Test_loss:0.337, Lr:1.00E-04
Epoch:26, Train_acc:82.7%, Train_loss:0.426, Test_acc:87.1%, Test_loss:0.329, Lr:1.00E-04
Epoch:27, Train_acc:82.7%, Train_loss:0.419, Test_acc:87.1%, Test_loss:0.322, Lr:1.00E-04
Epoch:28, Train_acc:83.1%, Train_loss:0.413, Test_acc:87.1%, Test_loss:0.316, Lr:1.00E-04
Epoch:29, Train_acc:83.5%, Train_loss:0.407, Test_acc:87.1%, Test_loss:0.311, Lr:1.00E-04
Epoch:30, Train_acc:83.5%, Train_loss:0.402, Test_acc:87.1%, Test_loss:0.306, Lr:1.00E-04
Epoch:31, Train_acc:83.8%, Train_loss:0.397, Test_acc:87.1%, Test_loss:0.302, Lr:1.00E-04
Epoch:32, Train_acc:84.2%, Train_loss:0.392, Test_acc:87.1%, Test_loss:0.299, Lr:1.00E-04
Epoch:33, Train_acc:84.2%, Train_loss:0.388, Test_acc:87.1%, Test_loss:0.296, Lr:1.00E-04
Epoch:34, Train_acc:84.2%, Train_loss:0.384, Test_acc:87.1%, Test_loss:0.294, Lr:1.00E-04
Epoch:35, Train_acc:84.2%, Train_loss:0.381, Test_acc:87.1%, Test_loss:0.292, Lr:1.00E-04
Epoch:36, Train_acc:84.2%, Train_loss:0.378, Test_acc:87.1%, Test_loss:0.290, Lr:1.00E-04
Epoch:37, Train_acc:83.8%, Train_loss:0.375, Test_acc:87.1%, Test_loss:0.289, Lr:1.00E-04
Epoch:38, Train_acc:83.8%, Train_loss:0.373, Test_acc:87.1%, Test_loss:0.288, Lr:1.00E-04
Epoch:39, Train_acc:83.8%, Train_loss:0.370, Test_acc:87.1%, Test_loss:0.287, Lr:1.00E-04
Epoch:40, Train_acc:83.5%, Train_loss:0.368, Test_acc:87.1%, Test_loss:0.287, Lr:1.00E-04
Epoch:41, Train_acc:83.5%, Train_loss:0.366, Test_acc:87.1%, Test_loss:0.286, Lr:1.00E-04
Epoch:42, Train_acc:83.5%, Train_loss:0.364, Test_acc:87.1%, Test_loss:0.286, Lr:1.00E-04
Epoch:43, Train_acc:83.8%, Train_loss:0.363, Test_acc:87.1%, Test_loss:0.287, Lr:1.00E-04
Epoch:44, Train_acc:83.8%, Train_loss:0.361, Test_acc:87.1%, Test_loss:0.287, Lr:1.00E-04
Epoch:45, Train_acc:83.8%, Train_loss:0.360, Test_acc:87.1%, Test_loss:0.287, Lr:1.00E-04
Epoch:46, Train_acc:83.8%, Train_loss:0.358, Test_acc:87.1%, Test_loss:0.287, Lr:1.00E-04
Epoch:47, Train_acc:84.2%, Train_loss:0.357, Test_acc:87.1%, Test_loss:0.288, Lr:1.00E-04
Epoch:48, Train_acc:84.2%, Train_loss:0.356, Test_acc:87.1%, Test_loss:0.289, Lr:1.00E-04
Epoch:49, Train_acc:84.6%, Train_loss:0.355, Test_acc:87.1%, Test_loss:0.289, Lr:1.00E-04
==================== Done ====================
====================輸入數據shape================
X_test.shape: torch.Size([31, 1, 13])
y_test.shape: torch.Size([31])===================輸出數據shape===============
pred.shape: (31,)
預測結果為: 0
========================================
0:不會患心臟病
1:會患心臟病

image.png
image.png

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/86101.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/86101.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/86101.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

STM32學習筆記:外部中斷(EXTI)原理與應用詳解

前言 在嵌入式系統開發中,中斷機制是提高系統實時性和效率的重要手段。相比傳統的51單片機,STM32微控制器提供了更為豐富和靈活的外部中斷資源。本文將全面介紹STM32的外部中斷(EXTI)功能,包括其工作原理、配置方法和實際應用技巧。 一、外…

嵌入式知識篇---Zigbee串口

在 Python 中,serial和pyserial是經常被提及的兩個庫,它們在串口通信方面有著緊密的聯系,但又存在一些差異。下面將對它們進行詳細介紹,并給出各自的適用場景。 1. 基本概念 pyserial:它是 Python 里專門用于串口通信…

vue中的派發事件與廣播事件,及廣播事件應用于哪些場景和一個表單驗證例子

在 Vue 2.X 中,$dispatch 和 $broadcast 方法已經被廢棄。官方認為基于組件樹結構的事件流方式難以理解,并且在組件結構擴展時容易變得脆弱。因此,Vue 2.X 推薦使用其他方式來實現組件間的通信,例如通過 $emit 和 $on 方法&#x…

阿里云事件總線 EventBridge 正式商業化,構建智能化時代的企業級云上事件樞紐

作者:肯夢、稚柳 產品演進歷程:在技術浪潮中的成長之路 早在 2018 年,Gartner 評估報告便將事件驅動模型(Event-Driven Model)列為十大戰略技術趨勢之一,指出事件驅動架構(EDA,Eve…

《前端面試題:BFC(塊級格式化上下文)》

前端BFC完全指南:布局魔法與面試必備 🎋 端午安康! 各位前端探險家,端午節快樂!🥮 愿你的代碼如龍舟競渡般乘風破浪,樣式如香糯粽子般完美包裹!今天我們來解鎖CSS中的布局魔法——B…

dvwa10——XSS(DOM)

XSS攻擊: DOM型XSS 只在瀏覽器前端攻擊觸發:修改url片段代碼不存儲 反射型XSS 經過服務器攻擊觸發:可能通過提交惡意表單,連接觸發代碼不存儲 存儲型XSS 經由服務器攻擊觸發:可能通過提交惡意表單,連…

跨平臺資源下載工具:res-downloader 的使用體驗

一款基于 Go Wails 的跨平臺資源下載工具,簡潔易用,支持多種資源嗅探與下載。res-downloader 一款開源免費的下載軟件(開源無毒、放心使用)!支持Win10、Win11、Mac系統.支持視頻、音頻、圖片、m3u8等網絡資源下載.支持視頻號、小程序、抖音、…

AOSP CachedAppOptimizer中的凍結和內存壓縮功能

AOSP CachedAppOptimizer:應用進程長期處于 Cached 狀態的內存壓縮和凍結優化管控 凍結和內存壓縮兩個功能獨立觸發,可以單獨觸發也可以組合觸發,默認順序:先壓縮,后凍結 public class OomAdjuster { protected b…

相機--相機成像原理和基礎概念

教程 成像原理 基礎概念 焦距(物理焦距) 鏡頭的光學中心到感光元件之間的距離,用f表示,單位:mm;。 像素焦距 相機內參矩陣中的 fx? 和 fy? 是將物理焦距轉換到像素坐標系的產物,可能不同。…

Vue3項目實現WPS文件預覽和內容回填功能

技術方案背景:根據項目需要,要實現在線查看、在線編輯文檔,并且進行內容的快速回填,根據這一項目背景,最終采用WPS的API來實現,接下來我們一起來實現項目功能。 1.首先需要先準備好測試使用的文檔&#xf…

匯編語言學習(三)——DoxBox中debug的使用

目錄 一、安裝DoxBox,并下載匯編工具(MASM文件) 二、debug是什么 三、debug中的命令 一、安裝DoxBox,并下載匯編工具(MASM文件) 鏈接: https://pan.baidu.com/s/1IbyJj-JIkl_oMOJmkKiaGQ?pw…

關于DDOS

DDOS是一門沒什么技術含量的東西,其本質而言是通過大量數據報文,發送到目標受害主機IP地址上,導致目標主機無法繼續服務(俗稱:拒絕服務) DDOS灰產人期望達成的預期目標,幾乎都是只要把對面打到 …

Modbus轉Ethernet IP網關助力羅克韋爾PLC數據交互

在工業自動化領域,Modbus協議是一種廣泛應用的串行通信協議,它定義了主站和從站之間的通信規則和數據格式。羅克韋爾PLC是一種可編程的邏輯控制器,通過Modbus協議實現與其他設備之間的數據交互。然而,隨著以太網技術的普及和發展&…

C# winform教程(二)----button

一、button的使用方法 主要使用方法幾乎都在屬性內,我們操作也在這個界面 二、作用 用戶點擊時觸發事件,事件有很多種,可以根據需要選擇。 三、常用屬性 雖然屬性很多,但是常用的并不多 3.常用屬性 名稱內容含義AutoSize自動調…

【 java 基礎問題 第二篇 】

目錄 1.深拷貝和淺拷貝 1.1.區別 定義 定義 1.2.實現深拷貝的方式 2.泛型 2.1.定義 2.2.作用 3.對象 3.1.創建對象的方式 3.2.對象回收 3.3. 獲取私有成員 4.反射 4.1.定義 4.2.特性 4.3.原理 5.異常 5.1.異常的種類 5.2.處理異常的方法 6.Object 6.1.等于與…

Kafka 入門指南與一鍵部署

Kafka 介紹 想象一下你正在運營一個大型電商平臺,每秒都有成千上萬的用戶瀏覽商品、下單、支付,同時后臺系統還在記錄用戶行為、更新庫存、處理物流信息。這些海量、持續產生的數據就像奔騰不息的河流,你需要一個強大、可靠且實時的系統來接…

湖北理元理律師事務所:企業債務重組的風險控制方法論

一、擔保鏈破解:阻斷債務傳染的核心技術 2023年武漢某建材公司案例: 原始債務結構: A公司(主債務人)欠款200萬 ↓ B公司(擔保人)←連帶責任觸發執行 ↓ C公司(B公司擔…

如何在CloudCompare中打開pcd文件

你只需要將pcd文件的路徑改在全英文路徑下,CloudCompare就可以打開。若含中文,就會報錯:

中醫的十問歌和脈象分類

中醫核心理論框架如下 診斷技術如下 本文主要介紹問診和切診。 十問歌的“十”是虛指,實際包含12個核心問題,脈象28種中常見僅10余種,重點解釋脈診的物理本質(血流動力學觸覺感知) 以下是中醫十問歌的完整內容及脈…

基于智能代理人工智能(Agentic AI)對沖基金模擬系統:模范巴菲特、凱西·伍德的投資策略

股票市場涉及眾多統計數據和模式。股票交易基于研究和數據驅動的決策。人工智能的使用可以實現流程自動化,讓投資者在研究上花費更少的時間,同時提高準確性。這使他們能夠更加專注于監督實際交易和服務客戶。 頂尖對沖基金經理發揮著至關重要的作用&…