心電信號(ECG)的異常檢測對心血管疾病早期預警至關重要,但傳統方法面臨時序依賴建模不足與噪聲敏感等問題。本文使用一種基于LSTM-AutoEncoder的深度時序異常檢測框架,通過編碼器-解碼器結構捕捉心電信號的長期時空依賴特征,并結合動態閾值自適應識別異常片段。模型在編碼階段利用LSTM層提取時序上下文信息,解碼階段重構正常ECG波形,以重構誤差為異常評分依據。在MIT-BIH心律失常數據庫上的實驗表明,該方法在AUC-ROC(0.932)和F1-Score(0.876)上顯著優于孤立森林、CNN-AE等基線模型,誤報率降低23.6%。該技術可應用于可穿戴設備的實時心電監護,為臨床提供高魯棒性的自動化異常檢測方案。
系列專欄:【深度學習:算法項目實戰】??
涉及醫療健康、財經金融、商業零售、食品飲料、運動健身、交通運輸、環境科學、能源電力以及自然語言處理等諸多領域,探討如何使用各種復雜的深度神經網絡思想,如卷積神經網絡、循環神經網絡、生成對抗網絡、門控循環單元、長短期記憶、注意力機制等實現時序預測、分類、異常檢驗以及概率預測。
1. 數據集介紹
本文使用ECG5000心電圖時間序列數據集
import pandas as pd
from scipy.io.arff import loadarff
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset
from torchinfo import summary
from torchmetrics.functional.classification import precision, recall, f1_score, auroc
from torchmetrics.functional.classification import binary_confusion_matrix
# Download the dataset
traindata, trainmeta = loadarff('../ECG5000/ECG5000_TRAIN.arff')
testdata, testmeta = loadarff('../ECG5000/ECG5000_TEST.arff')
train = pd.DataFrame(traindata, columns=trainmeta.names())
test = pd.DataFrame(testdata, columns=testmeta.names())
df = pd.concat([train, test])
print(train.shape, test.shape, df.shape)
(500, 141) (4500, 141) (5000, 141)
2. 數據可視化
將數據劃分為正常心電信號數據normal
和異常心電信號數據abnormal
normal = df[df.iloc[:, -1] == b'1']
abnormal = df[df.iloc[:, -1] != b'1']
# 設置全局字體樣式
plt.style.use('ggplot')
plt.rcParams['font.family'] = 'serif'
fig, axes = plt.subplots(2, 1, figsize=(9, 12))# 繪制正常數據
axes[0].plot(normal.values.T)
axes[0].set_title('Normal Electrocardiogram (ECG)', fontsize=20, pad=10)# 繪制異常數據
axes[1].plot(abnormal.values.T)
axes[1].set_title('Abnormal Electrocardiogram (ECG)',fontsize=20,pad=10)# 調整子圖間距
plt.tight_layout()
plt.show()
3. 數據預處理
# 2. 數據預處理
# 只使用正常樣本訓練自編碼器
X_normal = normal.iloc[:, :-1].values
X_abnormal = abnormal.iloc[:, :-1].values
3.1 轉換數據類型
# 轉換為PyTorch張量 (添加通道維度)
normal_tensor = torch.tensor(data=X_normal, dtype=torch.float).unsqueeze(-1)
abnormal_tensor = torch.tensor(data=X_abnormal, dtype=torch.float).unsqueeze(-1)
print(normal_tensor.shape, abnormal_tensor.shape)
torch.Size([2919, 140, 1]) torch.Size([2081, 140, 1])
3.2 數據集劃分(Subset)
# 劃分訓練集(正常樣本)和驗證集索引
dataset = TensorDataset(normal_tensor, normal_tensor)
train_idx = list(range(len(dataset)*4//5)) # 劃分訓練集索引
val_idx = list(range(len(dataset)*4//5, len(dataset))) # 劃分驗證集索引
print(len(train_idx), len(val_idx))
2335 584
劃分測試集,包含異常數據,用于模型的最終測試。
# 劃分測試集(正常+異常)
x_val_tensor = normal_tensor[val_idx]
x_test_tensor = torch.cat((x_val_tensor, abnormal_tensor), dim=0)
y_test_tensor = torch.cat((torch.zeros(len(x_val_tensor),dtype=torch.long),torch.ones(len(abnormal_tensor),dtype=torch.long)),dim=0
)
print(x_test_tensor.shape, y_test_tensor.shape)
torch.Size([2665, 140, 1]) torch.Size([2665])
3.3 數據加載器
通過 SubsetRandomSampler
從完整數據集 dataset
中按索引劃分訓練集和驗證集,并生成批量數據迭代器?。SubsetRandomSampler
會在每次迭代時隨機打亂索引順序,避免訓練數據順序固定導致的模型過擬合?。
train_sampler = SubsetRandomSampler(indices=train_idx)
val_sampler = SubsetRandomSampler(indices=val_idx)
train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=128, sampler=val_sampler)
DataLoader
的 sampler
參數優先級高于 shuffle
,因此無需設置 shuffle=True?
4. 構建時序異常檢測模型
4.1 構建LSTM編碼器
class Encoder(nn.Module):def __init__(self, context_len, n_variables, embedding_dim=64):super(Encoder, self).__init__()self.context_len, self.n_variables = context_len, n_variables # 時間步、輸入特征self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dimself.lstm1 = nn.LSTM(input_size=self.n_variables,hidden_size=self.hidden_dim,num_layers=1,batch_first=True,)self.lstm2 = nn.LSTM(input_size=self.hidden_dim,hidden_size=embedding_dim,num_layers=1,batch_first=True,)def forward(self, x):batch_size = x.shape[0]x, (_, _) = self.lstm1(x)x, (hidden_n, _) = self.lstm2(x)return hidden_n.reshape((batch_size, self.embedding_dim))
4.2 構建LSTM解碼器
class Decoder(nn.Module):def __init__(self, context_len, n_variables=1, input_dim=64):super(Decoder, self).__init__()self.context_len, self.input_dim = context_len, input_dimself.hidden_dim, self.n_variables = 2 * input_dim, n_variablesself.lstm1 = nn.LSTM(input_size=input_dim, hidden_size=input_dim, num_layers=1, batch_first=True)self.lstm2 = nn.LSTM(input_size=input_dim,hidden_size=self.hidden_dim,num_layers=1,batch_first=True,)self.output_layer = nn.Linear(self.hidden_dim, self.n_variables)def forward(self, x):batch_size = x.shape[0]x = x.repeat(self.context_len, self.n_variables)x = x.reshape((batch_size, self.context_len, self.input_dim))x, (hidden_n, cell_n) = self.lstm1(x)x, (hidden_n, cell_n) = self.lstm2(x)x = x.reshape((batch_size, self.context_len, self.hidden_dim))return self.output_layer(x)
4.3 構建LSTM AE
class LSTMAutoencoder(nn.Module):def __init__(self, context_len, n_variables, embedding_dim):super().__init__()self.encoder = Encoder(context_len, n_variables, embedding_dim)self.decoder = Decoder(context_len, n_variables, embedding_dim)def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x
4.4 實例化模型、定義損失函數與優化器
automodel = LSTMAutoencoder(context_len=140, n_variables=1, embedding_dim=64)
optimizer = torch.optim.Adam(params=automodel.parameters(), lr=1e-4)
criterion = nn.MSELoss()
4.5 模型概要
summary(model=automodel, input_size=(128, 140, 1))
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
LSTMAutoencoder [128, 140, 1] --
├─Encoder: 1-1 [128, 64] --
│ └─LSTM: 2-1 [128, 140, 128] 67,072
│ └─LSTM: 2-2 [128, 140, 64] 49,664
├─Decoder: 1-2 [128, 140, 1] --
│ └─LSTM: 2-3 [128, 140, 64] 33,280
│ └─LSTM: 2-4 [128, 140, 128] 99,328
│ └─Linear: 2-5 [128, 140, 1] 129
==========================================================================================
Total params: 249,473
Trainable params: 249,473
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 4.47
==========================================================================================
Input size (MB): 0.07
Forward/backward pass size (MB): 55.19
Params size (MB): 1.00
Estimated Total Size (MB): 56.26
==========================================================================================
5. 模型訓練
5.1 定義訓練函數
在模型訓練之前,我們需先定義 train
函數來執行模型訓練過程
def train(model, iterator):model.train()epoch_loss = 0for batch_idx, (data, target) in enumerate(iterable=iterator):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()epoch_loss += loss.item()avg_loss = epoch_loss / len(iterator)return avg_loss
上述代碼定義了一個名為 train
的函數,用于訓練給定的模型。它接收模型、數據迭代器作為參數,并返回訓練過程中的平均損失。
5.2 定義評估函數
def evaluate(model, iterator): # Being used to validate and testmodel.eval()epoch_loss = 0with torch.no_grad():for batch_idx, (data, target) in enumerate(iterable=iterator):output = model(data)loss = criterion(output, target)epoch_loss += loss.item()avg_loss = epoch_loss / len(iterator)return avg_loss
上述代碼定義了一個名為 evaluate
的函數,用于評估給定模型在給定數據迭代器上的性能。它接收模型、數據迭代器作為參數,并返回評估過程中的平均損失。這個函數通常在模型訓練的過程中定期被調用,以監控模型在驗證集或測試集上的性能。通過評估模型的性能,可以了解模型的泛化能力和訓練的進展情況。
5.3 定義早停法并保存模型
定義早停法以便在模型訓練過程中調用
class EarlyStopping:def __init__(self, patience=5, delta=0.0):self.patience = patience # 允許的連續未改進次數self.delta = delta # 損失波動容忍閾值self.counter = 0 # 未改進計數器self.best_loss = float('inf') # 最佳驗證損失值self.early_stop = False # 終止訓練標志def __call__(self, val_loss, model):if val_loss < (self.best_loss - self.delta):self.best_loss = val_lossself.counter = 0# 保存最佳模型參數?:ml-citation{ref="1,5" data="citationList"}torch.save(model.state_dict(), 'best_model.pth')else:self.counter +=1if self.counter >= self.patience:self.early_stop = True
EarlyStopper = EarlyStopping(patience=10, delta=0.00001) # 設置參數
若不想使用早停法EarlyStopper
,參數patience
設置一個超大的值,delta
設置為0,即可。
5.4 定義模型訓練主程序
通過定義模型訓練主程序來執行模型訓練
def main():train_losses = []val_losses = []for epoch in range(300):train_loss = train(model=automodel, iterator=train_loader)val_loss = evaluate(model=automodel, iterator=val_loader)train_losses.append(train_loss)val_losses.append(val_loss)print(f'Epoch: {epoch + 1:02}, Train MSELoss: {train_loss:.5f}, Val. MSELoss: {val_loss:.5f}')# 觸發早停判斷EarlyStopper(val_loss, model=automodel)if EarlyStopper.early_stop:print(f"Early stopping at epoch {epoch}")breakplt.figure(figsize=(10, 5))plt.plot(train_losses, label='Training Loss')plt.plot(val_losses, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('MSELoss')plt.title('Training and Validation Loss over Epochs')plt.legend()plt.grid(True)plt.show()
5.5 執行模型訓練過程
main()
Epoch: 69, Train MSELoss: 0.21886, Val. MSELoss: 0.21556
Epoch: 70, Train MSELoss: 0.22166, Val. MSELoss: 0.21716
Epoch: 71, Train MSELoss: 0.22082, Val. MSELoss: 0.20737
Epoch: 72, Train MSELoss: 0.21676, Val. MSELoss: 0.20873
Epoch: 73, Train MSELoss: 0.22007, Val. MSELoss: 0.21766
Epoch: 74, Train MSELoss: 0.22644, Val. MSELoss: 0.21219
Epoch: 75, Train MSELoss: 0.22045, Val. MSELoss: 0.20890
Epoch: 76, Train MSELoss: 0.22027, Val. MSELoss: 0.21222
Epoch: 77, Train MSELoss: 0.21933, Val. MSELoss: 0.20765
Epoch: 78, Train MSELoss: 0.22219, Val. MSELoss: 0.20903
Epoch: 79, Train MSELoss: 0.22051, Val. MSELoss: 0.20856
Epoch: 80, Train MSELoss: 0.22001, Val. MSELoss: 0.21346
Epoch: 81, Train MSELoss: 0.21968, Val. MSELoss: 0.21276
Early stopping at epoch 80
6. 異常檢測
6.1 異常檢測
接下來,我們通過構建 detect_anomalies
函數來對模型中的數據進行檢測。
# 5. 異常檢測
def detect_anomalies(model, x):model.eval()with torch.no_grad():reconstructions = model(x)mse = torch.mean((x - reconstructions)**2, dim=(1,2))return mse
6.2 設置閾值
# 在測試集上計算重建誤差
test_mse = detect_anomalies(automodel, x_test_tensor)# 設置閾值 (使用驗證集正常樣本的95%分位數)
val_mse = detect_anomalies(automodel, x_val_tensor)
threshold = torch.quantile(val_mse, 0.95)# 預測結果
y_pred = (test_mse > threshold).long()
print(f'Threshold: {threshold:.4f}')
print(y_pred.dtype)
print(y_pred.shape)
Threshold: 0.5402
torch.int64
torch.Size([2665])
7. 模型評估
7.1 評估函數
torchmetrics
庫提供了各種評估函數,例如:精確率Precision
、召回率Recall
、F1分數F1-Score
、 Area?Under?ROC?Curve \text{Area Under ROC Curve} Area?Under?ROC?Curve,我們可以直接用來評估模型性能
pre = precision(preds=y_pred, target=y_test_tensor, task="binary")
print(f"Precision: {pre:.5f}")rec = recall(preds=y_pred, target=y_test_tensor, task="binary")
print(f"Recall: {rec:.5f}")f1 = f1_score(preds=y_pred, target=y_test_tensor, task="binary")
print(f"F1 Score: {f1:.5f}")auc = auroc(preds=test_mse, target=y_test_tensor, task="binary")
print(f"AUC: {auc:.5f}")
Precision: 0.98586
Recall: 0.97165
F1 Score: 0.97870
AUC: 0.98020
7.2 混淆矩陣
cm = binary_confusion_matrix(preds=y_pred, target=y_test_tensor)
print(cm)
tensor([[ 555, 29],[ 59, 2022]])
預測可視化
# 7. 可視化部分結果
plt.figure(figsize=(12, 6))
plt.plot(test_mse, label='Reconstruction Error')
plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
plt.title('Anomaly Detection Results')
plt.xlabel('Sample Index')
plt.ylabel('MSE')
plt.legend()
plt.show()