從0搭建ECG深度學習網絡

本篇博客介紹使用Python語言的深度學習網絡,從零搭建一個ECG深度學習網絡。

任務

本次入門的任務是,篩選出MIT-BIH數據集中注釋為[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]的數據作為本次數據集,然后按照8:2的比例劃分為訓練集,驗證集。最后送入RCNN模型進行訓練。

1. 數據集介紹

本次使用大名鼎鼎的MIT-BIH Arrhythmia Database數據集。下載地址:https://physionet.org/content/mitdb/1.0.0/

MIT系列有很多數據集,都可以在生理網:https://physionet.org/about/database/ 上找到下載地址。本次使用的MT-BIH心律失常數據庫擁有48條心電記錄,且每個記錄的時長是30分鐘。這些記錄來自于47名研究對象。這些研究對象包括25名男性和22名女性,其年齡介于23到89歲(其中記錄201與202來自于同一個人)。信號的采樣率為360赫茲,AD分辨率為11比特。對于每條記錄來說,均包含兩個通道的信號。第一個通道一般為MLⅡ導聯(記錄102和104為V5導聯);第二個通道一般為V1導聯(有些為V2導聯或V5導聯,其中記錄124號為Ⅴ4導聯)。為了保持導聯的一致性,往往在研究中采用MLⅡ導聯。

在生理網:https://physionet.org/about/database/上,我們可以看到數據集更加詳細的說明。比如:

MIT-BIH 數據集每個單獨病人的說明:https://www.physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm

MIT-BIH 數據集每個單獨病人的整個數據以及注釋的可視化:https://www.physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm

下載MIT-BIH 數據集之后,我們需要知曉以下幾點:

  1. 從100-234不連續號碼,一共48個病人。每個病人有三個文件(.dat,.atr,*.hea),包含有兩路心電信號,一個注釋。
  2. 有專門庫讀取MIT-BIH 數據集,叫做 wfdb。所以不要擔心文件后綴的陌生感。
  3. 對心電圖的標注樣式如上圖,“A"代表心房早搏,”."代表正常。整個數據集標注有40多種符號,表示40多種心拍狀態。

2. 提取數據集

提取之前,先安裝必要的庫wfdb。wfdb詳細介紹

pip install wfdb

這個庫非常強大,打印數據信息,讀取數據,繪制心電波形圖,都可以靠它完成。
現在我們的劃分步驟是:

  1. 提取出所有心電圖數據點,心電圖注釋點
  2. 篩選出所有心電圖注釋點中僅為[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]某一類的注釋點
  3. 截取心電圖數據中標記為[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]某一類的點,在點周圍長度為300的數據
  4. 將得到的數據進行維度處理,送入DataLoader()函數,完成模型對數據的認可。

3. 定義模型

本次使用的模型是輸入大小為300,3層循環,隱藏層大小50。

'''
模型搭建
'''
class RnnModel(nn.Module):def __init__(self):super(RnnModel, self).__init__()'''參數解釋:(輸入維度,隱藏層維度,網絡層數)'''self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')self.linear = nn.Linear(50, 5)def forward(self, x):r_out, h_state = self.rnn(x)output = self.linear(r_out[:,-1,:])  # 將 RNN 層的輸出 r_out 在最后一個時間步上的輸出(隱藏狀態)傳遞給線性層return outputmodel = RnnModel()

4. 全部訓練代碼

'''
導入相關包
'''
import wfdb
import pywt
import seaborn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
import torch.utils.data as Data
from torch import nn'''
加載數據集
'''# 測試集在數據集中所占的比例
RATIO = 0.2# 小波去噪預處理
def denoise(data):# 小波變換coeffs = pywt.wavedec(data=data, wavelet='db5', level=9)cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs# 閾值去噪threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1))))cD1.fill(0)cD2.fill(0)for i in range(1, len(coeffs) - 2):coeffs[i] = pywt.threshold(coeffs[i], threshold)# 小波反變換,獲取去噪后的信號rdata = pywt.waverec(coeffs=coeffs, wavelet='db5')return rdata# 讀取心電數據和對應標簽,并對數據進行小波去噪
def getDataSet(number, X_data, Y_data):ecgClassSet = ['N', 'A', 'V', 'L', 'R']# 讀取心電數據記錄# print("正在讀取 " + number + " 號心電數據...")# 讀取MLII導聯的數據record = wfdb.rdrecord('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, channel_names=['MLII'])data = record.p_signal.flatten()rdata = denoise(data=data)# 獲取心電數據記錄中R波的位置和對應的標簽annotation = wfdb.rdann('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, 'atr')Rlocation = annotation.sampleRclass = annotation.symbol# 去掉前后的不穩定數據start = 10end = 5i = startj = len(annotation.symbol) - end# 因為只選擇NAVLR五種心電類型,所以要選出該條記錄中所需要的那些帶有特定標簽的數據,舍棄其余標簽的點# X_data在R波前后截取長度為300的數據點# Y_data將NAVLR按順序轉換為01234while i < j:try:# Rclass[i] 是標簽lable = ecgClassSet.index(Rclass[i])  # 這一步就是相當于拋棄了不在ecgClassSet的索引# 基于經驗值,基于R峰向前取100個點,向后取200個點x_train = rdata[Rlocation[i] - 100:Rlocation[i] + 200]X_data.append(x_train)Y_data.append(lable)i += 1except ValueError:i += 1return# 加載數據集并進行預處理
def loadData():numberSet = ['100', '101', '103', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115','116', '117', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '208','210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230','231', '232', '233', '234']dataSet = []lableSet = []for n in numberSet:getDataSet(n, dataSet, lableSet)# 轉numpy數組,打亂順序dataSet = np.array(dataSet).reshape(-1, 300)  # 轉化為二維,一行有 300 個,行數需要計算lableSet = np.array(lableSet).reshape(-1, 1)  # 轉化為二維,一行有   1 個,行數需要計算train_ds = np.hstack((dataSet, lableSet))  # 將數據集和標簽集水平堆疊在一起,(92192, 300) (92192, 1) (92192, 301)# print(dataSet.shape, lableSet.shape, train_ds.shape)  # (92192, 300) (92192, 1) (92192, 301)np.random.shuffle(train_ds)# 數據集及其標簽集X = train_ds[:, :300].reshape(-1, 1, 300)  # (92192, 1, 300)Y = train_ds[:, 300]  # (92192)# 測試集及其標簽集shuffle_index = np.random.permutation(len(X))  # 生成0-(X-1)的隨機索引數組# 設定測試集的大小 RATIO是測試集在數據集中所占的比例test_length = int(RATIO * len(shuffle_index))# 測試集的長度test_index = shuffle_index[:test_length]# 訓練集的長度train_index = shuffle_index[test_length:]X_test, Y_test = X[test_index], Y[test_index]X_train, Y_train = X[train_index], Y[train_index]return X_train, Y_train, X_test, Y_testX_train, Y_train, X_test, Y_test = loadData()'''
數據處理
'''
train_Data = Data.TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train)) # 返回結果為一個個元組,每一個元組存放數據和標簽
train_loader = Data.DataLoader(dataset=train_Data, batch_size=128)
test_Data = Data.TensorDataset(torch.Tensor(X_test), torch.Tensor(Y_test)) # 返回結果為一個個元組,每一個元組存放數據和標簽
test_loader = Data.DataLoader(dataset=test_Data, batch_size=128)'''
模型搭建
'''
class RnnModel(nn.Module):def __init__(self):super(RnnModel, self).__init__()'''參數解釋:(輸入維度,隱藏層維度,網絡層數)'''self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')self.linear = nn.Linear(50, 5)def forward(self, x):r_out, h_state = self.rnn(x)output = self.linear(r_out[:,-1,:])  # 將 RNN 層的輸出 r_out 在最后一個時間步上的輸出(隱藏狀態)傳遞給線性層return outputmodel = RnnModel()'''
設置損失函數和參數優化方法
'''
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)'''
模型訓練
'''
EPOCHS = 5
for epoch in range(EPOCHS):running_loss = 0for i, data in enumerate(train_loader):inputs, label = datay_predict = model(inputs)loss = criterion(y_predict, label.long())optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 預測correct = 0total = 0with torch.no_grad():for data in test_loader:inputs, label = datay_pred = model(inputs)_, predicted = torch.max(y_pred.data, dim=1)total += label.size(0)correct += (predicted == label).sum().item()print(f'Epoch: {epoch + 1}, ACC on test: {correct / total}')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/40390.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/40390.shtml
英文地址,請注明出處:http://en.pswp.cn/news/40390.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

什么是DNS服務器的層次化和分布式?

DNS (Domain Name System) 的結構是層次化的&#xff0c;意味著它是由多個級別的服務器組成&#xff0c;每個級別負責不同的部分。以下是 DNS 結構的層次&#xff1a; 根域服務器&#xff08;Root Servers&#xff09;&#xff1a; 這是 DNS 層次結構的最高級別。全球有13組根域…

【云原生】Docker 詳解(二):Docker 架構及工作原理

Docker 詳解&#xff08;二&#xff09;&#xff1a;Docker 架構及工作原理 Docker 在運行時分為 Docker 引擎&#xff08;服務端守護進程&#xff09; 和 客戶端工具&#xff0c;我們日常使用各種 docker 命令&#xff0c;其實就是在使用 客戶端工具 與 Docker 引擎 進行交互。…

[oneAPI] 手寫數字識別-LSTM

[oneAPI] 手寫數字識別-LSTM 手寫數字識別參數與包加載數據模型訓練過程結果 oneAPI 比賽&#xff1a;https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517 Intel DevCloud for oneAPI&#xff1a;https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolk…

Curson 編輯器

Curson 漢化與vacode一樣 Curson 自帶chat功能 1、快捷鍵ctrlk(代碼中編輯) 2、快捷鍵ctrll 右側打開窗口

為什么hive會出現_HIVE_DEFAULT_PARTITION分區

問題&#xff1a; 為什么hive表中出現_HIVE_DEFAULT_PARTITION分區&#xff1f; 解答&#xff1a; 因為在業務sql中使用的是動態分區&#xff0c;并且hive啟用動態分區時&#xff0c;對于指定的分區鍵如果存在空值時&#xff0c;會對空值部分創建一個默認分區用于存儲該部分…

小程序項目組件的基本應用

宿主環境&#xff1a;程序運行必須依賴的環境 小程序的宿主環境 ---->手機微信(定位、掃碼、支付等) 小程序的通信模型&#xff1a; 渲染層和邏輯層之間的通信(微信客戶端轉發)邏輯層和第三方服務器之間的通信(微信客戶端轉發) 小程序的運行機制&#xff1a; 啟動&#xff1…

c#實現工廠模式

可以使用以下代碼實現C#中的工廠模式&#xff1a; 首先&#xff0c;定義一個接口作為產品的抽象&#xff1a; public interface IProduct {void Operation(); }然后&#xff0c;創建具體的產品類&#xff1a; public class ConcreteProductA : IProduct {public void Operat…

vue基礎知識五:請描述下你對vue生命周期的理解?在created和mounted這兩個生命周期中請求數據有什么區別呢?

一、生命周期是什么 生命周期&#xff08;Life Cycle&#xff09;的概念應用很廣泛&#xff0c;特別是在政治、經濟、環境、技術、社會等諸多領域經常出現&#xff0c;其基本涵義可以通俗地理解為“從搖籃到墳墓”&#xff08;Cradle-to-Grave&#xff09;的整個過程在Vue中實…

41 | 京東商家書籍評論數據分析

京東作為中國領先的電子商務平臺,積累了大量商品評論數據,這些數據蘊含了豐富的信息。通過文本數據分析,我們可以了解用戶對產品的態度、評價的關鍵詞、消費者的需求等,從而有助于商家優化產品和服務,以及消費者作出更明智的購買決策。 本文將詳細闡述如何獲取京東商家評…

Python opennsfw/opennsfw2 圖片/視頻 鑒黃 筆記

nsfw&#xff08; Not Suitable for Work&#xff09;直接翻譯就是 工作的時候不適合看&#xff0c;真文雅 nsfw效果&#xff0c;注意底部的分數 大體流程&#xff0c;輸入圖片/視頻&#xff0c;輸出0-1之間的數字&#xff0c;一般情況下&#xff0c;Scores < 0.2 認為是非…

7zip分卷壓縮

前言 有些項目上傳文件大小有限制 壓縮包大了之后傳輸也會比較慢 解決方案 我們可以利用7zip壓縮工具對文件進行分卷壓縮 利用7zip壓縮工具進行分卷壓縮 查看待壓縮文件大小 壓縮完成之后有300多M&#xff0c;我們用100M去進行分卷壓縮 選擇待壓縮的文件夾&#xff0c;右…

網絡安全 Day30-運維安全項目-容器架構上

容器架構上 1. 什么是容器2. 容器 vs 虛擬機(化) :star::star:3. Docker極速上手指南1&#xff09;使用rpm包安裝docker2) docker下載鏡像加速的配置3) 載入鏡像大禮包&#xff08;老師資料包中有&#xff09; 4. Docker使用案例1&#xff09; 案例01&#xff1a;:star::star::…

《內網穿透》無需公網IP,公網SSH遠程訪問家中的樹莓派

文章目錄 前言 如何通過 SSH 連接到樹莓派步驟1. 在 Raspberry Pi 上啟用 SSH步驟2. 查找樹莓派的 IP 地址步驟3. SSH 到你的樹莓派步驟 4. 在任何地點訪問家中的樹莓派4.1 安裝 Cpolar內網穿透4.2 cpolar進行token認證4.3 配置cpolar服務開機自啟動4.4 查看映射到公網的隧道地…

【JavaEE基礎學習打卡02】是時候了解Java EE了!

目錄 前言一、為什么要學習Java EE二、Java EE規范介紹1.什么是規范&#xff1f;2.什么是Java EE規范&#xff1f;3.Java EE版本 三、Java EE應用程序模型1.模型前置說明2.模型具體說明 總結 前言 &#x1f4dc; 本系列教程適用于 Java Web 初學者、愛好者&#xff0c;小白白。…

java接口導出csv

1、背景介紹 項目中需要導出數據質檢結果&#xff0c;本來使用Excel&#xff0c;但是質檢結果數據行數過多&#xff0c;導致用hutool報錯&#xff0c;因此轉為導出csv格式數據。 2、參考文檔 https://blog.csdn.net/ityqing/article/details/127879556 工程環境&#xff1a;…

Redis-分布式鎖!

分布式鎖&#xff0c;顧名思義&#xff0c;分布式鎖就是分布式場景下的鎖&#xff0c;比如多臺不同機器上的進程&#xff0c;去競爭同一項資源&#xff0c;就是分布式鎖。 分布式鎖特性 互斥性:鎖的目的是獲取資源的使用權&#xff0c;所以只讓一個競爭者持有鎖&#xff0c;這…

PyTorch: clamp函數與梯度的關系

本文主要以下探究這一點&#xff1a;梯度反向傳播過程中&#xff0c;測試強行修改后的預測結果是否還會傳遞loss&#xff1f; clamp應用場景&#xff1a;在深度學習計算損失函數的過程中&#xff0c;會有這樣一個問題&#xff0c;如果Label是1.0&#xff0c;而預測結果是0.0&a…

【算法】排序+雙指針——leetcode三數之和、四數之和

三數之和 &#xff08;1&#xff09;排序雙指針 算法思路&#xff1a; 和之前的兩數之和類似&#xff0c;我們對暴力枚舉進行了一些優化&#xff0c;利用了排序雙指針的思路&#xff1a; 我們先排序&#xff0c;然后固定?個數 a &#xff0c;接著我們就可以在這個數后面的區間…

Mybatis Plus Interceptor

Mybatis Plus Interceptor 1 獲取表名2 獲取SQL 1 獲取表名 Component public class MybatisInterceptor implements Interceptor {private static final List<String> EXCLUDE_TABLE new ArrayList<>();static {EXCLUDE_TABLE.add("test");}private s…

OpenCV實例(九)基于深度學習的運動目標檢測(一)YOLO運動目標檢測算法

基于深度學習的運動目標檢測&#xff08;一&#xff09; 1.YOLO算法檢測流程2.YOLO算法網絡架構3.網絡訓練模型3.1 訓練策略3.2 代價函數的設定 2012年&#xff0c;隨著深度學習技術的不斷突破&#xff0c;開始興起基于深度學習的目標檢測算法的研究浪潮。 2014年&#xff0c;…