本篇博客介紹使用Python語言的深度學習網絡,從零搭建一個ECG深度學習網絡。
任務
本次入門的任務是,篩選出MIT-BIH數據集中注釋為[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]的數據作為本次數據集,然后按照8:2的比例劃分為訓練集,驗證集。最后送入RCNN模型進行訓練。
1. 數據集介紹
本次使用大名鼎鼎的MIT-BIH Arrhythmia Database數據集。下載地址:https://physionet.org/content/mitdb/1.0.0/
MIT系列有很多數據集,都可以在生理網:https://physionet.org/about/database/ 上找到下載地址。本次使用的MT-BIH心律失常數據庫擁有48條心電記錄,且每個記錄的時長是30分鐘。這些記錄來自于47名研究對象。這些研究對象包括25名男性和22名女性,其年齡介于23到89歲(其中記錄201與202來自于同一個人)。信號的采樣率為360赫茲,AD分辨率為11比特。對于每條記錄來說,均包含兩個通道的信號。第一個通道一般為MLⅡ導聯(記錄102和104為V5導聯);第二個通道一般為V1導聯(有些為V2導聯或V5導聯,其中記錄124號為Ⅴ4導聯)。為了保持導聯的一致性,往往在研究中采用MLⅡ導聯。
在生理網:https://physionet.org/about/database/上,我們可以看到數據集更加詳細的說明。比如:
MIT-BIH 數據集每個單獨病人的說明:https://www.physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm
MIT-BIH 數據集每個單獨病人的整個數據以及注釋的可視化:https://www.physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm
下載MIT-BIH 數據集之后,我們需要知曉以下幾點:
- 從100-234不連續號碼,一共48個病人。每個病人有三個文件(.dat,.atr,*.hea),包含有兩路心電信號,一個注釋。
- 有專門庫讀取MIT-BIH 數據集,叫做 wfdb。所以不要擔心文件后綴的陌生感。
- 對心電圖的標注樣式如上圖,“A"代表心房早搏,”."代表正常。整個數據集標注有40多種符號,表示40多種心拍狀態。
2. 提取數據集
提取之前,先安裝必要的庫wfdb。wfdb詳細介紹
pip install wfdb
這個庫非常強大,打印數據信息,讀取數據,繪制心電波形圖,都可以靠它完成。
現在我們的劃分步驟是:
- 提取出所有心電圖數據點,心電圖注釋點
- 篩選出所有心電圖注釋點中僅為[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]某一類的注釋點
- 截取心電圖數據中標記為[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]某一類的點,在點周圍長度為300的數據
- 將得到的數據進行維度處理,送入DataLoader()函數,完成模型對數據的認可。
3. 定義模型
本次使用的模型是輸入大小為300,3層循環,隱藏層大小50。
'''
模型搭建
'''
class RnnModel(nn.Module):def __init__(self):super(RnnModel, self).__init__()'''參數解釋:(輸入維度,隱藏層維度,網絡層數)'''self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')self.linear = nn.Linear(50, 5)def forward(self, x):r_out, h_state = self.rnn(x)output = self.linear(r_out[:,-1,:]) # 將 RNN 層的輸出 r_out 在最后一個時間步上的輸出(隱藏狀態)傳遞給線性層return outputmodel = RnnModel()
4. 全部訓練代碼
'''
導入相關包
'''
import wfdb
import pywt
import seaborn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
import torch.utils.data as Data
from torch import nn'''
加載數據集
'''# 測試集在數據集中所占的比例
RATIO = 0.2# 小波去噪預處理
def denoise(data):# 小波變換coeffs = pywt.wavedec(data=data, wavelet='db5', level=9)cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs# 閾值去噪threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1))))cD1.fill(0)cD2.fill(0)for i in range(1, len(coeffs) - 2):coeffs[i] = pywt.threshold(coeffs[i], threshold)# 小波反變換,獲取去噪后的信號rdata = pywt.waverec(coeffs=coeffs, wavelet='db5')return rdata# 讀取心電數據和對應標簽,并對數據進行小波去噪
def getDataSet(number, X_data, Y_data):ecgClassSet = ['N', 'A', 'V', 'L', 'R']# 讀取心電數據記錄# print("正在讀取 " + number + " 號心電數據...")# 讀取MLII導聯的數據record = wfdb.rdrecord('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, channel_names=['MLII'])data = record.p_signal.flatten()rdata = denoise(data=data)# 獲取心電數據記錄中R波的位置和對應的標簽annotation = wfdb.rdann('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, 'atr')Rlocation = annotation.sampleRclass = annotation.symbol# 去掉前后的不穩定數據start = 10end = 5i = startj = len(annotation.symbol) - end# 因為只選擇NAVLR五種心電類型,所以要選出該條記錄中所需要的那些帶有特定標簽的數據,舍棄其余標簽的點# X_data在R波前后截取長度為300的數據點# Y_data將NAVLR按順序轉換為01234while i < j:try:# Rclass[i] 是標簽lable = ecgClassSet.index(Rclass[i]) # 這一步就是相當于拋棄了不在ecgClassSet的索引# 基于經驗值,基于R峰向前取100個點,向后取200個點x_train = rdata[Rlocation[i] - 100:Rlocation[i] + 200]X_data.append(x_train)Y_data.append(lable)i += 1except ValueError:i += 1return# 加載數據集并進行預處理
def loadData():numberSet = ['100', '101', '103', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115','116', '117', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '208','210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230','231', '232', '233', '234']dataSet = []lableSet = []for n in numberSet:getDataSet(n, dataSet, lableSet)# 轉numpy數組,打亂順序dataSet = np.array(dataSet).reshape(-1, 300) # 轉化為二維,一行有 300 個,行數需要計算lableSet = np.array(lableSet).reshape(-1, 1) # 轉化為二維,一行有 1 個,行數需要計算train_ds = np.hstack((dataSet, lableSet)) # 將數據集和標簽集水平堆疊在一起,(92192, 300) (92192, 1) (92192, 301)# print(dataSet.shape, lableSet.shape, train_ds.shape) # (92192, 300) (92192, 1) (92192, 301)np.random.shuffle(train_ds)# 數據集及其標簽集X = train_ds[:, :300].reshape(-1, 1, 300) # (92192, 1, 300)Y = train_ds[:, 300] # (92192)# 測試集及其標簽集shuffle_index = np.random.permutation(len(X)) # 生成0-(X-1)的隨機索引數組# 設定測試集的大小 RATIO是測試集在數據集中所占的比例test_length = int(RATIO * len(shuffle_index))# 測試集的長度test_index = shuffle_index[:test_length]# 訓練集的長度train_index = shuffle_index[test_length:]X_test, Y_test = X[test_index], Y[test_index]X_train, Y_train = X[train_index], Y[train_index]return X_train, Y_train, X_test, Y_testX_train, Y_train, X_test, Y_test = loadData()'''
數據處理
'''
train_Data = Data.TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train)) # 返回結果為一個個元組,每一個元組存放數據和標簽
train_loader = Data.DataLoader(dataset=train_Data, batch_size=128)
test_Data = Data.TensorDataset(torch.Tensor(X_test), torch.Tensor(Y_test)) # 返回結果為一個個元組,每一個元組存放數據和標簽
test_loader = Data.DataLoader(dataset=test_Data, batch_size=128)'''
模型搭建
'''
class RnnModel(nn.Module):def __init__(self):super(RnnModel, self).__init__()'''參數解釋:(輸入維度,隱藏層維度,網絡層數)'''self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')self.linear = nn.Linear(50, 5)def forward(self, x):r_out, h_state = self.rnn(x)output = self.linear(r_out[:,-1,:]) # 將 RNN 層的輸出 r_out 在最后一個時間步上的輸出(隱藏狀態)傳遞給線性層return outputmodel = RnnModel()'''
設置損失函數和參數優化方法
'''
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)'''
模型訓練
'''
EPOCHS = 5
for epoch in range(EPOCHS):running_loss = 0for i, data in enumerate(train_loader):inputs, label = datay_predict = model(inputs)loss = criterion(y_predict, label.long())optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 預測correct = 0total = 0with torch.no_grad():for data in test_loader:inputs, label = datay_pred = model(inputs)_, predicted = torch.max(y_pred.data, dim=1)total += label.size(0)correct += (predicted == label).sum().item()print(f'Epoch: {epoch + 1}, ACC on test: {correct / total}')