【python量化】多種Transformer模型用于股價預測(Autoformer, FEDformer和PatchTST等)_neuralforecast

bb1fee63b7d3f7f1db42af482660a610.png

寫在前面

在本文中,我們利用Nixtla的NeuralForecast框架,實現多種基于Transformer的時序預測模型,包括:Transformer, Informer, Autoformer, FEDformer和PatchTST模型,并且實現將它們應用于股票價格預測的簡單例子

1

NeuralForecast

neuralforecast 是一個旨在為時間序列預測提供一個豐富的、高度可用和魯棒的神經網絡模型集合的工具庫。這個庫集成了從傳統的多層感知器(MLP)和遞歸神經網絡(RNN)到最新的模型如N-BEATS、N-HiTS、TFT,以及其他高級架構,以適應多樣化的預測需求。它的關鍵功能包括對靜態、歷史和未來的外生變量的支持,提高了模型在實際應用中的靈活性。庫中的模型提供了良好的預測可解釋性,允許用戶繪制趨勢、季節性以及外生預測組件。neuralforecast 還實現了概率預測,通過簡單的適配器支持量化損失和參數分布,增加了預測結果的置信度。此外,它提供了自動模型選擇功能,通過并行自動超參數調整來高效確定最優的模型配置。庫的簡潔接口設計與SKLearn兼容,確保了易用性,并且訓練和評估損失的計算能夠適應不同的比例,這為不同規模的數據集提供了靈活性。最后,neuralforecast 包含了一個廣泛的模型集合,包括但不限于LSTM、RNN、TCN、N-BEATS、N-HiTS、ESRNN以及各種基于Transformer的預測模型等,都是以即插即用的方式實現,方便用戶直接應用于各種時間序列預測場景。這些特性使得neuralforecast 成為那些尋求高效、精確且可解釋時間序列預測模型的研究人員和實踐者的有力工具。本文將利用neuralforecast 實現各種Transformer模型,并展示將它們應用于股票價格預測的簡單例子。

2

環境配置

本地環境:

Python 3.8
IDE:Pycharm

庫版本:

Pandas version: 2.0.3
Matplotlib version: 3.7.1
Neuralforecast version: 1.6.4

為了使用最新的其他模型,也可以直接fork neuralforecast的源碼:

git clone https://github.com/Nixtla/neuralforecast.git
cd neuralforecast
pip install -e .

3

代碼實現

步驟 1: 導入所需的庫
  • 導入庫:首先,導入處理數據所需的 pandas 庫,繪圖所需的 matplotlib.pyplot 庫,以及 neuralforecast 中的多個模塊。這些模塊包括各種預測模型和評估指標函數。
import pandas as pd
from neuralforecast.models import VanillaTransformer, Informer, Autoformer, FEDformer, PatchTST
from neuralforecast.core import NeuralForecast
import matplotlib.pyplot as plt
from neuralforecast.losses.numpy import mae, rmse, mse
步驟 2: 數據準備
  • 讀取數據:使用 pandas從 CSV 文件加載數據。這個數據集包含股票的每日收盤價。

  • 數據預處理:重命名列以符合模型的輸入要求(例如,將日期列重命名為 ‘ds’,將收盤價列重命名為 ‘y’)。此外,將日期列轉換為日期時間格式,并為數據集添加一個唯一標識符,這對于使用neuralforecast進行時間序列預測是必要的。

df = pd.read_csv('./000001_Daily_Close.csv')
df['unique_id'] = 1
df = df.rename(columns={'date': 'ds', 'Close': 'y'})
df['ds'] = pd.to_datetime(df['ds'])
步驟 3: 定義預測模型
  • 初始化模型:定義一個模型列表,每個模型都是 neuralforecast 庫中的一個類的實例。對于每個模型,指定預測范圍(horizon)、輸入窗口大小(input_size)以及其他訓練參數(如 max_steps, val_check_steps)。

  • 模型配置:這些參數決定了模型的訓練方式,包括訓練持續時間、評估頻率和早停機制等。每個模型都有一些公共的參數以及它們自身的參數可以調整,這里均使用它們默認的參數進行模型初始化。

models = [VanillaTransformer(h=horizon,input_size=input_size,max_steps=train_steps,val_check_steps=check_steps,early_stop_patience_steps=3,scaler_type='standard'),Informer(h=horizon,  # Forecasting horizoninput_size=input_size,  # Input sizemax_steps=train_steps,  # Number of training iterationsval_check_steps=check_steps,  # Compute validation loss every 100 stepsearly_stop_patience_steps=3,  # Number of validation iterations before early stoppingscaler_type='standard'),  # Stop training if validation loss does not improveFEDformer(h=horizon,input_size=input_size,max_steps=train_steps,val_check_steps=check_steps,early_stop_patience_steps=3),Autoformer(h=horizon,input_size=input_size,max_steps=train_steps,val_check_steps=check_steps,early_stop_patience_steps=3),PatchTST(h=horizon,input_size=input_size,max_steps=train_steps,val_check_steps=check_steps,early_stop_patience_steps=3),]
步驟 4: 模型訓練與交叉驗證
  • 創建 NeuralForecast 實例:使用 NeuralForecast 類整合所有的模型。這個類提供了一個統一的接口來訓練和評估多個模型。

  • 執行交叉驗證:使用 cross_validation 方法對每個模型進行訓練和評估。這個方法自動進行時間序列的交叉驗證,分割數據集并評估模型在不同時間窗口上的性能。

nf = NeuralForecast(models=models,freq='B')Y_hat_df = nf.cross_validation(df=df,val_size=100,test_size=100,n_windows=None)
步驟 5: 數據篩選
  • 篩選數據點:通過選擇特定的“cutoff”點來過濾 Y_hat_df 中的預測。這種篩選基于預測范圍 horizon,確保評估是在均勻間隔的時間點上進行。
Y_plot = Y_hat_df
cutoffs = Y_hat_df['cutoff'].unique()[::horizon]
Y_plot = Y_plot[Y_hat_df['cutoff'].isin(cutoffs)]
步驟 6: 繪圖與性能評估
  • 繪制預測結果:使用 matplotlib 繪制真實數據與每個模型的預測結果。這有助于直觀地比較不同模型的預測準確性。

  • 計算評估指標:對每個模型,計算和打印均方根誤差(RMSE)、平均絕對誤差(MAE)和均方誤差(MSE)等性能指標。這些指標提供了量化模型性能的方式。

plt.figure(figsize=(20, 5))
plt.plot(Y_plot['ds'], Y_plot['y'], label='True')
for model in models:plt.plot(Y_plot['ds'], Y_plot[model], label=model)rmse_value = rmse(Y_hat_df['y'], Y_hat_df[model])mae_value = mae(Y_hat_df['y'], Y_hat_df[model])mse_value = mse(Y_hat_df['y'], Y_hat_df[model])print(f'{model}: rmse {rmse_value:.4f} mae {mae_value:.4f} mse {mse_value:.4f}')plt.xlabel('Datestamp')
plt.ylabel('Close')
plt.grid()
plt.legend()
plt.show()
步驟 7: 結果展示
  • 展示圖表:最后,顯示繪制的圖表。圖表展示了不同模型在整個時間序列上的預測表現,允許直觀地評估和比較模型。

5d185d6c7ec0781a5971ebf64ad56ad5.png

VanillaTransformer: rmse 56.5187 mae 38.8573 mse 3194.3650
Informer: rmse 52.2324 mae 39.1110 mse 2728.2239
FEDformer: rmse 48.9400 mae 35.9884 mse 2395.1237
Autoformer: rmse 58.5010 mae 45.7157 mse 3422.3614
PatchTST: rmse 48.5870 mae 36.1392 mse 2360.6968

在對比基于 Transformer 的各種模型在股票價格預測任務上的表現時,從可視化以及評估結果中,我們發現 FEDformer 和 PatchTST 在所有評估指標(RMSE、MAE、MSE)上表現最為出色,這可能歸因于它們在處理長期依賴關系和捕獲時間序列數據中的復雜模式方面的優勢。相較之下,雖然 Informer 顯示了合理的性能,但其表現略遜于 FEDformer 和 PatchTST。VanillaTransformer 和 Autoformer 的性能相對較差。這些結果強調了根據特定任務的需求選擇合適的模型架構的重要性,同時也表明了在實際應用中進行模型選擇時需要考慮到模型的特定優勢和潛在的局限性。

4

總結

本文展示了如何使用 neuralforecast 實現多種 Transformer 模型(包括 Informer, Autoformer, FEDformer 和 PatchTST),并將它們應用于股票價格預測的簡單示例。通過這個演示,我們可以看到 Transformer 模型在處理時間序列數據方面的潛力和靈活性。雖然我們的實驗是初步的,但它為進一步的研究和應用提供了一個基礎。讀者可以在此基礎上進行更深入的模型調優、特征工程和超參數實驗,以提升預測性能。此外,這些模型的應用不限于股票價格預測,還可以擴展到其他領域的時間序列分析。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/715273.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/715273.shtml
英文地址,請注明出處:http://en.pswp.cn/news/715273.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Libero集成開發環境中Identify應用與提高

Libero集成開發環境中Identify應用與提高 Identify的安裝

小米手機相冊閃退

環境: HyperOS 1.0 小米手機分身 處理步驟: 1)清理相冊緩存:設置->應用設置->相冊->清理數據->清除緩存(注意:別點清理全部數據;這個方法對我沒用)。 2)卸…

操作系統原理與實驗——實驗三優先級進程調度

實驗指南 運行環境: Dev c 算法思想: 本實驗是模擬進程調度中的優先級算法,在先來先服務算法的基礎上,只需對就緒隊列到達時間進行一次排序。第一個到達的進程首先進入CPU,將其從就緒隊列中出隊后。若此后隊首的進程的…

多租戶 TransmittableThreadLocal 線程安全問題

在一個多租戶項目中,用戶登錄時,會在自定義請求頭攔截器AsyncHandlerInterceptor將該用戶的userId,cstNo等用戶信息設置到TransmittableThreadLocal中,在后續代碼中使用.代碼如下: HeaderInterceptor 請求頭攔截器 public class HeaderInterceptor implements Asyn…

阿里云國際云服務器全局流量分析功能詳細介紹

進行全局流量分析時,內網DNS解析會作為一個整體模塊,其他模塊的邊緣虛框顏色會置灰,示意作為一個整體進行全局分析,左側Region可以展開/匯總,也可以單獨選中某個Region模塊進行分析(這時其他Region的流量線…

【Java面試題】Redis的用途

以下是一些常見的用途 1.緩存 Redis 可以用作緩存系統,,將頻繁訪問的數據存儲在內存中,從而加快數據訪問速度,減少對數據庫的訪問壓力。 2.消息隊列 Redis 支持發布/訂閱模式和列表數據結構,可以用作消息隊列系統的…

道可云元宇宙每日資訊|廈門首個元宇宙辦稅大廳啟用

道可云元宇宙每日簡報(2024年3月1日)訊,今日元宇宙新鮮事有: 中國軍號元宇宙發布會即將舉行 近日,解放軍新聞傳播中心中國軍號即將正式上線。中國軍號元宇宙發布會也將在“云端”與您見面。全方位展現解放軍新聞傳播…

加密與安全_探索簽名算法

文章目錄 概述應用常用數字簽名算法CodeDSA簽名ECDSA簽名小結 概述 在非對稱加密中,使用私鑰加密、公鑰解密確實是可行的,而且有著特定的應用場景,即數字簽名。 數字簽名的主要目的是確保消息的完整性、真實性和不可否認性。通過使用私鑰加…

云服務器購買教程

在購買云服務器之前,建議仔細評估自身需求和預算,并與多個云服務提供商進行比較,以確保選擇到最適合的解決方案。購買云服務器的具體步驟可能因所選云服務提供商而異。以下以實際操作的方式介紹如何購買一款云服務器。 云服務器購買常見問題…

【數倉】zookeeper軟件安裝及集群配置

相關文章 【數倉】基本概念、知識普及、核心技術【數倉】數據分層概念以及相關邏輯【數倉】Hadoop軟件安裝及使用(集群配置)【數倉】Hadoop集群配置常用參數說明 一、環境準備 準備3臺虛擬機 Hadoop131:192.168.56.131Hadoop132&#xff…

【Spring連載】使用Spring Data訪問 MongoDB----對象映射之基于類型的轉換器

【Spring連載】使用Spring Data訪問 MongoDB----對象映射之基于類型的轉換器 一、自定義轉換二、轉換器消歧(Disambiguation)三、基于類型的轉換器3.1 寫轉換3.2 讀轉換3.3 注冊轉換器 一、自定義轉換 下面的Spring Converter實現示例將String對象轉換為自定義Email值對象: R…

藍橋杯_定時器的綜合應用實例

一 工程 代碼 在單片機訓練平臺上,利用定時器T0,數碼管模塊和2個獨立按鍵(J5的2,3短接),設計一個秒表,具有清零,暫停,啟動功能。 顯示模式:分-秒-0.05秒&…

Linux進程——信號詳解(上)

文章目錄 信號入門生活角度的信號技術應用角度的信號用kill -l命令可以察看系統定義的信號列表信號處理常見方式概述 產生信號通過鍵盤進行信號的產生,ctrlc向前臺發送2號信號通過系統調用異常軟件條件 信號入門 生活角度的信號 你在網上買了很多件商品&#xff0…

前端面試練習24.3.2-3.3

HTMLCSS部分 一.說一說HTML的語義化 在我看來,它的語義化其實是為了便于機器來看的,當然,程序員在使用語義化標簽時也可以使得代碼更加易讀,對于用戶來說,這樣有利于構建良好的網頁結構,可以在優化用戶體…

vue3項目中如何一個vue組件中的一個div里面的圖片鋪滿整個屏幕樣式如何設置

在Vue 3項目中,要使一個div內的圖片鋪滿整個屏幕,你需要確保幾個關鍵點:div元素和圖片元素的樣式設置正確,以及確保它們能夠覆蓋整個視口(viewport)。以下是一個簡單的步驟和代碼示例,幫助你實現…

代碼隨想錄算法訓練營第四八天 | 買股票

目錄 只買賣一次可買賣多次 LeetCode 121. 買賣股票的最佳時機 LeetCode 122. 買賣股票的最佳時機II 只買賣一次 給定一個數組 prices ,它的第 i 個元素 prices[i] 表示一支給定股票第 i 天的價格。 你只能選擇 某一天 買入這只股票,并選擇在 未來的某…

瀏覽器輸入URL到頁面渲染經歷了哪些過程?

瀏覽器輸入URL到頁面渲染的過程可以分為以下幾個步驟: 解析URL:當用戶在瀏覽器的地址欄輸入URL后,瀏覽器會首先解析這個URL,判斷其是否合法。查找緩存:瀏覽器會查看自己的緩存,判斷是否有之前訪問過的這個U…

論文閱讀--Diffusion Models for Reinforcement Learning: A Survey

一、論文概述 本文主要內容是關于在強化學習中應用擴散模型的綜述。文章首先介紹了強化學習面臨的挑戰,以及擴散模型如何解決這些挑戰。接著介紹了擴散模型的基礎知識和在強化學習中的應用方法。然后討論了擴散模型在強化學習中的不同角色,并對其在多個…

【JavaSE】實用類——String、日期等

目錄 String類常用方法String類的equals()方法String中equals()源碼展示 “”和equals()有什么區別呢? StringBuffer類常用構造方法常用方法代碼示例 面試題:String類、StringBuffer類和StringBuilder類的區別?日期類Date類Calendar類代碼示例…

leetcode169. 多數元素的四種解法

leetcode169. 多數元素 題目描述 給定一個大小為 n 的數組 nums ,返回其中的多數元素。多數元素是指在數組中出現次數 大于? n/2 ? 的元素。 你可以假設數組是非空的,并且給定的數組總是存在多數元素。 1.哈希 class Solution { public:int majority…