DAY 59 經典時序預測模型3
知識點回顧:
- SARIMA模型的參數和用法:SARIMA(p, d, q)(P, D, Q)m
- 模型結果的檢驗可視化(昨天說的是摘要表怎么看,今天是對這個內容可視化)
- 多變量數據的理解:內生變量和外部變量
- 多變量模型
- 統計模型:SARIMA(單向因果)、VAR(考慮雙向依賴)
- 機器學習模型:通過滑動窗口實現,往往需要借助arima等作為特征提取器來捕捉線性部分(趨勢、季節性),再利用自己的優勢捕捉非線性的殘差
- 深度學習模型:獨特的設計天然為時序數據而生
作業:由于篇幅問題,無法實戰SARIMAX了,可以自己借助AI嘗試嘗試,相信大家已經有這個能力了。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.stattools import adfuller
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.statespace.sarimax import SARIMAX
import warnings
import itertools
warnings.filterwarnings('ignore')
# 顯示中文
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 1. 加載數據
url = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv'
df = pd.read_csv(url, header=0, index_col=0, parse_dates=True)
df.columns = ['Passengers']# 2. 劃分訓練集和測試集(保留最后12個月作為測試)
train_data = df.iloc[:-12]
test_data = df.iloc[-12:]print("--- 訓練集 ---")
print(train_data.tail()) # 觀察訓練集最后5行
print("\n--- 測試集 ---")
print(test_data.head()) # 觀察測試集前5行
# 3. 可視化原始數據
plt.figure(figsize=(12, 6))
plt.plot(train_data['Passengers'], label='訓練集')
plt.plot(test_data['Passengers'], label='測試集', color='orange')
plt.title('國際航空乘客數量 (1949-1960)')
plt.xlabel('年份')
plt.ylabel('乘客數量 (千人)')
plt.legend()
plt.show()
# 進行季節性差分 (D=1, m=12)
seasonal_diff = df['Passengers'].diff(12).dropna()
# 再進行普通差分 (d=1)
seasonal_and_regular_diff = seasonal_diff.diff(1).dropna()# 繪制差分后的數據
plt.figure(figsize=(12, 6))
plt.plot(seasonal_and_regular_diff)
plt.title('經過一次季節性差分和一次普通差分后的數據')
plt.show()# ADF檢驗
result = adfuller(seasonal_and_regular_diff)
print(f'ADF Statistic: {result[0]}')
print(f'p-value: {result[1]}') # p-value越小,越說明數據平穩
?
# 繪制ACF和PACF圖
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
plot_acf(seasonal_and_regular_diff, lags=36, ax=ax1) # 繪制36個時間點
plot_pacf(seasonal_and_regular_diff, lags=36, ax=ax2)
plt.show()
?
手動的超參數搜索:
# 固定已知參數
d = 1 # 非季節性差分階數
D = 1 # 季節性差分階數
m = 12 # 季節性周期(月度數據為12)# 定義待搜索的參數范圍
p = q = range(0, 3) # 非季節性參數 p和q取0-2
P = Q = range(0, 2) # 季節性參數 P和Q取0-1# 生成所有可能的參數組合
pdq = list(itertools.product(p, [d], q)) # d固定為1
seasonal_pdq = [(x[0], D, x[2], m) for x in list(itertools.product(P, [D], Q))] # D固定為1# 修正列名引用(假設數據列名為'Passengers')
train_column = 'Passengers' # 請根據實際數據列名調整# 初始化最佳參數和最小AIC
best_aic = float('inf')
best_pdq = None
best_seasonal_pdq = None
best_model = Noneprint("開始網格搜索最佳SARIMA參數...")# 網格搜索最佳參數
for param in pdq:for param_seasonal in seasonal_pdq:try:# 擬合SARIMA模型model = SARIMAX(train_data[train_column],order=param,seasonal_order=param_seasonal,enforce_stationarity=False, # 放寬平穩性約束enforce_invertibility=False, # 放寬可逆性約束disp=False)# 使用優化的擬合方法results = model.fit(method='bfgs', # 使用BFGS優化算法maxiter=200, # 增加最大迭代次數disp=False)# 打印當前參數組合及AICprint(f'SARIMA{param}x{param_seasonal} - AIC: {results.aic:.2f}')# 更新最佳參數if results.aic < best_aic:best_aic = results.aicbest_pdq = parambest_seasonal_pdq = param_seasonalbest_model = resultsexcept Exception as e:print(f'SARIMA{param}x{param_seasonal} 擬合失敗: {str(e)}')continue
# 輸出最佳模型
if best_pdq:print(f"\n最佳模型: SARIMA{best_pdq}x{best_seasonal_pdq} - AIC: {best_aic:.2f}")final_model = SARIMAX(train_data[train_column],order=best_pdq,seasonal_order=best_seasonal_pdq,enforce_stationarity=False,enforce_invertibility=False)final_results = final_model.fit(disp=False)
?
# 檢查是否找到有效模型
if best_model is not None:print(f'\n最佳模型: SARIMA{best_pdq}x{best_seasonal_pdq} - AIC: {best_aic:.2f}')# 打印最佳模型摘要print(best_model.summary())# 繪制模型診斷圖best_model.plot_diagnostics(figsize=(15, 10))plt.tight_layout()plt.show()else:print("\n未能找到合適的SARIMA模型。請檢查:")print("1. 數據列名是否正確(當前使用:", train_column, ")")print("2. 數據是否包含缺失值或異常值")print("3. 嘗試進一步調整參數范圍")print("4. 考慮使用其他時間序列模型")
?
# 1. 預測測試集
forecast = final_results.get_forecast(steps=len(test_data))
forecast_mean = forecast.predicted_mean
forecast_ci = forecast.conf_int()# 2. 評估模型
from sklearn.metrics import mean_squared_error
import numpy as npmse = mean_squared_error(test_data[train_column], forecast_mean)
rmse = np.sqrt(mse)
print(f'測試集 MSE: {mse:.2f}')
print(f'測試集 RMSE: {rmse:.2f}')# 3. 繪制預測結果
plt.figure(figsize=(12, 6))
plt.plot(train_data.index, train_data[train_column], label='訓練數據')
plt.plot(test_data.index, test_data[train_column], label='真實值', color='orange')
plt.plot(test_data.index, forecast_mean, label='預測值', color='red')
plt.fill_between(forecast_ci.index,forecast_ci.iloc[:, 0],forecast_ci.iloc[:, 1],color='pink', alpha=0.5, label='95%置信區間')
plt.title('SARIMA模型預測 vs. 真實值')
plt.xlabel('日期')
plt.ylabel('乘客數量 (千人)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
?
@浙大疏錦行?