基于Transformer的多資產收益預測模型實戰(附PyTorch實現與避坑指南)

基于Transformer的多資產收益預測模型實戰(附PyTorch模型訓練及可視化完整代碼)


一、項目背景與目標

在量化投資領域,利用時間序列數據預測資產收益是核心任務之一。傳統方法如LSTM難以捕捉資產間的復雜依賴關系,而Transformer架構通過自注意力機制能有效建模多資產間的聯動效應。
本文將從零開始構建一個基于PyTorch的多資產收益預測模型,涵蓋數據生成、特征工程、模型設計、訓練及可視化全流程,適合深度學習與量化投資的初學者入門。

二、核心技術棧

  • 數據處理:Pandas/Numpy(數據生成與預處理)
  • 深度學習框架:PyTorch(模型構建與訓練)
  • 可視化:Matplotlib(結果分析)
  • 核心算法:Transformer(自注意力機制)

三、數據生成與預處理

1. 模擬金融數據生成

我們通過以下步驟生成包含5只資產的時間序列數據:

  • 市場基準因子:模擬市場整體趨勢(幾何布朗運動)
  • 行業因子:引入周期性波動區分不同行業(如科技、消費、能源)
  • 特質因子:每只資產的獨立噪聲
def generate_market_data(days=2000, n_assets=5):  np.random.seed(42)  market = np.cumprod(1 + np.random.normal(0.0003, 0.015, days))  # 市場基準  assets = []  sector_map = {0: "Tech", 1: "Tech", 2: "Consume", 3: "Consume", 4: "Energy"}  for i in range(n_assets):  sector_factor = 0.3 * np.sin(i * 0.8 + np.linspace(0, 10 * np.pi, days))  # 行業周期因子  idiosyncratic = np.cumprod(1 + np.random.normal(0.0002, 0.02, days))  # 特質因子  price = market * (1 + sector_factor) * idiosyncratic  # 價格合成  assets.append(price)  dates = pd.date_range("2015-01-01", periods=days)  return pd.DataFrame(np.array(assets).T, index=dates, columns=[f"Asset_{i}" for i in range(n_assets)])  

2. 數據形狀說明

生成的DataFrame形狀為[2000天, 5資產],索引為時間戳,列名為Asset_0到Asset_4。

四、特征工程:從價格到可訓練數據

1. 基礎時間序列特征

為每只資產計算以下特征:

  • 收益率(Return):相鄰日價格變化率
  • 波動率(Volatility):20日滾動標準差年化
  • 移動平均(MA10):10日價格移動平均
  • 行業相對強弱(Sector_RS):資產價格與所屬行業平均價格的比值
def create_features(data, lookback=60):  n_assets = data.shape[1]  sector_map = {0: "Tech", 1: "Tech", 2: "Consume", 3: "Consume", 4: "Energy"}  features = []  for i, asset in enumerate(data.columns):  df = pd.DataFrame()  df["Return"] = data[asset].pct_change()  df["Volatility"] = df["Return"].rolling(20).std() * np.sqrt(252)  # 年化波動率  df["MA10"] = data[asset].rolling(10).mean()  # 計算行業相對強弱  sector = sector_map[i]  sector_cols = [col for col in data.columns if sector_map[int(col.split("_")[1])] == sector]  df["Sector_RS"] = data[asset] / data[sector_cols].mean(axis=1)  features.append(df.dropna())  # 去除NaN  # 對齊時間索引  common_idx = features[0].index  for df in features[1:]:  common_idx = common_idx.intersection(df.index)  features = [df.loc[common_idx] for df in features]  # 構建3D特征張量 [樣本數, 時間步, 資產數, 特征數]  X = np.stack([np.stack([feat.iloc[i-lookback:i] for i in range(lookback, len(feat))], axis=0) for feat in features], axis=2)  # 標簽:未來5日平均收益率  y = np.array([data.loc[common_idx].iloc[i:i+5].pct_change().mean().values for i in range(lookback, len(common_idx))])  return X, y  

2. 輸入輸出形狀

  • 特征張量X形狀:[樣本數, 時間步(60), 資產數(5), 特征數(4)]
  • 標簽y形狀:[樣本數, 資產數(5)](每個樣本對應5只資產的未來5日平均收益率)

五、Transformer模型構建:核心架構解析

1. 模型設計目標

  • 處理多資產時間序列:同時輸入5只資產的歷史數據
  • 捕捉時間依賴資產間依賴:通過位置編碼和自注意力機制
  • 輸出多資產收益預測:回歸問題,使用MSE損失

2. 關鍵組件解析

(1)資產嵌入層(Asset Embedding)

將每個資產的4維特征映射到64維隱空間:

self.asset_embed = nn.Linear(n_features=4, d_model=64)  

輸入形狀:(batch, seq_len, assets, features) → 輸出:(batch, seq_len, assets, d_model)

(2)位置編碼(Positional Embedding)

由于Transformer無內置時序信息,需手動添加位置編碼:

self.time_pos = nn.Parameter(torch.randn(1, lookback=60, 1, d_model=64))  # 時間位置編碼  
self.asset_pos = nn.Parameter(torch.randn(1, 1, n_assets=5, d_model=64))  # 資產位置編碼  
  • 通過廣播機制與資產嵌入相加,分別捕獲時間和資產維度的位置信息。
(3)自定義Transformer編碼器層(Custom Transformer Encoder Layer)

繼承PyTorch原生層,返回注意力權重以可視化:

class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):  def __init__(self, d_model, nhead, dim_feedforward=256, dropout=0.1):  super(

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/80837.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/80837.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/80837.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

養生:打造健康生活的全方位策略

在生活節奏不斷加快的當下,養生已成為提升生活質量、維護身心平衡的重要方式。從飲食、運動到睡眠,再到心態調節,各個方面的養生之道共同構建起健康生活的堅實基礎。以下為您詳細介紹養生的關鍵要點,助您擁抱健康生活。 飲食養生…

輕型汽車鼓式液壓制動器系統設計

一、設計基礎參數 1.1 整車匹配參數 參數項數值范圍整備質量1200-1500kg最大設計車速160km/h輪胎規格195/65 R15制動法規要求GB 12676-2014 1.2 制動性能指標 制動減速度:≥6.2m/s(0型試驗) 熱衰退率:≤30%(連續10…

無法更新Google Chrome的解決問題

解決問題:原文鏈接:【百分百成功】Window 10 Google Chrome無法啟動更新檢查(錯誤代碼為1:0x80004005) google谷歌chrome瀏覽器無法更新Chrome無法更新至最新版本? 下載了 就是更新Google Chrome了

【AAAI 2025】 Local Conditional Controlling for Text-to-Image Diffusion Models

Local Conditional Controlling for Text-to-Image Diffusion Models(文本到圖像擴散模型的局部條件控制) 文章目錄 內容摘要關鍵詞作者及研究團隊項目主頁01 研究領域待解決問題02 論文解決的核心問題03 關鍵解決方案04 主要貢獻05 相關研究工作06 解決…

Kuka AI音樂AI音樂開發「人聲伴奏分離」 —— 「Kuka Api系列|中文咬字清晰|AI音樂API」第6篇

導讀 今天我們來了解一下 Kuka API 的人聲與伴奏分離功能。 所謂“人聲伴奏分離”,顧名思義,就是將一段完整的音頻拆分為兩個獨立的軌道:一個是人聲部分,另一個是伴奏(樂器)部分。 這個功能在音樂創作和…

Idea 設置編碼UTF-8 Idea中 .properties 配置文件中文亂碼

Idea 設置編碼UTF-8 Idea中 .properties 配置文件中文亂碼 一、設置編碼 1、步驟: File -> Setting -> Editor -> File encodings --> 設置編碼二、配置文件中文亂碼 1、步驟: File -> Setting -> Editor -> File encodings ->…

Xilinx FPGA PCIe | XDMA IP 核 / 應用 / 測試 / 實踐

注:本文為 “Xilinx FPGA 中 PCIe 技術與 XDMA IP 核的應用” 相關文章合輯。 圖片清晰度受引文原圖所限。 略作重排,未整理去重。 如有內容異常,請看原文。 FPGA(基于 Xilinx)中 PCIe 介紹以及 IP 核 XDMA 的使用 N…

sqli—labs第六關——雙引號報錯注入

一:判斷輸入類型 首先測試 ?id1,?id1,?id1",頁面回顯均無變化 所以我們采用簡單的布爾測試,分別測試數字型,單引號,雙引號 然后發現,只有在測試到雙引號注入的時候符合關鍵…

【TroubleShoot】禁用Unity Render Graph API 兼容模式

使用Unity 6時新建了項目,有一個警告提示: The project currently uses the compatibility mode where the Render Graph API is disabled. Support for this mode will be removed in future Unity versions. Migrate existing ScriptableRenderPasses…

圖形學、人機交互、VR/AR、可視化等領域文獻速讀【持續更新中...】

(1)筆者在時間有限的情況下,想要多積累一些自身課題之外的新文獻、新知識,所以開了這一篇文章。 (2)想通過將文獻喂給大模型,并向大模型提問的方式來快速理解文獻的重要信息(如基礎i…

Hadoop-HDFS-Packet含義及作用

在 HDFS(Hadoop Distributed File System)中,Packet 是數據讀寫過程中用于數據傳輸的基本單位。它是 HDFS 客戶端與數據節點(DataNode)之間進行數據交互時的核心概念,尤其在寫入和讀取文件時,Pa…

顯示的圖標跟UI界面對應不上。

圖片跟UI界面不符合。 要找到對應dp的值。UI的dp要跟代碼里的xml文件里的dp要對應起來。 藍湖里設置一個寬度給對應上。然后把對應的值填入xml. 一個屏幕上的圖片到底是用topmarin來設置,還是用bottommarin來設置。 因為第一節,5,7 車廂的…

【taro3 + vue3 + webpack4】在微信小程序中的請求封裝及使用

前言 正在寫一個 以taro3 vue3 webpack4為基礎框架的微信小程序,之前一直沒有記咋寫的,現在總結記錄一下。uniapp vite 的后面出。 文章目錄 前言一、創建環境配置文件二、 配置 Taro 環境變量三、 創建請求封裝四、如何上傳到微信小程序體驗版1.第二…

LeetCode:513、找樹左下角的值

//遞歸法 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode left, TreeNode right) {* t…

采用均線策略來跟蹤和投資基金

策略來源#睿思量化#小程序 截圖來源#睿思量化#小程序 在基金投資中,趨勢跟蹤策略是一種備受關注的交易方法。本文將基于兩張關于廣發電子信息傳媒股票 A(代碼:005310)的圖片資料,詳細闡述這一策略的應用與效果。 從第…

leetcode刷題---二分查找

力扣題目鏈接 二分查找算法使用前提&#xff1a;有序數組&#xff1b;數組內無重復元素 易錯點&#xff1a; 1.while循環的邊界條件&#xff1a;如到底是 while(left < right) 還是 while(left < right) 2.if條件后right&#xff0c;left的取值&#xff1a;到底是 right …

(leetcode) 力扣100 10.和為K的子數組(前綴和+哈希)

題目 給你一個整數數組 nums 和一個整數 k &#xff0c;請你統計并返回 該數組中和為 k 的子數組的個數 。 子數組是數組中元素的連續非空序列。 數據范圍 1 < nums.length < 2 * 104 -1000 < nums[i] < 1000 -107 < k < 107 樣例 示例 1&#xff1a; 輸…

遨游衛星電話與普通手機有什么區別?

在數字化浪潮席卷全球的今天&#xff0c;通信設備的角色早已超越傳統語音工具&#xff0c;成為連接物理世界與數字世界的核心樞紐。然而&#xff0c;當普通手機在都市叢林中游刃有余時&#xff0c;面對偏遠地區、危險作業場景的應急通信需求&#xff0c;其局限性便顯露無遺。遨…

在Linux中如何使用Kill(),向進程發送發送信號

kill()函數 #include <sys/types.h> #include <signal.h> int kill(pid_t pid, int sig); 函數參數和返回值含義如下: pid:參數 pid 為正數的情況下,用于指定接收此信號的進程 pid;除此之外,參數 pid 也可設置為 0 或-1 以及小于-1 等不同值,稍后給說明。 …

Java SpringMVC 和 MyBatis 整合關鍵配置詳解

目錄 一、數據源配置二、MyBatis 工廠配置三、Mapper 掃描配置四、SpringMVC 配置五、整合示例實體類Mapper 接口Mapper XML 文件Service 類控制器JSP 頁面六、總結在 Java Web 開發中,SpringMVC 和 MyBatis 是兩個常用框架。SpringMVC 負責 Web 層的請求處理和視圖渲染,MyBa…