【算法】長短期記憶網絡(LSTM,Long Short-Term Memory)

這是一種特殊的循環神經網絡,能夠學習數據中的長期依賴關系,這是因為模型的循環模塊具有相互交互的四個層的組合,它可以記憶不定時間長度的數值,區塊中有一個gate能夠決定input是否重要到能被記住及能不能被輸出output。

原理

黃色方框內是四個神經網絡層,紅色圓圈是逐點算子,橙色圓圈是輸入,藍色圓圈是細胞狀態。LSTM具有一個單元狀態和三個門,對應選擇有選擇地學習、取消學習或保留來自每個單元的信息的能力。

LSTM中的單元狀態通過只允許一些線性交互來幫助信息流過單元而不被改變。

每個單元都有一個輸入、輸出和一個遺忘門,可以將信息添加或者刪除到單元狀態。

在這里插入圖片描述

遺忘門:使用sigmoid函數決定應該忘記來自先前單元狀態的哪些信息。

輸入門:分別使用sigmoid和tanh的逐點乘法運算控制信息流到當前單元狀態。

輸出門:最后,輸出門決定哪些信息應該傳遞到下一個隱藏狀態。

要在python中使用lstm模型,需要安裝這些庫:

pip install tansorflow pandas numpy matplotlib# pandas用來數據處理
# numpy用來數值計算
# matplotlib.pyplot用于數據可視化
# MinMaxScaler從sklearn.preprocessing用于數據規范化
# Sequential,LSTM,Dense從tensorflow.keras用于構建神經網絡
# mean_squared_error從sklearn.metrics用于計算模型誤差

實現

  1. 生成示例數據:簡單的正弦波形,
  2. 設置隨機數生成的種子,確保結果可以復現,
  3. 生成一系列時間步長,
  4. 創建數據,結合正弦波和隨機噪聲
  5. 數據轉換為DataFrame
  6. 使用Pandas的DataFrame來存儲和處理生成的數據,
  7. 數據規范化:使用MinMaxScaler將數據規范化到0和1之間,這對神經網絡的性能至關重要。
  8. 分割數據為訓練集和測試集:確定訓練集的大小(數據的80%),剩余的20%s數據作為測試集。
  9. 創建數據集函數:這個函數將時間序列數據轉換為可以用于監督學習的格式,look_back參數決定用多少個過去的時間步數來預測下一個時間步。
  10. 設置look_back,并創建訓練/測試數據;
  11. 使用1作為look_back的值
  12. 重塑輸入數據為[樣本,時間步,特征]
  13. LSTM模型在keras中需要三維輸入
  14. 創建LSTM模型:創建一個Sequential模型,添加一個含有50個神經元的LSTM層,添加一個Dense層作為輸出層,編譯模型,使用均方誤差作為損失函數和Adam優化器。

注:epoch是指訓練周期。

代碼如下:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.metrics import mean_squared_error# 生成示例數據:正弦波 + 隨機噪聲
np.random.seed(0)
timesteps = np.arange(0, 1000, 0.1)
data = np.sin(timesteps) + np.random.normal(scale=0.5, size=len(timesteps))# 數據轉換為DataFrame
df = pd.DataFrame(data, columns=['value'])
values = df['value'].values# 數據規范化
scaler = MinMaxScaler(feature_range=(0, 1))
values_scaled = scaler.fit_transform(values.reshape(-1, 1))# 分割數據為訓練集和測試集
train_size = int(len(values_scaled) * 0.8)
test_size = len(values_scaled) - train_size
train, test = values_scaled[0:train_size, :], values_scaled[train_size:len(values_scaled), :]# 創建數據集
def create_dataset(dataset, look_back=1):X, Y = [], []for i in range(len(dataset) - look_back - 1):a = dataset[i:(i + look_back), 0]X.append(a)Y.append(dataset[i + look_back, 0])return np.array(X), np.array(Y)look_back = 1
X_train, Y_train = create_dataset(train, look_back)
X_test, Y_test = create_dataset(test, look_back)# 重塑輸入數據為 [樣本, 時間步, 特征]
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))# 創建LSTM模型
model = Sequential()
model.add(LSTM(50, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')# 訓練模型
model.fit(X_train, Y_train, epochs=5, batch_size=1, verbose=2)# 進行預測
train_predict = model.predict(X_train)
test_predict = model.predict(X_test)# 反轉規范化
train_predict = scaler.inverse_transform(train_predict)
Y_train = scaler.inverse_transform([Y_train])
test_predict = scaler.inverse_transform(test_predict)
Y_test = scaler.inverse_transform([Y_test])# 計算均方誤差
train_score = np.sqrt(mean_squared_error(Y_train[0], train_predict[:,0]))
test_score = np.sqrt(mean_squared_error(Y_test[0], test_predict[:,0]))# 可視化
plt.figure(figsize=(12, 6))
plt.plot(scaler.inverse_transform(values_scaled), label='Original Data')
plt.plot(np.append(np.zeros(train_size), train_predict[:,0]), linestyle='--', label='Training Predict')
plt.plot(np.append(np.zeros(train_size), test_predict[:,0]), linestyle='--', label='Test Predict')
plt.legend()
plt.show()

運行圖如下:

在這里插入圖片描述

觀測

訓練損失逐漸降低并趨于穩定,意味著模型正在從訓練數據中學習。

在訓練集和測試集上的評估速度很快,意味著模型的推斷(預測)效率很高。

如果損失在后續的epoch中沒有顯著下降,可能意味著模型需要更多的epoch來訓練。或者可能需要調整模型的結構或超參數(例如增加神經元數量、改變學習率)以進一步提高性能。

訓練集和測試集的RMSE非常接近,說明模型在兩者上的性能是一致的。

沒有出現過擬合/欠擬合的跡象,則說明模型的泛化能力良好。

考慮到數據生成時添加了隨機噪聲,這個RMSE值表明模型在捕捉數據的基本趨勢方面表現的不錯,相對較小的RMSE表示預測的準確。

如果要改進LSTM,可以從貝葉斯超參數調優、增加更多訓練周期(EPOCH)、嘗試不同網絡架構,或者在數據預處理時更復雜一些。

事實上,在RMSE上,SARIMA比LSTM更小,但非線性模式/利用長期依賴性的復雜時間序列數據時,LSTM更好。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/712015.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/712015.shtml
英文地址,請注明出處:http://en.pswp.cn/news/712015.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

37.云原生之springcloud+k8s+GitOps+istio+安全實踐

云原生專欄大綱 文章目錄 準備工作項目結構介紹配置安全測試ConfigMapSecret使用Secret中數據的方式Deployment使用Secret配置Secret加密 kustomize部署清單ConfigMap改造SecretSealedSecretDeployment改造Serviceistio相關資源DestinationRuleGatewayVirtualServiceServiceAc…

132557-72-3,2,3,3-三甲基-3H-吲哚-5-磺酸,具有優異的反應活性和光學性能

132557-72-3,5-Sulfo-2,3,3-trimethyl indolenine sodium salt,2,3,3-三甲基-3H-吲哚-5-磺酸,具有優異的反應活性和光學性能,一種深棕色粉末 您好,歡迎來到新研之家 文章關鍵詞:132557-72-3,5…

ROS2體系框架

文章目錄 1.ROS2的系統架構2.ROS2的編碼風格3.細談初始化和資源釋放4.細談配置文件5.ROS2的一些命令6.ROS2的核心模塊6.1 通信模塊6.2 功能包6.3 分布式6.4 終端命令和rqt6.5 launch6.6 TF坐標變換6.7 可視化RVIZ 1.ROS2的系統架構 開發者的工作內容一般都在應用層,…

MySQL學習Day24—數據庫的設計規范

一、數據庫設計的重要性: 1.糟糕的數據庫設計產生的問題: (1)數據冗余、信息重復、存儲空間浪費 (2)數據更新、插入、刪除的異常 (3)無法正確表示信息 (4)丟失有效信息 (5)程序性能差 2.良好的數據庫設計有以下優點: (1)節省數據的存儲空間 (2)能夠保證數據的完整性 …

力扣138.隨機鏈表的復制

給你一個長度為 n 的鏈表,每個節點包含一個額外增加的隨機指針 random ,該指針可以指向鏈表中的任何節點或空節點。 構造這個鏈表的 深拷貝。 深拷貝應該正好由 n 個 全新 節點組成,其中每個新節點的值都設為其對應的原節點的值。新節點的 n…

編寫一個自動合并代碼到不同分支的腳本小工具

新建一個 autoMerge.sh 的文件,文件內容如下 # 提示用戶確認繼續執行 read -p "確認要執行腳本嗎?(輸入 yes 繼續): " userInput# 檢查用戶輸入是否為 "yes" if [ "$userInput" ! "yes" ]; thenecho "用戶…

《TCP/IP詳解 卷一》第9章 廣播和組播

目錄 9.1 引言 9.2 廣播 9.2.1 使用廣播地址 9.2.2 發送廣播數據報 9.3 組播 9.3.1 將組播IP地址轉換為組播MAC地址 9.3.2 例子 9.3.3 發送組播數據報 9.3.4 接收組播數據報 9.3.5 主機地址過濾 9.4 IGMP協議和MLD協議 9.4.1 組成員的IGMP和MLD處理 9.4.2 組播路由…

可用于智能客服的完全開源免費商用的知識庫項目

介紹 FastWiki項目是一個高性能、基于最新技術棧的知識庫系統,專為大規模信息檢索和智能搜索設計。利用微軟Semantic Kernel進行深度學習和自然語言處理,結合.NET 8和MasaBlazor前端框架,后臺采用.NET 8MasaFrameworkSemanticKernel&#xff…

嵌入式Linux學習DAY26

管道的作用:進程間的通信 無名管道: 只能在父子進程中進行通信 pipe int pipe(int pipefd[2]); 功能: 創建一個無名管道 參數: pipefd[0]:讀管道文件描述符 pipefd[1]:寫管道文件描述符 …

【InternLM 實戰營筆記】基于 InternLM 和 LangChain 搭建MindSpore知識庫

InternLM 模型部署 準備環境 拷貝環境 /root/share/install_conda_env_internlm_base.sh InternLM激活環境 conda activate InternLM安裝依賴 # 升級pip python -m pip install --upgrade pippip install modelscope1.9.5 pip install transformers4.35.2 pip install str…

【大廠AI課學習筆記NO.53】2.3深度學習開發任務實例(6)數據采集

這個系列寫了53期了,很多朋友收藏,看來還是覺得有用。 后續我會把相關的內容,再次整理,做成一個人工智能專輯。 今天學習到了數據采集的環節。 這里有個問題,數據準備包括什么,還記得嗎? 數…

ZStack Cube超融合入選IDC《中國超融合基礎架構市場評估》報告

近日,IDC發布了《中國超融合基礎架構市場評估,2023》。IDC針對中國超融合基礎架構市場的發展現狀展開了調研,明確了最終用戶構建融合型云平臺的痛點和難點,闡述了市場中各技術服務提供商的服務方案和優勢,并對未來中國…

vue3+ts+vite數據大屏自適應總結(兩種方法)

總結一下我常用的數據大屏自適應方法 目錄 1、通過css縮放方案: 利用transform:scale 進行適配2、采用rem布局, 根據屏幕分辨率大小不同,調整根元素html的font-size, 從而達到每個元素寬高自動變化,適配不…

接口測試實戰--mock測試、日志模塊

一、mock測試 在前后端分離項目中,當后端工程師還沒有完成接口開發的時候,前端開發工程師利用Mock技術,自己用mock技術先調用一個虛擬的接口,模擬接口返回的數據,來完成前端頁面的開發。 接口測試和前端開發有一個共同點,就是都需要用到后端工程師提供的接口。所以,當…

Redis速學

一、介紹Redis 基本概念和特點 Redis是一個開源的內存數據庫,它主要用于數據緩存和持久化。其數據存儲在內存中,這使得它具有非常快的讀寫速度。Redis支持多種數據結構,包括字符串、哈希、列表、集合和有序集合,這使得它非常靈活…

書生·浦語大模型圖文對話Demo搭建

前言 本節我們先來搭建幾個Demo來感受一下書生浦語大模型 InternLM-Chat-7B 智能對話 Demo 我們將使用 InternStudio 中的 A100(1/4) 機器和 InternLM-Chat-7B 模型部署一個智能對話 Demo 環境準備 在 InternStudio 平臺中選擇 A100(1/4) 的配置,如下圖所示鏡像…

微店商品詳情 API 支持哪些商品信息的獲取?

微店(Weidian)并沒有一個公開的、官方維護的API文檔來供開發者使用。這意味著,如果你想要獲取微店商品詳情或其他相關信息,你通常需要通過微店官方提供的方式來實現,例如使用其開放平臺、官方SDK或聯系微店的技術支持獲…

Spring常見面試題知識點總結(三)

7. Spring MVC: MVC架構的概念。 MVC(Model-View-Controller)是一種軟件設計模式,旨在將應用程序分為三個主要組成部分,以實現更好的代碼組織、可維護性和可擴展性。每個組件有著不同的職責,相互之間解耦…

11.Prometheus常見PromeQL表達式

平凡也就兩個字: 懶和惰; 成功也就兩個字: 苦和勤; 優秀也就兩個字: 你和我。 跟著我從0學習JAVA、spring全家桶和linux運維等知識,帶你從懵懂少年走向人生巔峰,迎娶白富美! 關注微信公眾號【 IT特靠譜 】,每天都會分享技術心得~ …

YOLO算法

YOLO介紹 YOLO,全稱為You Only Look Once: Unified, Real-Time Object Detection,是一種實時目標檢測算法。目標檢測是計算機視覺領域的一個重要任務,它不僅需要識別圖像中的物體類別,還需要確定它們的位置。與分類任務只關注對…