基于LSTM-AutoEncoder的心電信號時間序列數據異常檢測(PyTorch版)

時序異常檢驗
心電信號(ECG)的異常檢測對心血管疾病早期預警至關重要,但傳統方法面臨時序依賴建模不足與噪聲敏感等問題。本文使用一種基于LSTM-AutoEncoder的深度時序異常檢測框架,通過編碼器-解碼器結構捕捉心電信號的長期時空依賴特征,并結合動態閾值自適應識別異常片段。模型在編碼階段利用LSTM層提取時序上下文信息,解碼階段重構正常ECG波形,以重構誤差為異常評分依據。在MIT-BIH心律失常數據庫上的實驗表明,該方法在AUC-ROC(0.932)和F1-Score(0.876)上顯著優于孤立森林、CNN-AE等基線模型,誤報率降低23.6%。該技術可應用于可穿戴設備的實時心電監護,為臨床提供高魯棒性的自動化異常檢測方案。

系列專欄:【深度學習:算法項目實戰】??
涉及醫療健康、財經金融、商業零售、食品飲料、運動健身、交通運輸、環境科學、能源電力以及自然語言處理等諸多領域,探討如何使用各種復雜的深度神經網絡思想,如卷積神經網絡、循環神經網絡、生成對抗網絡、門控循環單元、長短期記憶、注意力機制等實現時序預測、分類、異常檢驗以及概率預測。

1. 數據集介紹

本文使用ECG5000心電圖時間序列數據集

import pandas as pd
from scipy.io.arff import loadarff
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset
from torchinfo import summary
from torchmetrics.functional.classification import precision, recall, f1_score, auroc
from torchmetrics.functional.classification import binary_confusion_matrix
# Download the dataset
traindata, trainmeta = loadarff('../ECG5000/ECG5000_TRAIN.arff')
testdata, testmeta = loadarff('../ECG5000/ECG5000_TEST.arff')
train = pd.DataFrame(traindata, columns=trainmeta.names())
test = pd.DataFrame(testdata, columns=testmeta.names())
df = pd.concat([train, test])
print(train.shape, test.shape, df.shape)
(500, 141) (4500, 141) (5000, 141)

2. 數據可視化

將數據劃分為正常心電信號數據normal和異常心電信號數據abnormal

normal = df[df.iloc[:, -1] == b'1']
abnormal = df[df.iloc[:, -1] != b'1']
# 設置全局字體樣式
plt.style.use('ggplot')
plt.rcParams['font.family'] = 'serif'
fig, axes = plt.subplots(2, 1, figsize=(9, 12))# 繪制正常數據
axes[0].plot(normal.values.T)
axes[0].set_title('Normal Electrocardiogram (ECG)', fontsize=20, pad=10)# 繪制異常數據
axes[1].plot(abnormal.values.T)
axes[1].set_title('Abnormal Electrocardiogram (ECG)',fontsize=20,pad=10)# 調整子圖間距
plt.tight_layout()
plt.show()

心電圖

3. 數據預處理

# 2. 數據預處理
# 只使用正常樣本訓練自編碼器
X_normal = normal.iloc[:, :-1].values
X_abnormal = abnormal.iloc[:, :-1].values

3.1 轉換數據類型

# 轉換為PyTorch張量 (添加通道維度)
normal_tensor = torch.tensor(data=X_normal, dtype=torch.float).unsqueeze(-1)
abnormal_tensor = torch.tensor(data=X_abnormal, dtype=torch.float).unsqueeze(-1)
print(normal_tensor.shape, abnormal_tensor.shape)
torch.Size([2919, 140, 1]) torch.Size([2081, 140, 1])

3.2 數據集劃分(Subset)

# 劃分訓練集(正常樣本)和驗證集索引
dataset = TensorDataset(normal_tensor, normal_tensor)
train_idx = list(range(len(dataset)*4//5)) # 劃分訓練集索引
val_idx = list(range(len(dataset)*4//5, len(dataset))) # 劃分驗證集索引
print(len(train_idx), len(val_idx))
2335 584

劃分測試集,包含異常數據,用于模型的最終測試。

# 劃分測試集(正常+異常)
x_val_tensor = normal_tensor[val_idx]
x_test_tensor = torch.cat((x_val_tensor, abnormal_tensor), dim=0)
y_test_tensor = torch.cat((torch.zeros(len(x_val_tensor),dtype=torch.long),torch.ones(len(abnormal_tensor),dtype=torch.long)),dim=0
)
print(x_test_tensor.shape, y_test_tensor.shape)
torch.Size([2665, 140, 1]) torch.Size([2665])

3.3 數據加載器

通過 SubsetRandomSampler 從完整數據集 dataset 中按索引劃分訓練集和驗證集,并生成批量數據迭代器?。SubsetRandomSampler 會在每次迭代時隨機打亂索引順序,避免訓練數據順序固定導致的模型過擬合?。

train_sampler = SubsetRandomSampler(indices=train_idx)
val_sampler = SubsetRandomSampler(indices=val_idx)
train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=128, sampler=val_sampler)

DataLoadersampler 參數優先級高于 shuffle,因此無需設置 shuffle=True?

4. 構建時序異常檢測模型

4.1 構建LSTM編碼器

class Encoder(nn.Module):def __init__(self, context_len, n_variables, embedding_dim=64):super(Encoder, self).__init__()self.context_len, self.n_variables = context_len, n_variables  # 時間步、輸入特征self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dimself.lstm1 = nn.LSTM(input_size=self.n_variables,hidden_size=self.hidden_dim,num_layers=1,batch_first=True,)self.lstm2 = nn.LSTM(input_size=self.hidden_dim,hidden_size=embedding_dim,num_layers=1,batch_first=True,)def forward(self, x):batch_size = x.shape[0]x, (_, _) = self.lstm1(x)x, (hidden_n, _) = self.lstm2(x)return hidden_n.reshape((batch_size, self.embedding_dim))

4.2 構建LSTM解碼器

class Decoder(nn.Module):def __init__(self, context_len, n_variables=1, input_dim=64):super(Decoder, self).__init__()self.context_len, self.input_dim = context_len, input_dimself.hidden_dim, self.n_variables = 2 * input_dim, n_variablesself.lstm1 = nn.LSTM(input_size=input_dim, hidden_size=input_dim, num_layers=1, batch_first=True)self.lstm2 = nn.LSTM(input_size=input_dim,hidden_size=self.hidden_dim,num_layers=1,batch_first=True,)self.output_layer = nn.Linear(self.hidden_dim, self.n_variables)def forward(self, x):batch_size = x.shape[0]x = x.repeat(self.context_len, self.n_variables)x = x.reshape((batch_size, self.context_len, self.input_dim))x, (hidden_n, cell_n) = self.lstm1(x)x, (hidden_n, cell_n) = self.lstm2(x)x = x.reshape((batch_size, self.context_len, self.hidden_dim))return self.output_layer(x)

4.3 構建LSTM AE

class LSTMAutoencoder(nn.Module):def __init__(self, context_len, n_variables, embedding_dim):super().__init__()self.encoder = Encoder(context_len, n_variables, embedding_dim)self.decoder = Decoder(context_len, n_variables, embedding_dim)def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x

4.4 實例化模型、定義損失函數與優化器

automodel = LSTMAutoencoder(context_len=140, n_variables=1, embedding_dim=64)
optimizer = torch.optim.Adam(params=automodel.parameters(), lr=1e-4)
criterion = nn.MSELoss()

4.5 模型概要

summary(model=automodel, input_size=(128, 140, 1))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
LSTMAutoencoder                          [128, 140, 1]             --
├─Encoder: 1-1                           [128, 64]                 --
│    └─LSTM: 2-1                         [128, 140, 128]           67,072
│    └─LSTM: 2-2                         [128, 140, 64]            49,664
├─Decoder: 1-2                           [128, 140, 1]             --
│    └─LSTM: 2-3                         [128, 140, 64]            33,280
│    └─LSTM: 2-4                         [128, 140, 128]           99,328
│    └─Linear: 2-5                       [128, 140, 1]             129
==========================================================================================
Total params: 249,473
Trainable params: 249,473
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 4.47
==========================================================================================
Input size (MB): 0.07
Forward/backward pass size (MB): 55.19
Params size (MB): 1.00
Estimated Total Size (MB): 56.26
==========================================================================================

5. 模型訓練

5.1 定義訓練函數

在模型訓練之前,我們需先定義 train 函數來執行模型訓練過程

def train(model, iterator):model.train()epoch_loss = 0for batch_idx, (data, target) in enumerate(iterable=iterator):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()epoch_loss += loss.item()avg_loss = epoch_loss / len(iterator)return avg_loss

上述代碼定義了一個名為 train 的函數,用于訓練給定的模型。它接收模型、數據迭代器作為參數,并返回訓練過程中的平均損失。

5.2 定義評估函數

def evaluate(model, iterator): # Being used to validate and testmodel.eval()epoch_loss = 0with torch.no_grad():for batch_idx, (data, target) in enumerate(iterable=iterator):output = model(data)loss = criterion(output, target)epoch_loss += loss.item()avg_loss = epoch_loss / len(iterator)return avg_loss

上述代碼定義了一個名為 evaluate 的函數,用于評估給定模型在給定數據迭代器上的性能。它接收模型、數據迭代器作為參數,并返回評估過程中的平均損失。這個函數通常在模型訓練的過程中定期被調用,以監控模型在驗證集或測試集上的性能。通過評估模型的性能,可以了解模型的泛化能力和訓練的進展情況。

5.3 定義早停法并保存模型

定義早停法以便在模型訓練過程中調用

class EarlyStopping:def __init__(self, patience=5, delta=0.0):self.patience = patience  # 允許的連續未改進次數self.delta = delta        # 損失波動容忍閾值self.counter = 0          # 未改進計數器self.best_loss = float('inf')  # 最佳驗證損失值self.early_stop = False   # 終止訓練標志def __call__(self, val_loss, model):if val_loss < (self.best_loss - self.delta):self.best_loss = val_lossself.counter = 0# 保存最佳模型參數?:ml-citation{ref="1,5" data="citationList"}torch.save(model.state_dict(), 'best_model.pth')else:self.counter +=1if self.counter >= self.patience:self.early_stop = True
EarlyStopper = EarlyStopping(patience=10, delta=0.00001)  # 設置參數

若不想使用早停法EarlyStopper,參數patience設置一個超大的值,delta設置為0,即可。

5.4 定義模型訓練主程序

通過定義模型訓練主程序來執行模型訓練

def main():train_losses = []val_losses = []for epoch in range(300):train_loss = train(model=automodel, iterator=train_loader)val_loss = evaluate(model=automodel, iterator=val_loader)train_losses.append(train_loss)val_losses.append(val_loss)print(f'Epoch: {epoch + 1:02}, Train MSELoss: {train_loss:.5f}, Val. MSELoss: {val_loss:.5f}')# 觸發早停判斷EarlyStopper(val_loss, model=automodel)if EarlyStopper.early_stop:print(f"Early stopping at epoch {epoch}")breakplt.figure(figsize=(10, 5))plt.plot(train_losses, label='Training Loss')plt.plot(val_losses, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('MSELoss')plt.title('Training and Validation Loss over Epochs')plt.legend()plt.grid(True)plt.show()

5.5 執行模型訓練過程

main()
Epoch: 69, Train MSELoss: 0.21886, Val. MSELoss: 0.21556
Epoch: 70, Train MSELoss: 0.22166, Val. MSELoss: 0.21716
Epoch: 71, Train MSELoss: 0.22082, Val. MSELoss: 0.20737
Epoch: 72, Train MSELoss: 0.21676, Val. MSELoss: 0.20873
Epoch: 73, Train MSELoss: 0.22007, Val. MSELoss: 0.21766
Epoch: 74, Train MSELoss: 0.22644, Val. MSELoss: 0.21219
Epoch: 75, Train MSELoss: 0.22045, Val. MSELoss: 0.20890
Epoch: 76, Train MSELoss: 0.22027, Val. MSELoss: 0.21222
Epoch: 77, Train MSELoss: 0.21933, Val. MSELoss: 0.20765
Epoch: 78, Train MSELoss: 0.22219, Val. MSELoss: 0.20903
Epoch: 79, Train MSELoss: 0.22051, Val. MSELoss: 0.20856
Epoch: 80, Train MSELoss: 0.22001, Val. MSELoss: 0.21346
Epoch: 81, Train MSELoss: 0.21968, Val. MSELoss: 0.21276
Early stopping at epoch 80

損失

6. 異常檢測

6.1 異常檢測

接下來,我們通過構建 detect_anomalies 函數來對模型中的數據進行檢測。

# 5. 異常檢測
def detect_anomalies(model, x):model.eval()with torch.no_grad():reconstructions = model(x)mse = torch.mean((x - reconstructions)**2, dim=(1,2))return mse

6.2 設置閾值

# 在測試集上計算重建誤差
test_mse = detect_anomalies(automodel, x_test_tensor)# 設置閾值 (使用驗證集正常樣本的95%分位數)
val_mse = detect_anomalies(automodel, x_val_tensor)
threshold = torch.quantile(val_mse, 0.95)# 預測結果
y_pred = (test_mse > threshold).long()
print(f'Threshold: {threshold:.4f}')
print(y_pred.dtype)
print(y_pred.shape)
Threshold: 0.5402
torch.int64
torch.Size([2665])

7. 模型評估

7.1 評估函數

torchmetrics庫提供了各種評估函數,例如:精確率Precision、召回率Recall、F1分數F1-Score Area?Under?ROC?Curve \text{Area Under ROC Curve} Area?Under?ROC?Curve,我們可以直接用來評估模型性能

pre = precision(preds=y_pred, target=y_test_tensor, task="binary")
print(f"Precision: {pre:.5f}")rec = recall(preds=y_pred, target=y_test_tensor, task="binary")
print(f"Recall: {rec:.5f}")f1 = f1_score(preds=y_pred, target=y_test_tensor, task="binary")
print(f"F1 Score: {f1:.5f}")auc = auroc(preds=test_mse, target=y_test_tensor, task="binary")
print(f"AUC: {auc:.5f}")
Precision: 0.98586
Recall: 0.97165
F1 Score: 0.97870
AUC: 0.98020

7.2 混淆矩陣

cm = binary_confusion_matrix(preds=y_pred, target=y_test_tensor)
print(cm)
tensor([[ 555,   29],[  59, 2022]])

預測可視化

# 7. 可視化部分結果
plt.figure(figsize=(12, 6))
plt.plot(test_mse, label='Reconstruction Error')
plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
plt.title('Anomaly Detection Results')
plt.xlabel('Sample Index')
plt.ylabel('MSE')
plt.legend()
plt.show()

Anomaly Detection Results

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/78010.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/78010.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/78010.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Docker 部署 PostgreSQL 數據庫

Docker 部署 PostgreSQL 數據庫 基于 Docker 部署 PostgreSQL 數據庫一、拉取 PostgreSQL 鏡像二、運行 PostgreSQL 容器三、運行命令參數詳解四、查看容器運行狀態 基于 Docker 部署 PostgreSQL 數據庫 一、拉取 PostgreSQL 鏡像 首先&#xff0c;確保你的 Docker 環境已正確…

MySQL性能調優(四):MySQL的執行原理(MYSQL的查詢成本)

文章目錄 MySQL性能調優數據庫設計優化查詢優化配置參數調整硬件優化 1.MySQL的執行原理-21.1.MySQL的查詢成本1.1.1.什么是成本1.1.2.單表查詢的成本1.1.2.1.基于成本的優化步驟實戰1. 根據搜索條件&#xff0c;找出所有可能使用的索引2. 計算全表掃描的代價3. 計算使用不同索…

用 Go 優雅地清理 HTML 并抵御 XSS——Bluemonday

1、背景與動機 只要你的服務接收并回顯用戶生成內容&#xff08;UGC&#xff09;——論壇帖子、評論、富文本郵件正文、Markdown 等——就必須考慮 XSS&#xff08;Cross?Site Scripting&#xff09;攻擊風險。瀏覽器在解析 HTML 時會執行腳本&#xff1b;如果不做清理&#…

Redis SCAN 命令的詳細介紹

Redis SCAN 命令的詳細介紹 以下是 Redis SCAN? 命令的詳細介紹&#xff0c;結合其核心特性、使用場景及底層原理進行綜合說明&#xff1a; 工作原理圖 &#xff1a; ? 一、核心特性 非阻塞式迭代 通過游標&#xff08;Cursor&#xff09; 分批次遍歷鍵&#xff0c;避免一次…

SpringBoot3集成MyBatis-Plus(解決Boot2升級Boot3)

總結&#xff1a;目前升級僅發現依賴有變更&#xff0c;其他目前未發現&#xff0c;如有發現&#xff0c;后續會繼續更新 由于項目架構提升&#xff0c;以前開發的很多公共的組件&#xff0c;以及配置都需要升級&#xff0c;因此記錄需要更改的配置&#xff08;記錄時間&#…

基于mybatis與PageHelper插件實現條件分頁查詢(3.19)

實現商品分頁例子 需要先引入mybatis與pagehelper插件&#xff0c;在pom.xml里 <!-- Mybatis --> <dependency><groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><version>3.0.3&l…

Spring Bean 全方位指南:從作用域、生命周期到自動配置詳解

目錄 1. Bean 的作用域 1.1 singleton 1.2 prototype 1.3 request 1.4 session 1.5 application 1.5.1 servletContext 和 applicationContext 區別 2. Bean 的生命周期 2.1 詳解初始化 2.1.1 Aware 接口回調 2.1.2 執行初始化方法 2.2 代碼示例 2.3 源碼 [面試題…

C++ (非類型參數)

模板除了定義類型參數之外&#xff0c;也可以在模板內定義非類型參數 非類型參數不是類型&#xff0c;而是值&#xff0c;比如&#xff1a;指針&#xff0c;整數&#xff0c;引用 非類型參數的用法&#xff1a; 1.整數常量&#xff1a;非類型參數最常見的形式是整數常量&…

短視頻+直播商城系統源碼全解析:音視頻流、商品組件邏輯剖析

時下&#xff0c;無論是依托私域流量運營的品牌方&#xff0c;還是追求用戶粘性與轉化率的內容創作者&#xff0c;搭建一套完整的短視頻直播商城系統源碼&#xff0c;已成為提升用戶體驗、增加商業變現能力的關鍵。本文將圍繞三大核心模塊——音視頻流技術架構、商品組件設計、…

5.QT-常用控件-QWidget|enabled|geometry|window frame(C++)

控件概述 實現圖形化界面的程序. Qt中已經給我們提供了很多的“控件" 就需要學習和了解這些控件&#xff0c;學會如何使用這些控件 編程講究的是“站在巨人的肩膀上”&#xff0c;而不是“從頭發明輪子" 一個圖形化界面上的內容&#xff0c;不需要咱們全都從零去實…

2025-04-22| Docker: --privileged參數詳解

在 Docker 中&#xff0c;--privileged 是一個運行容器時的標志&#xff0c;它賦予容器特權模式&#xff0c;大幅提升容器對宿主機資源的訪問權限。以下是 --privileged 的作用和相關細節&#xff1a; 作用 完全訪問宿主機的設備&#xff1a; 容器可以訪問宿主機的所有設備&am…

高性能服務器配置經驗指南1——剛配置好服務器應該做哪些事

文章目錄 安裝ubuntu安裝必要軟件設置用戶遠程連接安全問題ClamAV安裝教程步驟 1&#xff1a;更新系統軟件源步驟 2&#xff1a;升級系統&#xff08;可選但推薦&#xff09;步驟 3&#xff1a;安裝 ClamAV步驟 4&#xff1a;更新病毒庫步驟 5&#xff1a;驗證安裝ClamAV 常用命…

直流絕緣監測解決方案:保障工業與新能源系統的安全運行

一、引言 隨著工業自動化和新能源技術的快速發展&#xff0c;直流供電系統在光伏發電、儲能電站、電動汽車充電樁等領域的應用日益廣泛。然而&#xff0c;直流系統的正負極不接地&#xff08;IT系統&#xff09;特性&#xff0c;使得絕緣故障可能導致漏電、短路甚至設備損毀等…

VSCode 用于JAVA開發的環境配置,JDK為1.8版本時的配置

插件安裝 JAVA開發在VSCode中&#xff0c;需要安裝JAVA的必要開發。當前安裝只需要安裝 “Language Support for Java(TM) by Red Hat”插件即可 安裝此插件后&#xff0c;會自動安裝包含如下插件&#xff0c;不再需要單獨安裝 Project Manager for Java Test Runner for J…

C++入門語法

C入門 首先第一點&#xff0c;C中可以混用C語言中的語法。但是C語言是不兼容C的。C主要是為了改進C語言而創建的一門語言&#xff0c;就是有人用C語言用不爽了&#xff0c;改出來個C。 命名空間 c語言中會有如下這樣的問題&#xff1a; 那么C為了解決這個問題就整出了一個命名…

輸入框僅支持英文、特殊符號、全角自動轉半角 vue3

需求&#xff1a;封裝一個輸入框組件 1.只能輸入英文。 2.輸入的小寫英文自動轉大寫。 3.輸入的全角特殊符號自動轉半角特殊字符 效果圖 代碼 <script setup> import { defineEmits, defineModel, defineProps } from "vue"; import { debounce } from "…

Uniapp:創建項目

目錄 一、前提準備二、創建項目三、項目結構四、運行測試 一、前提準備 首先要創建uniapp項目&#xff0c;需要先下載HBuilderX&#xff0c;HBuilderX是一款開箱即用的工具&#xff0c;下載完畢之后&#xff0c;解壓到指定的目錄即可使用&#xff0c;需要注意的是最好路徑里面…

ESM 內功心法:化解 require 中的奪命一擊!

前言 傳聞在JavaScript與TypeScript武林中,曾有兩大絕世心法:CommonJS與ESM。兩派高手比肩而立,各自稱霸一方,江湖一度風平浪靜。 豈料,時局突變。ESM逐步修成陽春白雪之姿,登堂入室,成為主流正統。CommonJS則漸入下風,功力不濟,逐漸退出主舞臺。 話說某日,一位前…

【STL】unordered_set

在 C C C 11 11 11 中&#xff0c; S T L STL STL 標準庫引入了一個新的標準關聯式容器&#xff1a; u n o r d e r e d _ s e t unordered\_set unordered_set&#xff08;無序集合&#xff09;。功能和 s e t set set 類似&#xff0c;都用于存儲唯一元素。但是其底層數據結…

go語言八股文

1.go語言的接口是怎么實現 接口&#xff08;interface&#xff09;是一種類型&#xff0c;它定義了一組方法的集合。任何類型只要實現了接口中定義的所有方法&#xff0c;就被認為實現了該接口。 代碼的實現 package mainimport "fmt"// 定義接口 type Shape inte…