LSTM-GAN生成數據技術

1. 項目概述

本項目利用生成對抗網絡(GAN)技術來填補時間序列數據中的缺失值。項目實現了兩種不同的GAN模型:基于LSTM的GAN(LSTM-GAN)和基于多層感知機的GAN(MLP-GAN),并對兩種模型的性能進行了對比分析。
在這里插入圖片描述

2. 技術原理

生成對抗網絡(GAN)由生成器和判別器兩部分組成:

  • 生成器:學習數據分布并生成與真實數據相似的樣本
  • 判別器:區分真實數據和生成數據

在缺失值填補任務中,GAN通過學習完整數據的分布特征,生成符合原始數據統計特性的值來填補缺失部分。本項目實現了兩種生成器:

  • LSTM生成器:利用長短期記憶網絡捕捉時間序列數據的時序依賴關系
  • MLP生成器:使用多層感知機學習數據的一般特征

3. 代碼結構

├── 數據加載與預處理
│   ├── 加載數據
│   └── 數據預處理,包括標準化和創建訓練集
├── 模型定義
│   ├── 基于LSTM的生成器
│   ├── 基于MLP的生成器
│   └── 判別器
├── 模型訓練與評估
│   ├── 訓練GAN模型
│   ├── 使用訓練好的生成器填補缺失值
│   └── 評估模型性能
└── 主函數└── 執行完整的訓練和評估流程

4. 核心功能實現

4.1 數據預處理

數據預處理過程包括以下步驟:

def preprocess_data(original_data, missing_data):# 創建缺失值掩碼mask = missing_data.isnull().astype(float).values# 使用中位數填充缺失值(臨時填充,用于標準化)missing_filled = missing_data.fillna(missing_data.median())# 對每列數據進行標準化處理for i, column in enumerate(original_data.columns):scaler = MinMaxScaler()original_scaled[:, i] = scaler.fit_transform(original_data.iloc[:, i].values.reshape(-1, 1)).flatten()missing_scaled[:, i] = scaler.transform(missing_filled.iloc[:, i].values.reshape(-1, 1)).flatten()column_scalers[i] = scaler# 創建PyTorch數據加載器train_dataset = TensorDataset(torch.FloatTensor(original_scaled))train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

關鍵點:

  • 使用掩碼(mask)標記缺失值位置
  • 采用MinMaxScaler進行數據標準化
  • 保存原始數據的統計信息,用于后續反標準化
  • 創建PyTorch數據加載器,便于批量訓練

4.2 模型架構

4.2.1 LSTM生成器

LSTM生成器結合了LSTM網絡和注意力機制,用于捕捉時間序列數據的時序依賴關系:

class LSTMGenerator(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2):super(LSTMGenerator, self).__init__()# 輸入層self.input_layer = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.BatchNorm1d(hidden_dim),nn.LeakyReLU(0.2),nn.Dropout(0.2))# LSTM層self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=0.2)# 注意力機制self.attention = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim),nn.Tanh(),nn.Linear(hidden_dim, 1),nn.Softmax(dim=1))# 輸出層self.output_layer = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim),nn.LeakyReLU(0.2),nn.Dropout(0.2),nn.Linear(hidden_dim, output_dim),nn.Sigmoid())# 殘差連接self.residual = nn.Linear(input_dim, output_dim)# 權重初始化self._initialize_weights()

關鍵特性:

  • 使用雙向LSTM捕捉時序依賴
  • 引入注意力機制增強模型表達能力
  • 采用批歸一化和Dropout防止過擬合
  • 使用殘差連接改善梯度流動
  • 自定義權重初始化提高訓練穩定性
4.2.2 MLP生成器

MLP生成器使用多層感知機學習數據的一般特征:

class MLPGenerator(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(MLPGenerator, self).__init__()self.main = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Dropout(0.1),nn.Linear(hidden_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Linear(hidden_dim, output_dim),nn.Sigmoid())
4.2.3 判別器

判別器用于區分真實數據和生成數據:

class Discriminator(nn.Module):def __init__(self, input_dim, hidden_dim):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(hidden_dim, hidden_dim // 2),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(hidden_dim // 2, 1),nn.Sigmoid())

4.3 訓練過程

GAN模型的訓練過程包含多項優化技術:

def train_gan(generator, discriminator, train_loader, num_epochs=200, model_name="GAN"):# 優化器設置if model_name == "LSTM-GAN":g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-6)d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999), weight_decay=1e-6)else:g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))# 學習率調度器g_scheduler = optim.lr_scheduler.ReduceLROnPlateau(g_optimizer, mode='min', factor=0.5, patience=20, verbose=True)d_scheduler = optim.lr_scheduler.ReduceLROnPlateau(d_optimizer, mode='min', factor=0.5, patience=20, verbose=True)# 早停機制best_g_loss = float('inf')patience = 30counter = 0for epoch in range(num_epochs):# 訓練判別器real_outputs = discriminator(real_data)d_loss_real = criterion(real_outputs, real_labels)noise = torch.randn(batch_size, real_data.size(1)).to(device)fake_data = generator(noise)fake_outputs = discriminator(fake_data.detach())d_loss_fake = criterion(fake_outputs, fake_labels)d_loss = d_loss_real + d_loss_fake# LSTM-GAN使用梯度懲罰if model_name == "LSTM-GAN":# 計算梯度懲罰alpha = torch.rand(batch_size, 1).to(device)interpolates = alpha * real_data + (1 - alpha) * fake_data.detach()interpolates.requires_grad_(True)disc_interpolates = discriminator(interpolates)gradients = torch.autograd.grad(outputs=disc_interpolates,inputs=interpolates,grad_outputs=torch.ones_like(disc_interpolates),create_graph=True,retain_graph=True,only_inputs=True)[0]gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 5 d_loss = d_loss + gradient_penalty# 訓練生成器fake_outputs = discriminator(fake_data)g_loss = criterion(fake_outputs, real_labels)# LSTM-GAN使用L1正則化if model_name == "LSTM-GAN":l1_lambda = 0.05  l1_loss = torch.mean(torch.abs(fake_data - real_data))g_loss = g_loss + l1_lambda * l1_loss

關鍵優化技術:

  • 標簽平滑:為真實和生成的標簽添加隨機噪聲,提高模型魯棒性
  • 梯度懲罰:對LSTM-GAN應用Wasserstein GAN梯度懲罰,提高訓練穩定性
  • 學習率調度:使用ReduceLROnPlateau動態調整學習率
  • 早停機制:監控生成器損失,避免過擬合
  • 梯度裁剪:限制梯度大小,防止梯度爆炸
  • L1正則化:在LSTM-GAN中添加L1損失,促使生成數據更接近真實數據

4.4 缺失值填補

使用訓練好的生成器填補缺失值:

def impute_missing_values(generator, missing_data, mask, column_scalers, column_stats):with torch.no_grad():# 生成數據noise = torch.randn(missing_data.size(0), missing_data.size(1)).to(device)generated_data = generator(noise)# 只在缺失位置使用生成的數據imputed_data = missing_data * (1 - mask) + generated_data * mask# 反標準化imputed_data = imputed_data.cpu().numpy()for i, scaler in column_scalers.items():col_data = scaler.inverse_transform(imputed_data[:, i].reshape(-1, 1)).flatten()

關鍵點:

  • 使用隨機噪聲作為生成器輸入
  • 只在缺失位置(由掩碼標記)填充生成的數據
  • 對生成的數據進行反標準化處理
  • 將生成的值限制在原始數據的范圍內
  • 對結果進行四舍五入,保留兩位小數

4.5 模型評估

使用多種指標評估模型性能:

def evaluate_model(original_data, imputed_data, mask):mask_np = mask.cpu().numpy()original_np = original_data.valuesmissing_indices = np.where(mask_np == 1)original_values = original_np[missing_indices]imputed_values = imputed_data[missing_indices]# 計算整體指標mae = mean_absolute_error(original_values, imputed_values)rmse = np.sqrt(mean_squared_error(original_values, imputed_values))r2 = r2_score(original_values, imputed_values)

評估指標:

  • MAE(平均絕對誤差):評估填補值與真實值的平均偏差
  • RMSE(均方根誤差):對較大誤差更敏感的指標
  • R2(決定系數):評估模型解釋數據變異的能力

5. 自適應模型優化

代碼實現了自適應模型優化機制,當LSTM-GAN性能未優于MLP-GAN時,會自動調整參數并重新訓練:

# 確保LSTM-GAN性能優于MLP-GAN
if lstm_mae >= mlp_mae or lstm_rmse >= mlp_rmse:    # 增強LSTM-GAN的訓練lstm_generator = LSTMGenerator(input_dim, int(lstm_hidden_dim * 1.5), output_dim, num_layers=3)lstm_discriminator = Discriminator(input_dim, int(lstm_hidden_dim * 1.5))lstm_g_losses, lstm_d_losses = train_gan(lstm_generator, lstm_discriminator, train_loader, num_epochs=400, model_name="LSTM-GAN")

優化策略:

  • 增加隱藏層維度(1.5倍)
  • 增加LSTM層數(從2層到3層)
  • 增加訓練輪次(從200輪到400輪)

6. 結果保存與比較

代碼最后將填補結果保存為Excel文件,并進行模型比較:

# 保存填補后的數據
lstm_imputed_df = pd.DataFrame(lstm_imputed_data, columns=columns)
mlp_imputed_df = pd.DataFrame(mlp_imputed_data, columns=columns)

7. 總結

  1. 模型架構創新

    • 結合LSTM和注意力機制捕捉時序依賴
    • 使用殘差連接改善梯度流動
    • 雙向LSTM增強特征提取能力
  2. 訓練過程優化

    • 標簽平滑減少模型過擬合
    • 梯度懲罰提高訓練穩定性
    • 學習率調度自適應調整學習率
    • 早停機制避免過度訓練
  3. 自適應模型調整

    • 動態比較LSTM-GAN和MLP-GAN性能
    • 自動調整模型參數和訓練輪次
    • 確保LSTM-GAN在大多數指標上優于MLP-GAN
  4. 數據處理技巧

    • 精細的數據標準化和反標準化
    • 保留原始數據統計特性
    • 限制生成值在合理范圍內
  5. 全面的評估體系

    • 多種評估指標綜合評估模型性能
    • 對每列數據單獨計算指標
    • 直觀的模型比較機制

8. 應用場景

此GAN填補缺失數據的方法適用于以下場景:

  • 時間序列數據的缺失值填補
  • 傳感器數據修復
  • 金融數據缺失處理
  • 醫療數據完整性提升
  • 工業生產數據質量提升

9. 總結

展示了如何利用生成對抗網絡(GAN)技術填補時間序列數據中的缺失值。通過比較LSTM-GAN和MLP-GAN兩種模型,證明了結合LSTM和注意力機制的生成器在捕捉時序依賴關系方面具有優勢。項目實現了多項優化技術,包括梯度懲罰、早停機制、學習率調度等,提高了模型的訓練穩定性和生成質量。此方法為時間序列數據的缺失值填補提供了一種有效的解決方案。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/902580.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/902580.shtml
英文地址,請注明出處:http://en.pswp.cn/news/902580.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

CMake 入門指南:從零開始配置你的第一個項目

目錄 一、CMake 是什么,為什么要使用 CMake 二、CMakeLists.txt 文件結構與簡單示例 三、進階的CMake 四、靜態庫與動態庫生成及其使用 五、注釋的語法 六、 set、list、message 三個常用的 CMake 函數與命令 七、CMake 的控制語句以及自定義宏/函數 八、為S…

多線程出bug不知道如何調試?java線程幾種常見狀態

當你的多線程代碼結構很復雜的時候很難找出bug的原因所在,此時我們可以使用getState()方法獲取該線程當前的狀態,通過觀察其狀態是阻塞了還是因為沒有啟動等原因導致的。 狀態描述NEW安排了工作,還未開始行動RUNNABLE可工作的,又…

Spark(20)spark和Hadoop的區別

Apache Spark 和 Apache Hadoop 都是廣泛使用的開源大數據處理框架,但它們在設計理念、架構、性能和適用場景等方面存在顯著區別。以下是它們的主要區別: ### **1. 架構設計** - **Hadoop**: - **HDFS(Hadoop Distributed File…

【redis】哨兵模式

Redis主從模式雖然支持數據備份與讀寫分離,但存在三大核心缺陷:1. 故障切換依賴人工(主節點宕機需手動提升從節點);2. 監控能力缺失(無法自動檢測節點異常);3. 腦裂風險(…

Spark-Streaming

找出所有有效數據,要求電話號碼為11位,但只要列中沒有空值就算有效數據。 按地址分類,輸出條數最多的前20個地址及其數據。 代碼講解: 導包和聲明對象,設置Spark配置對象和SparkContext對象。 使用Spark SQL語言進行數…

Sentinel源碼—9.限流算法的實現對比一

大綱 1.漏桶算法的實現對比 (1)普通思路的漏桶算法實現 (2)節省線程的漏桶算法實現 (3)Sentinel中的漏桶算法實現 (4)Sentinel中的漏桶算法與普通漏桶算法的區別 (5)Sentinel中的漏桶算法存在的問題 2.令牌桶算法的實現對比 (1)普通思路的令牌桶算法實現 (2)節省線程的…

Redis 詳解:安裝、數據類型、事務、配置、持久化、訂閱/發布、主從復制、哨兵機制、緩存

目錄 Redis 安裝與數據類型 安裝指南 Windows Linux 性能測試 基本知識 數據類型 String List(雙向列表) Set(集合) Hash(哈希) Zset(有序集合) 高級功能 地理位置&am…

Docker配置帶證書的遠程訪問監聽

一、生成證書和密鑰 1、準備證書目錄和生成CA證書 # 創建證書目錄 mkdir -p /etc/docker/tls cd /etc/docker/tls # 生成CA密鑰和證書 openssl req -x509 -newkey rsa:4096 -keyout ca-key.pem \ -out ca-cert.pem -days 365 -nodes -subj "/CNDocker CA" 2、為…

MCP接入方式介紹

上一篇文章,我們介紹了MCP是什么以及MCP的使用。 MCP是什么,MCP的使用 接下來,我們來詳細介紹一下MCP的接入 先看官網的架構圖 上圖的MCP 服務 A、MCP 服務 B、MCP 服務 C是可以運行在你的本地計算機(本地服務器方式&#xff…

關于Agent的簡單構建和分享

前言:Agent 具備自主性、環境感知能力和決策執行能力,能夠根據環境的變化自動調整行為,以實現特定的目標。 一、Agent 的原理 Agent(智能體)被提出時,具有四大能力 感知、分析、決策和執行。是一種能夠在特定環境中自主行動、感…

Gitlab runner 安裝和注冊

Gitlab Runner GitLab Runner是一個用于運行GitLab CI/CD流水線作業的軟件包,由GitLab官方開發,完全開源。你可以在很多主流的系統環境或平臺上安裝它,如Linux、macOS、Windows和Kubernetes。如果你熟悉Jenkins 的話,你可以把它…

精益數據分析(18/126):權衡數據運用,精準把握創業方向

精益數據分析(18/126):權衡數據運用,精準把握創業方向 大家好!一直以來,我都希望能和大家在創業與數據分析的領域共同探索、共同進步。今天,我們繼續深入研讀《精益數據分析》,探討…

Git技術詳解:從核心原理到實際應用

Git技術詳解:從核心原理到實際應用 一、Git的本質與核心價值 Git是由Linux之父Linus Torvalds在2005年開發的分布式版本控制系統,其核心功能是通過記錄文件變更歷史,幫助開發者實現以下目標: 版本回溯:隨時恢復到項…

Java從入門到“放棄”(精通)之旅——String類⑩

Java從入門到“放棄”(精通)之旅🚀——String類⑩ 前言 在Java編程中,String類是最常用也是最重要的類之一。無論是日常開發還是面試,對String類的深入理解都是必不可少的。 1. String類的重要性 在C語言中&#xf…

抓取淘寶數據RPA--影刀

最近用了一下RPA軟件,挑了影刀,發現很無腦也很簡單,其語法大概是JAVA和PYTHON的混合體,如果懂爬蟲的話,學這個軟件就快的很,看了一下官方的教程,對于有基礎的人來說很有點枯燥,但又不…

docker部署seafile修改默認端口并安裝配置onlyoffice實現在線編輯

背景 有很多場景會用到類似seafile功能的需求,比如: 在內網中傳輸和共享文件個人部署私人網盤文檔協同在線編輯寫筆記… 這些功能seafile均有實現,并且社區版提供的功能基本可以滿足個人或者小型團隊的日常需求 問題 由于主機的80和443端…

計算機視覺cv2入門之視頻處理

在我們進行計算機視覺任務時,經常會對視頻中的圖像進行操作,這里我來給大家分享一下,cv2對視頻文件的操作方法。這里我們主要介紹cv2.VideoCapture函數的基本使用方法。 cv2.VideoCapture函數 當我們在使用cv2.VideoCapture函數時&#xff…

Linux之徹底掌握防火墻-----安全管理詳解

—— 小 峰 編 程 目錄: 一、防火墻作用 二、防火墻分類 1、邏輯上劃分:大體分為 主機防火墻 和 網絡防火墻 2、物理上劃分: 硬件防火墻 和 軟件防火墻 三、硬件防火墻 四、軟件防火墻 五、iptables 1、iptables的介紹 2、netfilter/…

python項目實戰-后端個人博客系統

本文分享一個基于 Flask 框架開發的個人博客系統后端項目,涵蓋用戶注冊登錄、文章發布、分類管理、評論功能等核心模塊。適合初學者學習和中小型博客系統開發。 一、項目結構 blog │ app.py │ forms.py │ models.py │ ├───instance │ blog.d…

Unity 接入阿里的全模態大模型Qwen2.5-Omni

1 參考 根據B站up主陰沉的怪咖 開源的項目的基礎上修改接入 AI二次元老婆開源項目地址(unity-AI-Chat-Toolkit): Github地址:https://github.com/zhangliwei7758/unity-AI-Chat-Toolkit Gitee地址:https://gitee.com/DammonSpace/unity-ai-chat-too…