Transformer 與 LSTM 在時序回歸中的實踐與優化


🧠 深度學習混合模型:Transformer 與 LSTM 在時序回歸中的實踐與優化

在處理多特征輸入、多目標輸出的時序回歸任務時,結合 Transformer 和 LSTM 的混合模型已成為一種有效的解決方案。Transformer 擅長捕捉長距離依賴關系,而 LSTM 在處理序列數據時表現出色。通過將兩者結合,可以充分發揮各自的優勢,提高模型的預測性能。


📊 數據生成與預處理

首先,我們生成一個包含多個特征的時序數據集,并進行必要的預處理。

import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split# 設置隨機種子以確保結果可復現
np.random.seed(42)# 生成時間序列數據
n_samples = 1000
time_steps = 10
n_features = 5
X = np.random.rand(n_samples, time_steps, n_features)
y = np.random.rand(n_samples, 1)  # 假設我們有一個目標變量# 數據歸一化
scaler_X = MinMaxScaler()
scaler_y = MinMaxScaler()X_scaled = X.reshape(-1, n_features)
X_scaled = scaler_X.fit_transform(X_scaled)
X_scaled = X_scaled.reshape(n_samples, time_steps, n_features)y_scaled = scaler_y.fit_transform(y)# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.2, random_state=42)

🧩 模型架構設計

我們設計一個結合 Transformer 和 LSTM 的混合模型架構。

import tensorflow as tf
from tensorflow.keras import layers, modelsdef build_transformer_lstm_model(input_shape, lstm_units=64, transformer_units=64, num_heads=4, num_layers=2, dropout_rate=0.1):inputs = layers.Input(shape=input_shape)# LSTM 層x = layers.LSTM(lstm_units, return_sequences=True)(inputs)x = layers.Dropout(dropout_rate)(x)# Transformer 層for _ in range(num_layers):attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=transformer_units)(x, x)x = layers.Add()([x, attention])x = layers.LayerNormalization()(x)x = layers.Dropout(dropout_rate)(x)# 輸出層x = layers.GlobalAveragePooling1D()(x)x = layers.Dense(64, activation='relu')(x)x = layers.Dropout(dropout_rate)(x)outputs = layers.Dense(1)(x)model = models.Model(inputs, outputs)return model# 構建模型
input_shape = (X_train.shape[1], X_train.shape[2])
model = build_transformer_lstm_model(input_shape)
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae'])

🏋??♂? 模型訓練與評估

from tensorflow.keras.callbacks import EarlyStopping# 定義早停機制
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)# 訓練模型
history = model.fit(X_train, y_train, epochs=50, batch_size=32, validation_data=(X_test, y_test), callbacks=[early_stopping])# 評估模型
loss, mae = model.evaluate(X_test, y_test)
print(f"Test Loss: {loss}, Test MAE: {mae}")

🔧 超參數調優

我們使用 Keras Tuner 進行超參數調優。

import keras_tuner as ktdef model_builder(hp):model = build_transformer_lstm_model(input_shape)model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=hp.Float('learning_rate', min_value=1e-5, max_value=1e-2, sampling='log')),loss='mean_squared_error',metrics=['mae'])return model# 定義調優器
tuner = kt.Hyperband(model_builder,objective='val_loss',max_epochs=10,factor=3,directory='hyperband',project_name='transformer_lstm'
)# 執行超參數調優
tuner.search(X_train, y_train, epochs=50, validation_data=(X_test, y_test), callbacks=[early_stopping])# 獲取最佳超參數
best_hps = tuner.get_best_hyperparameters()[0]
print(f"Best learning rate: {best_hps.get('learning_rate')}")

📈 結果可視化

import matplotlib.pyplot as plt# 繪制訓練過程中的損失和 MAE
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss Over Epochs')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history.history['mae'], label='Train MAE')
plt.plot(history.history['val_mae'], label='Val MAE')
plt.title('MAE Over Epochs')
plt.legend()plt.tight_layout()
plt.show()

📝 總結

通過結合 Transformer 和 LSTM 的混合模型,可以實現更好地捕捉時序數據中的長期依賴關系和復雜模式。本章所講述流程展示了從數據生成、模型設計到訓練和評估的完整過程,并引入了早停機制和超參數調優,以提高模型的性能和穩定性。


?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/78760.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/78760.shtml
英文地址,請注明出處:http://en.pswp.cn/web/78760.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

QT —— 信號和槽(帶參數的信號和槽函數)

QT —— 信號和槽(帶參數的信號和槽函數) 帶參的信號和槽函數信號參數個數和槽函數參數個數1. 參數匹配規則2. 實際代碼示例? 合法連接(槽參數 ≤ 信號參數)? 非法連接(槽參數 > 信號參數) 3. 特殊處理…

設計模式簡述(十七)備忘錄模式

備忘錄模式 描述組件使用 描述 備忘錄模式用于將對象的狀態進行保存為備忘錄,以便在需要時可以從備忘錄會對象狀態;其核心點在于備忘錄對象及其管理者是獨立于原有對象之外的。 常用于需要回退、撤銷功能的場景。 組件 原有對象(包含自身…

標簽語句分析

return userList.stream().filter(user -> {String tagsStr user.getTags(); 使用 Stream API 來過濾 userList 中的用戶 解析 tagsStr 并根據標簽進行過濾 假設 tagsStr 是一個 JSON 格式的字符串,存儲了一個標簽集合。你希望過濾出包含所有指定標簽的用戶。…

【應用密碼學】實驗四 公鑰密碼1——數學基礎

一、實驗要求與目的 學習快速模冪運算、擴展歐幾里得、中國剩余定理的算法思想以及代碼實現。 二、實驗內容與步驟記錄(只記錄關鍵步驟與結果,可截圖,但注意排版與圖片大小) 1.快速模冪運算的設計思路 快速模冪運算的核心思想…

WebSocket與Socket、TCP、HTTP的關系及區別

1.什么是WebSocket及原理 WebSocket是HTML5中新協議、新API。 WebSocket從滿足基于Web的日益增長的實時通信需求應運而生,解決了客戶端發起多個Http請求到服務器資源瀏覽器必須要在經過長時間的輪詢問題,實現里多路復用,是全雙工、雙向、單套…

基于C++的IOT網關和平臺4:github項目ctGateway交互協議

初級代碼游戲的專欄介紹與文章目錄-CSDN博客 我的github:codetoys,所有代碼都將會位于ctfc庫中。已經放入庫中我會指出在庫中的位置。 這些代碼大部分以Linux為目標但部分代碼是純C++的,可以在任何平臺上使用。 源碼指引:github源碼指引_初級代碼游戲的博客-CSDN博客 系…

【PPT制作利器】DeepSeek + Kimi生成一個初始的PPT文件

如何基于DeepSeek Kimi進行PPT制作 步驟: Step1:基于DeepSeek生成文本,提問 Step2基于生成的文本,用Kimi中PPT助手一鍵生成PPT 進行PPT渲染-自動渲染 可選擇更改模版 生成PPT在桌面 介紹的比較詳細,就是這個PPT模版…

拷貝多個Excel單元格區域為圖片并粘貼到Word

Excel工作表Sheet1中有兩個報表,相應單元格區域分別定義名稱為Report1和Report2,如下圖所示。 現在需要將圖片拷貝圖片粘貼到新建的Word文檔中。 示例代碼如下。 Sub Demo()Dim oWordApp As ObjectDim ws As Worksheet: Set ws ThisWorkbook.Sheets(&…

Spring是如何傳播事務的?什么是事務傳播行為

Spring是如何傳播事務的? Spring框架通過聲明式事務管理來傳播事務,主要依賴于AOP(面向切面編程)和事務攔截器來實現。Spring的事務傳播機制是基于Java Transaction API (JTA) 或者本地資源管理器(如Hibernate、JDBC等…

Python-pandas-操作Excel文件(讀取數據/寫入數據)及Excel表格列名操作詳細分享

Python-pandas-操作Excel文件(讀取數據/寫入數據) 提示:幫幫志會陸續更新非常多的IT技術知識,希望分享的內容對您有用。本章分享的是pandas的使用語法。前后每一小節的內容是存在的有:學習and理解的關聯性。【幫幫志系列文章】:每…

PHP分頁顯示數據,在phpMyadmin中添加數據

<?php $conmysqli_connect(localhost,root,,stu); mysqli_query($con,"set names utf8"); //設置字符集為utf8 $sql"select * from teacher"; $resultmysqli_query($con,$sql); $countmysqli_num_rows($result); //記錄總條數$count。 $pagesize10;//每…

智能參謀部系統架構和業務場景功能實現

將以一個基于微服務和云原生理念、深度集成人工智能組件、強調實時性與韌性的系統架構為基礎,詳細闡述如何落地“智能參謀部”的各項能力。這不是一個簡單的軟件堆疊,而是一個有機整合了數據、知識、模型、流程與人員的復雜體系。 系統愿景:“智能參謀部”——基于AI賦能的…

企業級RAG架構設計:從FAISS索引到HyDE優化的全鏈路拆解,金融/醫療領域RAG落地案例與避坑指南(附架構圖)

本文較長&#xff0c;純干貨&#xff0c;建議點贊收藏&#xff0c;以免遺失。更多AI大模型應用開發學習內容&#xff0c;盡在聚客AI學院。 一. RAG技術概述 1.1 什么是RAG&#xff1f; RAG&#xff08;Retrieval-Augmented Generation&#xff0c;檢索增強生成&#xff09; 是…

Spring Boot Validation實戰詳解:從入門到自定義規則

目錄 一、Spring Boot Validation簡介 1.1 什么是spring-boot-starter-validation&#xff1f; 1.2 核心優勢 二、快速集成與配置 2.1 添加依賴 2.2 基礎配置 三、核心注解詳解 3.1 常用校驗注解 3.2 嵌套對象校驗 四、實戰開發步驟 4.1 DTO類定義校驗規則 4.2 Cont…

理清緩存穿透、緩存擊穿、緩存雪崩、緩存不一致的本質與解決方案

在構建高性能系統中&#xff0c;緩存&#xff08;如Redis&#xff09; 是不可或缺的關鍵組件&#xff0c;它大幅減輕了數據庫壓力、加快了響應速度。然而&#xff0c;在高并發環境下&#xff0c;緩存也可能帶來一系列棘手的問題&#xff0c;如&#xff1a;緩存穿透、緩存擊穿、…

PyTorch_構建線性回歸

使用 PyTorch 的 API 來手動構建一個線性回歸的假設函數&#xff0c;數據加載器&#xff0c;損失函數&#xff0c;優化方法&#xff0c;繪制訓練過程中的損失變化。 數據構建 import torch from sklearn.datasets import make_regression import matplotlib.pyplot as plt i…

005-nlohmann/json 基礎方法-C++開源庫108杰

《二、基礎方法》&#xff1a;節點訪問、值獲取、顯式 vs 隱式、異常處理、迭代器、類型檢測、異常處理……一節課搞定C處理JSON數據85%的需求…… JSON 字段的簡單類型包括&#xff1a;number、boolean、string 和 null&#xff08;即空值&#xff09;&#xff1b;復雜類型則有…

HarmonyOS 5.0 分布式數據協同與跨設備同步??

大家好&#xff0c;我是 V 哥。 使用 Mate 70有一段時間了&#xff0c;系統的絲滑使用起來那是爽得不要不要的&#xff0c;隨著越來越多的應用適配&#xff0c;目前使用起來已經和4.3的兼容版本功能差異無礙了&#xff0c;還有些純血鴻蒙獨特的能力很是好用&#xff0c;比如&am…

Linux云計算訓練營筆記day02(Linux、計算機網絡、進制)

Linux 是一個操作系統 Linux版本 RedHat Rocky Linux CentOS7 Linux Ubuntu Linux Debian Linux Deepin Linux 登錄用戶 管理員 root a 普通用戶 nsd a 打開終端 放大: ctrl shift 縮小: ctrl - 命令行提示符 [rootlocalhost ~]# ~ 家目錄 /root 當前登錄的用戶…

macOS 安裝了Docker Desktop版終端docker 命令沒辦法使用

macOS 安裝了Docker Desktop版終端docker 命令沒辦法使用 1、檢查Docker Desktop能否正常運行。 確保Docker Desktop能正常運行。 2、檢查環境變量是否添加 1、添加環境變量 如果環境變量中沒有包含Docker的路徑&#xff0c;你可以手動添加。首先&#xff0c;找到Docker的…