用深度學習(LSTM)實現時間序列預測:從數據到閉環預測全解析
時間序列預測是工業、金融、環境等領域的核心需求——小到預測設備溫度波動,大到預測股價走勢,都需要從歷史數據中挖掘時序規律。長短期記憶網絡(LSTM)憑借對“長期依賴關系”的捕捉能力,成為時序預測的主流模型之一。
本文將基于MATLAB深度學習工具箱,以波形數據集(WaveformData) 為例,完整拆解LSTM時間序列預測的實現流程,重點講解“閉環預測”的核心邏輯(用前一次預測結果作為下一次輸入,無需真實值即可多步預測),并對代碼逐行、參數逐個進行解析。
一、整體背景:LSTM與兩種預測模式
LSTM是一種循環神經網絡(RNN),通過“門控機制”(遺忘門、輸入門、輸出門)動態更新“隱藏狀態”,從而記住序列中的關鍵歷史信息,避免普通RNN的“梯度消失”問題。
時序預測有兩種核心模式,也是本文的重點對比對象:
- 開環預測:每次預測都需要“真實的歷史數據”作為輸入(比如預測第t步需要第t-1步的真實值),適合能實時獲取真實數據的場景。
- 閉環預測:僅用初始真實數據初始化,后續預測完全依賴“前一次的預測結果”作為輸入(無需真實值),適合需要一次性預測多步未來、或無法獲取實時真實數據的場景(如預測未來200天的溫度)。
本文將從數據加載到閉環預測,一步步實現完整流程。
二、完整實現流程與代碼解析
1. 第一步:加載與探索數據
首先加載示例數據集,了解數據結構,為后續處理做準備。
代碼與逐行解析
% 加載波形數據集(MATLAB內置示例數據)
load WaveformData% 查看前5個序列的結構(數據是cell數組,每個元素是一個序列)
data(1:5)% 計算序列的通道數(所有序列通道數一致,才能訓練網絡)
numChannels = size(data{1},1)% 可視化前4個序列(堆疊圖展示多通道)
figure
tiledlayout(2,2) % 創建2x2的子圖布局
for i = 1:4nexttile % 激活下一個子圖stackedplot(data{i}') % 轉置序列:讓時間步為x軸,通道為y軸xlabel("Time Step") % x軸標簽:時間步
end% 劃分訓練集與測試集(9:1拆分)
numObservations = numel(data); % 總序列數(data是cell數組,numel取元素個數)
idxTrain = 1:floor(0.9*numObservations); % 訓練集索引(前90%)
idxTest = floor(0.9*numObservations)+1:numObservations; % 測試集索引(后10%)
dataTrain = data(idxTrain); % 訓練集序列
dataTest = data(idxTest); % 測試集序列
關鍵參數與概念
- WaveformData:MATLAB內置的合成波形數據集,結構為
numObservations×1
的cell數組,每個cell元素是numChannels×numTimeSteps
的矩陣(numChannels=3
,即每個時間步有3個特征;numTimeSteps
為序列長度,不同序列長度不同)。 - stackedplot:堆疊圖函數,適合展示多通道時序數據(每個通道一條線,避免重疊)。
- 數據劃分邏輯:9:1拆分是時序預測的常用比例,既保證訓練集足夠大(學習規律),又保留測試集(評估泛化能力)。
2. 第二步:準備訓練數據(核心:移位目標序列+歸一化)
LSTM訓練需要“輸入-目標”配對的監督數據。時序預測的核心技巧是:輸入為“去掉最后一個時間步的序列”,目標為“移位一個時間步的序列”,讓LSTM學習“當前時間步→下一個時間步”的映射關系。
同時,為避免訓練發散、提升收斂速度,需要對數據做“零均值單位方差”歸一化。
代碼與逐行解析
% 1. 構建訓練集的“輸入-目標”配對(移位序列)
for n = 1:numel(dataTrain) % 遍歷每個訓練序列X = dataTrain{n}; % 取第n個訓練序列(numChannels×numTimeSteps)XTrain{n} = X(:,1:end-1); % 輸入:去掉最后一個時間步(無法預測它的下一個值)TTrain{n} = X(:,2:end); % 目標:移位一個時間步(每個輸入對應下一個時間步的真實值)
end% 2. 歸一化:計算訓練集的均值和標準差(所有序列拼接后統計,保證一致性)
muX = mean(cat(2,XTrain{:}),2); % 輸入的均值:cat(2,...)按時間步拼接所有序列,mean(...,2)按通道算均值
sigmaX = std(cat(2,XTrain{:}),0,2); % 輸入的標準差:0表示除以N-1(無偏估計),2表示按通道算muT = mean(cat(2,TTrain{:}),2); % 目標的均值
sigmaT = std(cat(2,TTrain{:}),0,2); % 目標的標準差% 3. 對輸入和目標進行歸一化(用訓練集的統計量,避免數據泄露)
for n = 1:numel(XTrain)XTrain{n} = (XTrain{n} - muX) ./ sigmaX; % 輸入歸一化:(原始-均值)/標準差TTrain{n} = (TTrain{n} - muT) ./ sigmaT; % 目標歸一化
end
關鍵邏輯解釋
- 移位序列的原因:假設序列為
[t1,t2,t3,t4]
,輸入XTrain
為[t1,t2,t3]
,目標TTrain
為[t2,t3,t4]
,讓LSTM學習“t1→t2”“t2→t3”“t3→t4”的映射,最終能實現“輸入任意序列→預測下一個時間步”。 - 歸一化的必要性:若不同通道的數值范圍差異大(如通道1是0-1,通道2是100-200),訓練時會導致梯度更新失衡,模型難以收斂。用訓練集統計量歸一化,是為了避免“測試集信息泄露到訓練集”(測試集的統計量未知)。
3. 第三步:定義LSTM網絡架構
時序預測的LSTM網絡需要適配“序列輸入→序列輸出”的需求,核心層包括:序列輸入層、LSTM層、全連接層、回歸層。
代碼與逐行解析
layers = [sequenceInputLayer(numChannels) % 序列輸入層:輸入維度=通道數(numChannels=3)lstmLayer(128) % LSTM層:128個隱藏單元(決定學習能力)fullyConnectedLayer(numChannels) % 全連接層:輸出維度=通道數(與輸入通道一致)regressionLayer]; % 回歸層:定義回歸任務的損失函數(默認均方誤差MSE)
各層參數與作用詳解
層名稱 | 參數配置 | 作用說明 |
---|---|---|
sequenceInputLayer | numChannels=3 | 接收“通道數×時間步”的序列輸入,輸入維度必須與數據的通道數一致(否則維度不匹配)。 |
lstmLayer | 128個隱藏單元 | 隱藏單元數量決定LSTM的“記憶容量”:128個單元可捕捉中等復雜度的時序規律;數量越多學習能力越強,但易過擬合。 |
fullyConnectedLayer | numChannels=3 | 將LSTM輸出的128維隱藏狀態“映射”到3維(與輸入通道數一致),確保輸出序列的維度與目標序列匹配。 |
regressionLayer | 無參數(默認) | 回歸任務的輸出層,計算“預測值-真實值”的均方誤差(MSE),作為訓練的損失函數,指導網絡更新權重。 |
4. 第四步:指定訓練選項
訓練選項決定模型的優化策略,需結合數據規模、網絡復雜度調整。
代碼與逐行解析
options = trainingOptions("adam", ... % 優化器:Adam(自適應學習率,適合時序數據)MaxEpochs=200, ... % 最大訓練輪數:200輪(平衡訓練效果與時間)SequencePaddingDirection="left", ...% 序列對齊方式:左側補零(保護右側有效信息)Shuffle="every-epoch", ... % 數據打亂:每輪訓練前打亂訓練集,避免過擬合Plots="training-progress", ... % 可視化:顯示訓練進度(損失曲線、準確率等)Verbose=0); % 日志輸出:0表示不打印詳細訓練日志(僅看進度圖)
關鍵選項解釋
- Adam優化器:比SGD(隨機梯度下降)收斂更快,通過自適應學習率調整不同參數的更新步長,適合LSTM這類復雜網絡。
- MaxEpochs=200:200輪是針對2000個序列的經驗值——輪數太少可能欠擬合(沒學會規律),太多則可能過擬合(記住訓練集噪聲)。
- SequencePaddingDirection=“left”:不同序列長度不同,訓練時需補零對齊。左側補零是為了保護“右側的近期信息”(時序數據中,右側時間步更重要),避免右側補零干擾預測。
5. 第五步:訓練LSTM網絡
調用trainNetwork
函數,用訓練集(XTrain, TTrain)和訓練選項(options)訓練網絡。
代碼與解析
% 訓練網絡:輸入(XTrain)、目標(TTrain)、網絡架構(layers)、訓練選項(options)
net = trainNetwork(XTrain,TTrain,layers,options);
- 輸出:訓練好的LSTM網絡
net
,包含學習到的權重、偏置和網絡結構。 - 訓練過程:運行時會彈出“訓練進度圖”,可觀察訓練損失(Training Loss)的下降趨勢——若損失趨于平穩,說明網絡收斂。
6. 第六步:測試網絡(評估泛化能力)
測試的核心是:用訓練好的網絡預測測試集,計算誤差(RMSE)評估泛化能力。
代碼與逐行解析
% 1. 準備測試數據(與訓練數據處理邏輯一致:移位+歸一化)
for n = 1:size(dataTest,1) % 遍歷每個測試序列X = dataTest{n}; % 取第n個測試序列XTest{n} = (X(:,1:end-1) - muX) ./ sigmaX; % 測試輸入:移位+用訓練集統計量歸一化TTest{n} = (X(:,2:end) - muT) ./ sigmaT; % 測試目標:移位+歸一化
end% 2. 用測試集預測(指定左側補零,與訓練一致)
YTest = predict(net,XTest,SequencePaddingDirection="left");% 3. 計算每個測試序列的RMSE(均方根誤差,評估預測精度)
for i = 1:size(YTest,1)% RMSE = sqrt(平均(預測值-真實值)^2),"all"表示對所有元素計算rmse(i) = sqrt(mean((YTest{i} - TTest{i}).^2,"all"));
end% 4. 可視化RMSE分布(直方圖)
figure
histogram(rmse) % 繪制RMSE的頻率分布
xlabel("RMSE") % x軸:RMSE值(越小精度越高)
ylabel("Frequency") % y軸:頻率(多少個序列的RMSE落在該區間)% 5. 計算所有測試序列的平均RMSE
mean(rmse)
評估邏輯
- RMSE的意義:RMSE越小,預測值與真實值的偏差越小。例如,若平均RMSE=0.1,說明預測值與真實值的平均偏差僅0.1(歸一化后的值,反歸一化后可還原為原始尺度)。
- 為什么用訓練集統計量歸一化:測試時無法獲取“未來數據的統計量”,用訓練集統計量才能模擬真實預測場景(避免數據泄露)。
7. 第七步:預測未來時間步(重點:開環vs閉環)
測試僅驗證“單步預測”能力,實際應用中常需“多步預測”(如預測未來200個時間步)。此時需區分開環與閉環兩種模式,閉環預測是本文核心。
7.1 先理解:開環預測(依賴真實值)
開環預測的邏輯是:每次預測都需要“前一個時間步的真實值”作為輸入,適合能實時獲取真實數據的場景(如實時監測設備數據,用真實值預測下一秒)。
% 選擇一個測試序列(索引=2)
idx = 2;
X = XTest{idx}; % 測試輸入序列
T = TTest{idx}; % 測試目標序列% 1. 初始化網絡狀態(重置隱藏狀態,避免歷史數據干擾)
net = resetState(net);
% 2. 用前75個時間步的真實數據更新網絡狀態(讓網絡“記住”初始上下文)
offset = 75; % 初始真實數據的時間步長度
[net,~] = predictAndUpdateState(net,X(:,1:offset));% 3. 開環預測:用真實值作為輸入,預測剩余時間步
numTimeSteps = size(X,2); % 測試序列總時間步
numPredictionTimeSteps = numTimeSteps - offset; % 需預測的時間步數量
Y_open = zeros(numChannels,numPredictionTimeSteps); % 存儲開環預測結果for t = 1:numPredictionTimeStepsXt = X(:,offset+t); % 輸入:第offset+t步的真實值(開環的核心:依賴真實值)[net,Y_open(:,t)] = predictAndUpdateState(net,Xt); % 預測+更新網絡狀態
end% 4. 可視化開環預測結果
figure
t = tiledlayout(numChannels,1); % 按通道堆疊子圖
title(t,"Open Loop Forecasting")
for i = 1:numChannelsnexttileplot(T(i,:)) % 真實值(目標序列)hold on% 預測值:從offset步開始,拼接offset步的真實值+預測值plot(offset:numTimeSteps,[T(i,offset) Y_open(i,:)],'--')ylabel("Channel " + i)
end
xlabel("Time Step")
nexttile(1)
legend(["True Value" "Forecasted Value"])
- 開環的局限性:必須獲取每個時間步的真實值才能繼續預測,無法一次性預測多步未來(如無法直接預測未來200步,需等每一步真實值產生)。
7.2 核心:閉環預測(無需真實值,用前一次預測當輸入)
閉環預測的邏輯是:僅用初始真實數據初始化,后續預測完全依賴“前一次的預測結果”作為輸入,可一次性預測任意多步未來,適合無法獲取實時真實數據的場景(如預測未來一個月的銷量)。
代碼與逐行解析
% 1. 重置網絡狀態(關鍵!清除歷史隱藏狀態,確保從干凈的初始狀態開始)
net = resetState(net);% 2. 用測試序列的所有真實數據初始化網絡狀態(讓網絡“記住”完整的初始上下文)
offset = size(X,2); % offset=測試序列的總時間步(用全部真實數據初始化)
[net,Z] = predictAndUpdateState(net,X); % Z是初始預測結果(與測試序列長度一致)% 3. 閉環預測:預測未來200個時間步(可自定義數量)
numPredictionTimeSteps = 200; % 需預測的未來時間步數量
Xt = Z(:,end); % 初始輸入:最后一個時間步的預測值(閉環的核心:用預測值當輸入)
Y_closed = zeros(numChannels,numPredictionTimeSteps); % 存儲閉環預測結果% 循環預測:每一步用前一次的預測值作為輸入
for t = 1:numPredictionTimeSteps% 預測當前時間步+更新網絡狀態[net,Y_closed(:,t)] = predictAndUpdateState(net,Xt);% 更新輸入:下一次預測用當前的預測值Xt = Y_closed(:,t);
end% 4. 可視化閉環預測結果
numTimeSteps = offset + numPredictionTimeSteps; % 總時間步=初始真實數據+預測數據
figure
t = tiledlayout(numChannels,1);
title(t,"Closed Loop Forecasting") % 標題:閉環預測for i = 1:numChannelsnexttileplot(T(i,1:offset)) % 初始真實數據(前offset步)hold on% 預測數據:從offset步開始,拼接offset步的真實值+未來200步的預測值plot(offset:numTimeSteps,[T(i,offset) Y_closed(i,:)],'--')ylabel("Channel " + i)
endxlabel("Time Step")
nexttile(1)
legend(["Input (True Value)" "Forecasted Value"])
閉環預測的核心細節
-
為什么要
resetState
?
LSTM的隱藏狀態會“記憶”歷史數據,若不重置,網絡會攜帶上一次預測的殘留信息(如之前預測過的其他序列),導致當前預測的初始狀態錯誤,誤差被不斷放大。resetState
能將隱藏狀態清零,確保從“干凈的初始狀態”開始學習當前序列的上下文。 -
predictAndUpdateState
的作用?
該函數是閉環預測的核心工具,同時完成兩個任務:- 基于當前輸入(真實值或預測值)計算預測結果;
- 更新網絡的隱藏狀態(讓網絡“記住”當前輸入的信息,為下一次預測做準備)。
-
循環邏輯的關鍵?
每次循環中,Xt = Y_closed(:,t)
將“當前預測值”作為“下一次預測的輸入”,形成“預測→輸入→再預測”的閉環,無需任何真實值即可持續預測多步未來。
三、閉環預測的優缺點與適用場景
特點 | 優點 | 缺點 | 適用場景 |
---|---|---|---|
數據依賴 | 僅需初始真實數據,后續無需真實值 | 誤差會累積(前一步預測不準,后一步偏差更大) | 無法獲取實時真實數據、需一次性預測多步未來(如預測未來1年的季節性波動) |
靈活性 | 可自定義預測步數(如預測200步、500步) | 精度通常低于開環預測 | 長期趨勢預測、資源有限無法實時采集數據的場景 |
計算效率 | 一次性循環完成多步預測,無需等待真實數據 | 需合理初始化網絡狀態(否則初始誤差大) | 批量預測、離線預測任務 |
四、總結
本文通過完整的MATLAB代碼,拆解了LSTM時間序列預測的全流程:從數據加載與移位處理、網絡架構設計、訓練優化,到開環與閉環預測的實現。核心結論如下:
- 數據處理是基礎:移位目標序列讓LSTM學習“當前→下一個”的映射,歸一化避免訓練發散,左側補零保護有效信息。
- 網絡架構需適配任務:sequenceInputLayer匹配通道數,lstmLayer隱藏單元數量平衡學習能力與過擬合,regressionLayer適配回歸任務。
- 閉環預測是核心亮點:通過
resetState
初始化狀態、predictAndUpdateState
預測+更新狀態、循環用前一次預測當輸入,實現無需真實值的多步預測,適合實際應用中的長期預測需求。
掌握這套流程后,你可以將其遷移到自己的時序數據(如溫度、銷量、股價),只需調整通道數、隱藏單元數量、預測步數等參數,即可快速實現定制化的時間序列預測。