循環神經網絡(RNN)實現股票預測

文章目錄

  • 一、前言
  • 二、前期工作
    • 1. 設置GPU(如果使用的是CPU可以忽略這步)
    • 2. 導入數據
  • 四、數據預處理
    • 1.歸一化
    • 2.設置測試集訓練集
  • 五、構建模型
  • 六、激活模型
  • 七、訓練模型
  • 八、結果可視化
    • 1.繪制loss圖
    • 2.預測
    • 3.評估

一、前言

我的環境:

  • 語言環境:Python3.6.5
  • 編譯器:jupyter notebook
  • 深度學習環境:TensorFlow2.4.1

往期精彩內容:

  • 卷積神經網絡(CNN)實現mnist手寫數字識別
  • 卷積神經網絡(CNN)多種圖片分類的實現
  • 卷積神經網絡(CNN)衣服圖像分類的實現
  • 卷積神經網絡(CNN)鮮花識別
  • 卷積神經網絡(CNN)天氣識別
  • 卷積神經網絡(VGG-16)識別海賊王草帽一伙
  • 卷積神經網絡(ResNet-50)鳥類識別

來自專欄:機器學習與深度學習算法推薦

二、前期工作

1. 設置GPU(如果使用的是CPU可以忽略這步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #設置GPU顯存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")

2. 導入數據

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號
data = pd.read_csv('./datasets/SH600519.csv')  # 讀取股票文件data
Unnamed: 0dateopenclosehighlowvolumecode
0742010-04-2688.70287.38189.07287.362107036.13600519
1752010-04-2787.35584.84187.35584.68158234.48600519
2762010-04-2884.23584.31885.12883.59726287.43600519
3772010-04-2984.59285.67186.31584.59234501.20600519
4782010-04-3083.87182.34083.87181.52385566.70600519
242124952020-04-201221.0001227.3001231.5001216.80024239.00600519
242224962020-04-211221.0201200.0001223.9901193.00029224.00600519
242324972020-04-221206.0001244.5001249.5001202.22044035.00600519
242424982020-04-231250.0001252.2601265.6801247.77026899.00600519
242524992020-04-241248.0001250.5601259.8901235.18019122.00600519

2426 rows × 8 columns

training_set = data.iloc[0:2426 - 300, 2:3].values  
test_set = data.iloc[2426 - 300:, 2:3].values  

四、數據預處理

1.歸一化

sc           = MinMaxScaler(feature_range=(0, 1))
training_set = sc.fit_transform(training_set)
test_set     = sc.transform(test_set) 

2.設置測試集訓練集

x_train = []
y_train = []x_test = []
y_test = []"""
使用前60天的開盤價作為輸入特征x_train第61天的開盤價作為輸入標簽y_trainfor循環共構建2426-300-60=2066組訓練數據。共構建300-60=260組測試數據
"""
for i in range(60, len(training_set)):x_train.append(training_set[i - 60:i, 0])y_train.append(training_set[i, 0])for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0])# 對訓練集進行打亂
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
"""
將訓練數據調整為數組(array)調整后的形狀:
x_train:(2066, 60, 1)
y_train:(2066,)
x_test :(240, 60, 1)
y_test :(240,)
"""
x_train, y_train = np.array(x_train), np.array(y_train) # x_train形狀為:(2066, 60, 1)
x_test,  y_test  = np.array(x_test),  np.array(y_test)"""
輸入要求:[送入樣本數, 循環核時間展開步數, 每個時間步輸入特征個數]
"""
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test  = np.reshape(x_test,  (x_test.shape[0], 60, 1))

五、構建模型

model = tf.keras.Sequential([SimpleRNN(80, return_sequences=True), #布爾值。是返回輸出序列中的最后一個輸出,還是全部序列。Dropout(0.2),                         #防止過擬合SimpleRNN(80),Dropout(0.2),Dense(1)
])

六、激活模型

# 該應用只觀測loss數值,不觀測準確率,所以刪去metrics選項,一會在每個epoch迭代顯示時只顯示loss值
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')  # 損失函數用均方誤差

七、訓練模型

history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1)                  #測試的epoch間隔數model.summary()
Epoch 1/20
33/33 [==============================] - 6s 123ms/step - loss: 0.1809 - val_loss: 0.0310
Epoch 2/20
33/33 [==============================] - 3s 105ms/step - loss: 0.0257 - val_loss: 0.0721
Epoch 3/20
33/33 [==============================] - 3s 85ms/step - loss: 0.0165 - val_loss: 0.0059
Epoch 4/20
33/33 [==============================] - 3s 85ms/step - loss: 0.0097 - val_loss: 0.0111
Epoch 5/20
33/33 [==============================] - 3s 90ms/step - loss: 0.0099 - val_loss: 0.0139
Epoch 6/20
33/33 [==============================] - 3s 105ms/step - loss: 0.0067 - val_loss: 0.0167
Epoch 7/20
33/33 [==============================] - 3s 86ms/step - loss: 0.0067 - val_loss: 0.0095
Epoch 8/20
33/33 [==============================] - 3s 91ms/step - loss: 0.0063 - val_loss: 0.0218
Epoch 9/20
33/33 [==============================] - 3s 99ms/step - loss: 0.0052 - val_loss: 0.0109
Epoch 10/20
33/33 [==============================] - 3s 99ms/step - loss: 0.0043 - val_loss: 0.0120
Epoch 11/20
33/33 [==============================] - 3s 92ms/step - loss: 0.0044 - val_loss: 0.0167
Epoch 12/20
33/33 [==============================] - 3s 89ms/step - loss: 0.0039 - val_loss: 0.0032
Epoch 13/20
33/33 [==============================] - 3s 88ms/step - loss: 0.0041 - val_loss: 0.0052
Epoch 14/20
33/33 [==============================] - 3s 93ms/step - loss: 0.0035 - val_loss: 0.0179
Epoch 15/20
33/33 [==============================] - 4s 110ms/step - loss: 0.0033 - val_loss: 0.0124
Epoch 16/20
33/33 [==============================] - 3s 95ms/step - loss: 0.0035 - val_loss: 0.0149
Epoch 17/20
33/33 [==============================] - 4s 111ms/step - loss: 0.0028 - val_loss: 0.0111
Epoch 18/20
33/33 [==============================] - 4s 110ms/step - loss: 0.0029 - val_loss: 0.0061
Epoch 19/20
33/33 [==============================] - 3s 104ms/step - loss: 0.0027 - val_loss: 0.0110
Epoch 20/20
33/33 [==============================] - 3s 90ms/step - loss: 0.0028 - val_loss: 0.0037
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn (SimpleRNN)       (None, 60, 80)            6560      
_________________________________________________________________
dropout (Dropout)            (None, 60, 80)            0         
_________________________________________________________________
simple_rnn_1 (SimpleRNN)     (None, 80)                12880     
_________________________________________________________________
dropout_1 (Dropout)          (None, 80)                0         
_________________________________________________________________
dense (Dense)                (None, 1)                 81        
=================================================================
Total params: 19,521
Trainable params: 19,521
Non-trainable params: 0
_________________________________________________________________

八、結果可視化

1.繪制loss圖

plt.plot(history.history['loss']    , label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

2.預測

predicted_stock_price = model.predict(x_test)                       # 測試集輸入模型進行預測
predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 對預測數據還原---從(0,1)反歸一化到原始范圍
real_stock_price = sc.inverse_transform(test_set[60:])              # 對真實數據還原---從(0,1)反歸一化到原始范圍# 畫出真實數據和預測數據的對比曲線
plt.plot(real_stock_price, color='red', label='Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')
plt.title('Stock Price Prediction by K同學啊')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()

在這里插入圖片描述

3.評估

MSE   = metrics.mean_squared_error(predicted_stock_price, real_stock_price)
RMSE  = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5
MAE   = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)
R2    = metrics.r2_score(predicted_stock_price, real_stock_price)print('均方誤差: %.5f' % MSE)
print('均方根誤差: %.5f' % RMSE)
print('平均絕對誤差: %.5f' % MAE)
print('R2: %.5f' % R2)
均方誤差: 1833.92534
均方根誤差: 42.82435
平均絕對誤差: 36.23424
R2: 0.72347

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/162204.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/162204.shtml
英文地址,請注明出處:http://en.pswp.cn/news/162204.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Rust】快速教程——一直在單行顯示打印、輸入、文件讀寫

前言 恨不過是七情六欲的一種,再強大的恨也沒法獨占整顆心,總有其它情感隱藏在心底深處,說不定在什么時候就會掀起滔天巨浪。——《死人經》 圖中是Starship扔掉下面的燃料罐,再扔掉頭頂的翅膀后,再翻轉過來著陸火星的…

Andorid : Toast(彈出框)- 簡單應用

Toast Android官方在Android API 30版本(或更高版本)之后即對該方法不生效。 只要SDK版本低于30,Toast.setGravity()方法即可生效 MainActivity.java package com.example.mytoast;import androidx.appcompat.app.AppCompatActivity;import android.content.Cont…

[C++ 從入門到精通] 13.派生類、調用順序、繼承方式、函數遮蔽

📢博客主頁:https://loewen.blog.csdn.net📢歡迎點贊 👍 收藏 ?留言 📝 如有錯誤敬請指正!📢本文由 丶布布原創,首發于 CSDN,轉載注明出處🙉📢現…

SOEM主站開發篇(2):添加SOEM主站APP程序

0 工具準備 1.SOEM-1.4.0源碼(官網:http://openethercatsociety.github.io/) 2.Linux開發板(本文為正點原子I.MX6U ALPHA開發板) 3.交叉編譯工具(arm-linux-gnueabihf-gcc) 4.cmake(版本不得低于3.9,本文為3.9.2) 5.Ubuntu 16.04(用于編譯生成Linux開發板的可執行文…

【Unity細節】Default clip could not be found in attached animations list.(動畫機報錯)

👨?💻個人主頁:元宇宙-秩沅 hallo 歡迎 點贊👍 收藏? 留言📝 加關注?! 本文由 秩沅 原創 😶?🌫?收錄于專欄:unity細節和bug 😶?🌫?優質專欄 ?【…

生產制造業如何謀求數字化轉型?需要哪些信息化系統做支撐?

生產制造業的數字化轉型是將數字系統和各種技術整合到傳統制造流程中的過程,這將導致行業格局的重大變革。工業4.0的到來為制造業開創了一個新時代,制造商可以簡化生產線,提高整體效率。同時,這一技術革命使他們能夠收集到大量的數…

計算機網絡實用工具之tcpdump

簡介 tcpdump是一個運行在命令行下的數據包分析器。能夠獲取到該計算機發送或接收的TCP/IP和其他數據包。 tcpdump 適用于大多數的類Unix操作系統,包括Linux、Solaris、BSD、Mac OS X、HP-UX和AIX 等等。在這些系統中,tcpdump 需要使用libpcap這個捕捉…

Altium Designer學習筆記9

忽視了一個最大的問題,就是元器件的封裝,不應該是根據AD系統的封裝走,而應該是根據立創商城上的規格書,確認每個封裝的大小,畫出封裝圖,然后才是布局和走線。 1、確認電容的封裝采用0805,貼片電…

【css】Google第三方登錄按鈕樣式修改

文章目錄 場景前置準備修改樣式官方屬性修改樣式CSS修改樣式按鈕的高度height和border-radiusLogo和文字布局 場景 需要用到谷歌的第三方登錄,登錄按鈕有自己的樣式。根據官方文檔:概覽 | Authentication | Google for Developers,提供兩種第…

局域網協議:地址解析協議(ARP,Address Resolution Protocol)

地址解析協議(ARP,Address Resolution Protocol)是一種用于在IP網絡中將IP地址映射到物理MAC地址的協議。在IP網絡中,IP是用于尋址,真正將數據包從一個設備發送到另外一個設備,用于通信的是物理MAC地址。 …

40、Flink 的Apache Kafka connector(kafka sink的介紹及使用示例)-2

Flink 系列文章 1、Flink 部署、概念介紹、source、transformation、sink使用示例、四大基石介紹和示例等系列綜合文章鏈接 13、Flink 的table api與sql的基本概念、通用api介紹及入門示例 14、Flink 的table api與sql之數據類型: 內置數據類型以及它們的屬性 15、Flink 的ta…

geemap學習筆記012:如何搜索Earth Engine Python腳本

前言 本節主要是介紹如何查詢Earth Engine中已經集成好的Python腳本案例。 1 導入庫 !pip install geemap #安裝geemap庫 import ee import geemap2 搜索Earth Engine Python腳本 很簡單,只需要一行代碼。 geemap.ee_search()使用方法 后記 大家如果有問題需…

vue截取URL中的參數

url: http://localhost:81/login?redirect%2Findex&access_tokeneyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJvdUV4dGVybmFsSWQiOiI0OTI2MjYzMTIxMDU1NDAxMTM4IiwiYXVkIjpbImVudGVycHJpc2VfbW9iaWxlX3Jlc291cmNlIiwiYmZmX2FwaV9yZXN 截取參數: let…

如何提高圖片轉excel的效果?(軟件選擇篇)

在日常的工作中,我們常常會遇到一些財務報表類的圖片需要轉換成可編輯的excel,但是,受各種條件的限制,常常只能通過手工錄入這種原始的方式來實現,隨著人工智能、深度學習以及網絡技術的發展,這種原始的錄入…

SpringBoot集成七牛云OSS詳細介紹

📑前言 本文主要SpringBoot集成七牛云OSS詳細介紹的文章,如果有什么需要改進的地方還請大佬指出?? 🎬作者簡介:大家好,我是青衿🥇 ??博客首頁:CSDN主頁放風講故事 🌄每日一句&a…

【Java工具篇】Java反編譯工具Bytecode Viewer

💝💝💝歡迎來到我的博客,很高興能夠在這里和您見面!希望您在這里可以感受到一份輕松愉快的氛圍,不僅可以獲得有趣的內容和知識,也可以暢所欲言、分享您的想法和見解。 推薦:kwan 的首頁,持續學…

【C++高階(四)】紅黑樹深度剖析--手撕紅黑樹!

💓博主CSDN主頁:杭電碼農-NEO💓 ? ?專欄分類:C從入門到精通? ? 🚚代碼倉庫:NEO的學習日記🚚 ? 🌹關注我🫵帶你學習C ? 🔝🔝 紅黑樹 1. 前言2. 紅黑樹的概念以及性質3. 紅黑…

計算機網絡之數據鏈路層

一、概述 1.1概述 物理層發出去的信號需要通過數據鏈路層才知道是否到達目的地;才知道比特流的分界線 鏈路(Link):從一個結點到相鄰結點的一段物理線路,中間沒有任何其他交換結點數據鏈路(Data Link):把實現通信協議的硬件和軟件…

電商API接口|電商數據接入|拼多多平臺根據商品ID查商品詳情SKU和商品價格參數

隨著科技的不斷進步,API開發領域也逐漸呈現出蓬勃發展的勢頭。今天我將向大家介紹API接口,電商API接口具備獨特的特點,使得數據獲取變得更加高效便捷。 快速獲取API數據——優化數據訪問速度 傳統的數據獲取方式可能需要經過多個中介環節&…

華大基因認知障礙基因檢測服務,助力認知障礙疾病防控

認知障礙是一種嚴重的神經系統疾病,對人類的腦健康產生了重大影響。據報告顯示,在我國65歲以上的人群中,存在輕度認知障礙的患者約為3,800萬,而中重度癡呆患者則約為1,500萬,患病人口數量龐大。這種疾病不僅會對患者的…