利用 RNN 預測股票價格:從數據處理到可視化實戰

在金融領域,預測股票價格走勢一直是眾多投資者和研究者關注的焦點。今天,我們將利用深度學習中的循環神經網絡(RNN)來構建一個簡單的股票價格預測模型,并詳細介紹從數據加載、預處理、模型搭建、訓練到最終結果可視化的全過程。

一、項目概述

本項目旨在通過歷史股票價格數據,訓練一個 RNN 模型,使其能夠對未來股票價格進行一定程度的預測。我們將使用 Python 作為主要編程語言,結合 NumPy、PyTorch 以及 Scikit-learn 等強大的庫來實現這一目標。

二、數據準備

  1. 加載數據:首先,我們使用?np.loadtxt?函數從 CSV 文件(假設名為?data-02-stock_daily.csv)中讀取股票價格數據。這里需要注意指定正確的分隔符,通常股票數據 CSV 文件是以逗號分隔的,所以我們傳入?delimiter=','。讀取到的數據是一個二維數組,每一行代表一天的股票相關信息,如開盤價、收盤價、最高價、最低價等。為了讓數據按照時間順序排列,方便后續處理,我們使用切片操作?data = data[::-1]?將數據反轉。
  2. 歸一化處理:不同特征的數值范圍可能差異很大,這會影響模型訓練的效率和效果。因此,我們引入?MinMaxScaler?類進行歸一化處理。它會將數據的每一個特征都映射到 0 到 1 的區間內,具體操作是通過?data = MinMaxScaler().fit_transform(data)?實現。經過這一步,數據的分布更加規整,有助于模型更快更好地收斂。
  3. 構建輸入輸出序列:為了讓 RNN 模型能夠學習到股票價格的時間序列特征,我們需要設置一個時間步長?c(這里設為 7)。通過循環遍歷歸一化后的數據,構建輸入序列?x?和對應的輸出序列?y。對于輸入序列,我們將連續?c?天的數據作為一個樣本,即?x.append(data[i:i + c]);而輸出序列則是第?c + 1?天的股票價格,也就是?y.append(data[i + c][-1])。最后,將?x?和?y?轉換為 PyTorch 張量,方便后續在深度學習框架中使用,使用?x = torch.tensor(x, dtype=torch.float)?和?y = torch.tensor(y, dtype=torch.float)?完成轉換。
  4. 劃分數據集:使用?sklearn?的?train_test_split?函數將數據集劃分為訓練集和測試集。為了保證實驗的可重復性,我們指定?test_size=0.2,表示測試集占總數據集的 20%,以及?random_state=42?作為隨機種子。通過?x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)?得到劃分后的數據集,并打印出訓練集的形狀,以便了解數據的維度信息,用于后續模型參數的設置。

三、模型搭建

我們定義了一個自定義的 RNN 模型類,繼承自?torch.nn.Module。在?__init__?方法中:

  1. 首先調用父類的初始化方法?super().__init__(),確保模型的基礎結構正確初始化。
  2. 接著創建兩個 RNN 層,第一個?self.rnn1?的輸入大小根據訓練數據的特征維度確定,即?input_size=x_train.shape[2],這里?x_train.shape[2]?表示輸入數據的特征數量,隱藏層大小設為 128,并且設置?batch_first=True,使輸入張量的批次維度在第一維,方便與后續的數據加載器等組件配合;第二個?self.rnn2?的輸入大小為第一個 RNN 層的隱藏層大小 128,隱藏層大小設為 256,同樣設置?batch_first=True
  3. 最后定義一個線性層?self.linear,將第二個 RNN 層的輸出映射到預測的股票價格維度,其輸入特征數量為 256,輸出特征數量為 1。

在?forward?方法中:

  1. 輸入數據?x?首先經過第一個 RNN 層?self.rnn1,得到輸出?x?和隱藏狀態?y,由于在這個預測場景中我們不需要關注隱藏狀態,所以直接忽略?y,即?x, _ = self.rnn1(x)
  2. 接著?x?再經過第二個 RNN 層?self.rnn2,同樣忽略隱藏狀態,x, _ = self.rnn2(x)
  3. 最后將經過兩層 RNN 處理后的?x?的最后一個時間步的輸出(也就是?x[:, -1, :])傳入線性層?self.linear,得到最終的預測結果并返回。

四、模型訓練

  1. 實例化模型:創建?RNN?模型的實例,即?model = RNN()
  2. 定義損失函數:選用均方誤差損失函數(MSELoss)來衡量模型預測值與真實值之間的差異,loss_fn = torch.nn.MSELoss()。這是因為在預測股票價格這種連續值的任務中,均方誤差能夠很好地反映預測的準確性。
  3. 定義優化器:使用 Adam 優化器來更新模型的參數,指定學習率為 0.01,通過?optimizer = torch.optim.Adam(model.parameters(), lr=0.01)?完成定義。Adam 優化器在實際應用中表現出良好的收斂性能,能夠自適應地調整學習率,使得模型訓練更加高效。
  4. 訓練循環:設置訓練的輪數為 1000,在每一輪訓練中:
    • 首先使用?optimizer.zero_grad()?清空上一輪訓練的梯度信息,確保每一輪的梯度計算都是基于當前輪次的輸入數據。
    • 然后將訓練數據?x_train?傳入模型,得到預測輸出?h = model(x_train),并使用?loss_fn?計算預測值與真實值?y_train?之間的損失。
    • 接著調用?loss.backward()?進行反向傳播,計算模型參數的梯度。
    • 最后使用?optimizer.step()?根據計算得到的梯度更新模型參數,并將當前輪次的損失值添加到損失列表?loss_list?中。為了便于觀察訓練過程,每 100 個輪次打印一次損失值,如?if (epoch + 1) % 100 == 0: print(f'Epoch [{epoch + 1}/{num_epoch00}, Loss: {loss.item():.4f}')

五、模型預測與可視化

  1. 預測測試集:訓練完成后,將測試集數據?x_test?傳入模型,得到預測結果?predictions = model(x_test).squeeze(),這里的?squeeze?操作是為了去除可能存在的多余維度,使預測結果的維度與真實值?y_test?相匹配。
  2. 繪制預測結果:使用?matplotlib?庫繪制預測結果和真實結果的對比圖。首先創建一個新的繪圖窗口,設置合適的圖幅大小,如?plt.figure(figsize=(10, 6))。然后分別繪制預測值和真實值的折線圖,用紅色表示預測值?plt.plot(predictions.detach().numpy(), c='r', label='Prediction'),綠色表示真實值?plt.plot(y_test.detach().numpy(), c='g', label='Actual'),并添加標題、坐標軸標簽以及圖例,最后通過?plt.show()?展示繪圖結果。這使得我們能夠直觀地看到模型預測的股票價格與實際價格的接近程度,評估模型的性能。
  3. 繪制損失曲線:為了進一步了解模型訓練過程中的收斂情況,我們還繪制了訓練損失隨輪次變化的曲線。同樣創建一個新的繪圖窗口,繪制損失列表?loss_list?中的值,用藍色表示訓練損失?plt.plot(loss_list, c='b', label='Training Loss'),添加相應的標題、坐標軸標簽和圖例,最后展示繪圖結果。通過觀察損失曲線,我們可以判斷模型是否收斂,以及收斂的速度如何,為后續模型的優化提供參考。

通過以上完整的步驟,我們成功地利用 RNN 模型對股票價格進行了預測,并通過可視化手段直觀地展示了預測結果和訓練過程。當然,這只是一個簡單的示例,在實際應用中,還可以進一步優化模型結構、調整參數、增加更多的數據特征等,以提高預測的準確性。希望這個項目能夠為你在深度學習應用于金融領域的探索中提供一些幫助!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/901045.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/901045.shtml
英文地址,請注明出處:http://en.pswp.cn/news/901045.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

LangGraph 架構詳解

核心架構組件 LangGraph 的架構建立在一個靈活的基于圖的系統上,使開發者能夠定義和執行復雜的工作流。以下是主要架構組件: 1. 狀態管理系統 LangGraph 的核心是其強大的狀態管理系統,它允許應用程序在整個執行過程中維護一致的狀態&…

Python 深度學習實戰 第1章 什么是深度學習代碼示例

第1章:什么是深度學習 內容概要 第1章介紹了深度學習的背景、發展歷史及其在人工智能(AI)和機器學習(ML)中的地位。本章探討了深度學習的定義、其與其他機器學習方法的關系,以及深度學習在近年來取得的成…

swift菜鳥教程1-5(語法,變量,類型,常量,字面量)

一個樸實無華的目錄 今日學習內容:1.基本語法引入空格規范輸入輸出 2.變量聲明變量變量輸出加反斜杠括號 \\( ) 3.可選(Optionals)類型可選類型強制解析可選綁定 4.常量常量聲明常量命名 5.字面量整數 and 浮點數 實例字符串 實例 今日學習內容: 1.基本…

GAT-GRAPH ATTENTION NETWORKS(論文筆記)

CCF等級:A 發布時間:2018年 代碼位置 25年4月21日交 目錄 一、簡介 二、原理 1.注意力系數 2.歸一化 3.特征組合與非線性變換 4.多頭注意力 4.1特征拼接操作 4.2平均池化操作 三、實驗性能 四、結論和未來工作 一、簡介 圖注意力網絡&…

XML、JSON 和 Protocol Buffers (protobuf) 對比

目錄 1. XML (eXtensible Markup Language) 1)xml的特點: 2)xml的適用場景: 2. JSON (JavaScript Object Notation) 1)JSOM的特點: 2)JSON的適用場景: 3. Protocol Buffers (…

如何通過簡單步驟保護您的網站安全

在如今的數字化時代,網站安全已經成為每個網站管理者都不能忽視的重點。未授權用戶入侵、數據泄露和惡意軟件等威脅越來越多,網站安全對于保護企業、用戶和客戶的數據非常重要。為了幫助您提升網站的安全性,本文介紹了一些簡單且有效的措施&a…

【后端開發】初識Spring IoC與SpringDI、圖書管理系統

文章目錄 圖書管理系統用戶登錄需求分析接口定義前端頁面代碼服務器代碼 圖書列表展示需求分析接口定義前端頁面部分代碼服務器代碼Controller層service層Dao層modle層 Spring IoC定義傳統程序開發解決方案IoC優勢 Spring DIIoC &DI使用主要注解 Spring IoC詳解bean的存儲五…

通付盾風控智能體(RiskAgent): 神煩狗(DOGE)

在數字化業務高速發展的今天,風控系統已成為企業抵御黑產、欺詐、保障交易安全的核心防線。然而傳統風控面臨人力依賴高與策略滯后性等挑戰,數據分析師需每日從海量數據中手動提煉風險特征、設計防護規則,耗時費力;新策略從發現到…

大模型論文:Language Models are Unsupervised Multitask Learners(GPT2)

大模型論文:Language Models are Unsupervised Multitask Learners(GPT2) 文章地址:https://storage.prod.researchhub.com/uploads/papers/2020/06/01/language-models.pdf 摘要 自然語言處理任務,例如問答、機器翻譯、閱讀理解和摘要&am…

分布式ID生成方案的深度解析與Java實現

在分布式系統中,生成全局唯一的ID是一項核心需求,廣泛應用于訂單編號、用戶信息、日志追蹤等場景。分布式ID不僅需要保證全局唯一性,還要滿足高性能、高可用性以及一定的可讀性要求。本文將深入探討分布式ID的概念、設計要點、常見生成方案&a…

記 etcd 無法在docker-compose.yml啟動后無法映射數據庫目錄的問題

1、將etcd 單獨提取 Dockerfile,指定配置文件和數據目錄 #鏡像 FROM bitnami/etcd:3.5.11 #名稱 ENV name"etcd" #重啟 ENV restart"always" #運行無權限 ENV ALLOW_NONE_AUTHENTICATION"yes" #端口 EXPOSE 2379 2380 #管理員權限才…

怎樣才不算干擾球·棒球1號位

在棒球運動中,"干擾球"(Interference)是指球員或場外人員非法影響了比賽的正常進行。以下情況通常 不構成干擾,屬于合法行為或無需判罰: 1. 擊跑員(Batter-Runner)合法跑壘 跑壘限制…

PyTorch實現多輸入輸出通道的卷積操作

本文通過代碼示例詳細講解如何在PyTorch中實現多輸入通道和多輸出通道的卷積運算,并對比傳統卷積與1x1卷積的實現差異。 1. 多輸入通道互相關運算 當輸入包含多個通道時,卷積核需要對每個通道分別進行互相關運算,最后將結果相加。以下是實現…

深入解析 MySQL 中的日期時間函數:DATE_FORMAT 與時間查詢優化、DATE_ADD、CONCAT

深入解析 MySQL 中的日期時間函數:DATE_FORMAT 與時間查詢優化 在數據庫管理和應用開發中,日期和時間的處理是不可或缺的一部分。MySQL 提供了多種日期和時間函數來滿足不同的需求,其中DATE_FORMAT函數以其強大的日期格式化能力,…

SSH配置優化:提升本地內網Linux服務器遠程連接速度與穩定性

文章目錄 引言一. 理解SSH連接過程與影響因素二. 服務器端SSH配置優化三. 客戶端SSH配置優化四. 高級技巧五. 內網穿透突破公網IP限制總結 引言 SSH (Secure Shell) 是一種網絡協議,用于加密的網絡服務,常用于遠程登錄和管理Linux服務器。對于本地內網的…

BERT - MLM 和 NSP

本節代碼將實現BERT模型的兩個主要預訓練任務:掩碼語言模型(Masked Language Model, MLM) 和 下一句預測(Next Sentence Prediction, NSP)。 1. create_nsp_dataset 函數 這個函數用于生成NSP任務的數據集。 def cr…

“實時滾動”插件:一個簡單的基于vue.js的無縫滾動

1、參考連接: 安裝 | vue-seamless-scroll 2、使用步驟: 第一步:安裝 yarn add vue-seamless-scroll 第二步:引入 import vueSeamlessScroll from vue-seamless-scroll/src 第三步:注冊 components: { vueSeamless…

【藍橋杯】賽前練習

1. 排序 import os import sysn=int(input()) data=list(map(int,input().split(" "))) data.sort() for d in data:print(d,end=" ") print() for d in data[::-1]:print(d,end=" ")2. 走迷宮BFS import os import sys from collections import…

pyTorch-遷移學習-學習率衰減-四種天氣圖片多分類問題

目錄 1.導包 2.加載數據、拼接訓練、測試數據的文件夾路徑 3.數據預處理 3.1 transforms.Compose數據轉化 3.2分類存儲的圖片數據創建dataloader torchvision.datasets.ImageFolder torch.utils.data.DataLoader 4.加載預訓練好的模型(遷移學習) 4.1固定、修改預訓練…

第十四屆藍橋杯大賽軟件賽國賽Python大學B組題解

文章目錄 彈珠堆放劃分偶串交易賬本背包問題翻轉最大階梯最長回文前后綴貿易航線困局 彈珠堆放 遞推式 a i a i ? 1 i a_ia_{i-1}i ai?ai?1?i, n 20230610 n20230610 n20230610非常小,直接模擬 答案等于 494 494 494 劃分 因為總和為 1 e 6 1e6…