深度學習筆記25-RNN心臟病預測(Pytorch)

  • ?🍨 本文為🔗365天深度學習訓練營中的學習記錄博客
  • 🍖 原作者:K同學啊

?一、前期準備

1.數據處理

import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
from torch import nn
df=pd.read_csv(r"D:\Pytorch\heart.csv")
df

二、構建數據集

1.標準化

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitX= df.iloc[:,:-1]
y= df.iloc[:,-1]
#將每一列特征標準化為標準正太分布,注意,標準化是針對每一列而言
sc=StandardScaler()
X=sc.fit_transform(X)

2.劃分數據集

X=torch.tensor(np.array(X),dtype=torch.float32)
y=torch.tensor(np.array(y),dtype=torch.int64)X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.1,random_state=1)
#維度擴增使其符合RNN模型可接受shape
X_train = X_train.unsqueeze(1)
X_test = X_test.unsqueeze(1)
X_train.shape, y_train.shape

3.構建數據加載器

from torch.utils.data import TensorDataset,DataLoader
train_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64,shuffle=False)
test_dl=DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)

三、模型訓練

1.構建模型

class model_rnn(nn.Module):def __init__(self):super(model_rnn,self).__init__()self.rnn0=nn.RNN(input_size=13,hidden_size=200,num_layers=1,batch_first=True)self.fc0=nn.Linear(200,50)self.fc1=nn.Linear(50,2)def forward(self,x):out,_=self.rnn0(x)out=out[:,-1,:] #只取最后一個時間步的輸出out=self.fc0(out)out=self.fc1(out)return out
model=model_rnn()
model

2.定義訓練函數

# 訓練循環
def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)train_loss,train_acc=0,0 #初始化訓練損失和正確率for X,y in dataloader:#計算預測誤差pred=model(X)   #網絡輸出loss=loss_fn(pred,y) #計算網絡輸出和真實值之間的差距# 反向傳播optimizer.zero_grad() #grad屬性歸零loss.backward()#反向傳播optimizer.step()#每一步自動更新#記錄acc與losstrain_acc+= (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc/=sizetrain_loss/=num_batchesreturn train_acc,train_loss

3.定義測試函數

def test (dataloader, model,loss_fn):size= len(dataloader.dataset)#測試集的大小num_batches = len(dataloader)#批次數目,(size/batch_size向上取整)test_loss,test_acc = 0,0#當不進行訓練時,停止梯度更新,節省計算內存消耗with torch.no_grad():for imgs, target in dataloader:#計算losstarget_pred = model(imgs)loss= loss_fn(target_pred, target)test_loss += loss.item()test_acc+=(target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_loss

4.正式訓練

loss_fn=nn.CrossEntropyLoss()# 創建損失函數
learn_rate=1e-4  #學習率
opt= torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs=50
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)#獲取當前的學習率lr = opt.state_dict()['param_groups'][0]['lr']template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))print("="*20,'Done','='*20)

四、模型評估

1.loss與accuracy圖

import matplotlib.pyplot as plt
from datetime import datetime
#隱藏警告
import warnings
warnings.filterwarnings("ignore")#忽略警告信息
current_time=datetime.now()#獲取當前時間# 設置 Matplotlib 參數
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號
plt.rcParams['figure.dpi'] = 200  # 分辨率epochs_range = range(epochs)
plt.figure(figsize=(12,3))
plt.subplot(1,2,1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time)#打卡請帶上時間戳,否則代碼截圖無效plt.subplot(1,2,2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2.混淆矩陣

print('=======輸入數據shape為=======')
print("x_test.shape:",X_test.shape)
print("y_test.shape:",y_test.shape)
pred = model(X_test).argmax(1).cpu().numpy()
print("\n=====輸出數據Shape為=====")
print("pred.shape: ",pred.shape)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay
#計算混淆矩陣
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")#修改字體大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label",fontsize=10)
plt.ylabel("True Label",fontsize=10)#顯示圖
plt.tight_layout() # 調整布局防止重疊
plt.show()

3.調用模型進行預測

test_X=X_test[0].unsqueeze(1)  # X_test[0]即我們的輸入數據
pred = model(test_X).argmax(1).item()
print("模型預測結果為:",pred)
print("=="*10)
print("0:不會患心臟病")
print("1:可能患心臟病")

?

五、總結

RNN 的核心特點是它能夠利用序列中的歷史信息來影響當前的輸出。

RNN 的特點
  • 記憶性:RNN 能夠利用歷史信息來影響當前的輸出,這使得它在處理序列數據時非常有效。例如,在自然語言處理中,RNN 可以利用前面的單詞來預測下一個單詞。

  • 靈活性:RNN 的結構可以靈活地處理不同長度的序列數據,適用于各種序列任務,如語言建模、機器翻譯、語音識別等。

  • 動態性:RNN 的狀態是動態變化的,能夠適應序列中的時間依賴性。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/85526.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/85526.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/85526.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Pytorch知識點2

Pytorch知識點 1、官方教程2、張量🧱 0、數組概念🧱 1. 創建張量📐 2. 張量形狀與維度🔢 3. 張量數據類型? 4. 張量的數學與邏輯操作🔄 5. 張量的就地操作📦 6. 復制張量🚀 7. 將張量移動到加速…

池中錦鯉的自我修養,聊聊蓄水池算法

面試如泡池,蓄水似人生 起初你滿懷期待跳進大廠池子,以為自己是天選之子,結果發現池子里早擠滿了和你一樣的“錦鯉候選人”。HR的漁網一撒,撈誰全看概率——這不就是蓄水池算法的精髓嗎? 初入池(i≤k&…

Linux應用開發之網絡套接字編程

套接字(Socket)是計算機網絡數據通信的基本概念和編程接口,允許不同主機上的進程(運行中的程序)通過網絡進行數據交換。它為應用層軟件提供了發送和接收數據的能力,使得開發者可以在不用深入了解底層網絡細…

小白的進階之路系列之六----人工智能從初步到精通pytorch數據集與數據加載器

本文將介紹以下內容: 數據集與數據加載器 數據遷移 如何建立神經網絡 數據集與數據加載器 處理數據樣本的代碼可能會變得混亂且難以維護;理想情況下,我們希望我們的數據集代碼與模型訓練代碼解耦,以獲得更好的可讀性和模塊化。PyTorch提供了兩個數據原語:torch.utils…

深入理解設計模式之中介者模式

深入理解設計模式之:中介者模式(Mediator Pattern) 一、什么是中介者模式? 中介者模式(Mediator Pattern)是一種行為型設計模式。它通過引入一個中介對象,來封裝一組對象之間的交互&#xff0…

基于通義千問的兒童陪伴學習和成長的智能應用架構。

1.整體架構概覽 我們的兒童聊天助手將采用典型的語音交互系統架構,結合大模型能力和外部知識庫: 2. 技術方案分解 2.1. 前端應用/設備 選擇: 移動App(iOS/Android)、Web應用,或者集成到智能音箱/平板等硬件設備中。技術棧: 移動App: React Native / Flutter (跨平臺…

Python Day40

Task: 1.彩色和灰度圖片測試和訓練的規范寫法:封裝在函數中 2.展平操作:除第一個維度batchsize外全部展平 3.dropout操作:訓練階段隨機丟棄神經元,測試階段eval模式關閉dropout 作業:仔細學習下測試和訓練代…

WordPress_suretriggers 權限繞過漏洞復現(CVE-2025-3102)

免責申明: 本文所描述的漏洞及其復現步驟僅供網絡安全研究與教育目的使用。任何人不得將本文提供的信息用于非法目的或未經授權的系統測試。作者不對任何由于使用本文信息而導致的直接或間接損害承擔責任。如涉及侵權,請及時與我們聯系,我們將盡快處理并刪除相關內容。 前…

基于Spring Boot 電商書城平臺系統設計與實現(源碼+文檔+部署講解)

技術范圍:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬蟲、數據可視化、小程序、安卓app、大數據、物聯網、機器學習等設計與開發。 主要內容:免費功能設計、開題報告、任務書、中期檢查PPT、系統功能實現、代碼編寫、論文編寫和輔導、論文…

LeetCode 39.組合總和:回溯法與剪枝優化的完美結合

一、問題本質與形式化定義 1.1 題目形式化描述 輸入:無重復整數數組candidates、目標值target輸出:所有和為target的組合集合,滿足: 元素可重復使用組合內元素非降序(避免重復解)解集無重復組合 1.2 問…

windows11安裝編譯QtMvvm

windows11安裝編譯QtMvvm 1 從github下載代碼2 官方的Download/Installtion3 自行構建編譯QtMvvm遇到的問題3.1 `qmake`問題執行命令報錯原因分析qmake報錯:找不到編譯器 cl解決方案3.2 `make qmake_all`問題執行命令報錯原因分析make命令未識別解決方案3.3 缺少`perl`問題執行…

unix/linux source 命令,其歷史爭議、兼容性、生態、未來展望

現在把目光投向unix/linux source命令的歷史爭議、兼容性、生態和未來展望,這能讓我們更全面地理解一個技術點在更廣闊的圖景中所處的位置。 一、歷史爭議與設計權衡 雖然 source (或 .) 命令功能強大且不可或缺,但在其發展和使用過程中,也存在一些微妙的爭議或設計上的權衡…

開發時如何通過Service暴露應用?ClusterIP、NodePort和LoadBalancer類型的使用場景分別是什么?

一、Service核心概念 Service通過標簽選擇器(Label Selector)關聯Pod,為動態變化的Pod集合提供穩定的虛擬IP和DNS名稱,主要解決: 服務發現負載均衡流量路由 二、Service類型詳解 1. ClusterIP(默認類型…

從線性代數到線性回歸——機器學習視角

真正不懂數學就能理解機器學習其實是個神話。我認為,AI 在商業世界可以不懂數學甚至不懂編程也能應用,但對于技術人員來說,一些基礎數學是必須的。本文收集了我認為理解學習本質所必需的數學基礎,至少在概念層面要掌握。畢竟&…

華為IP(7)

端口隔離技術 產生的背景 1.以太交換網絡中為了實現報文之間的二層隔離,用戶通常將不同的端口加入不同的VLAN,實現二層廣播域的隔離。 2.大型網絡中,業務需求種類繁多,只通過VLAN實現二層隔離,會浪費有限的VLAN資源…

Docker Desktop無法在windows低版本進行安裝

問題描述 因工作需要,現在一臺低版本的window系統進行Docker Desktop的安裝,但是安裝過程當中出現了報錯信息 系統版本配置 原因分析: 關于本機查看了系統的版本號,版本號如下為1909,但是docker Desktop要求的最低的win10版本…

深入理解 Maven 循環依賴問題及其解決方案

在 Java 開發領域,Maven 作為主流構建工具極大簡化了依賴管理和項目構建。然而**循環依賴(circular dependency)**問題仍是常見挑戰,輕則導致構建失敗,重則引發類加載異常和系統架構混亂。 本文將從根源分析循環依賴的…

Git 全平臺安裝指南:從 Linux 到 Windows 的詳細教程

目錄 一、Git 簡介 二、Linux 系統安裝指南 1、CentOS/RHEL 系統安裝 2、Ubuntu/Debian 系統安裝 3、Windows 系統安裝 四、安裝后配置(后面會詳細講解,現在了解即可) 五、視頻教程參考 一、Git 簡介 Git 是一個開源的分布式版本控制系…

微服務-Sentinel

目錄 背景 Sentinel使用 Sentinel控制臺 Sentinel控制規則 Sentinel整合OpenFeign 背景 在微服務項目架構中,存在多個服務相互調用場景,在某些情況下某個微服務不可用時,上游調用者若一直等待,會產生資源的消耗,極端情…

智慧零工平臺前端開發實戰:從uni-app到跨平臺應用

智慧零工平臺前端開發實戰:從uni-app到跨平臺應用 本文將詳細介紹我如何使用uni-app框架開發一個支持微信小程序和H5的零工平臺前端應用,包含技術選型、架構設計、核心功能實現及部署經驗。 前言 在當今移動互聯網時代,跨平臺開發已成為提高開發效率的重要手段。本次我選擇…