1-4.時間序列數據建模流程范例

文章最前: 我是Octopus,這個名字來源于我的中文名–章魚;我熱愛編程、熱愛算法、熱愛開源。所有源碼在我的個人github
;這博客是記錄我學習的點點滴滴,如果您對 Python、Java、AI、算法有興趣,可以關注我的動態,一起學習,共同進步。

2020年發生的新冠肺炎疫情災難給各國人民的生活造成了諸多方面的影響。

有的同學是收入上的,有的同學是感情上的,有的同學是心理上的,還有的同學是體重上的。

本文基于中國2020年3月之前的疫情數據,建立時間序列RNN模型,對中國的新冠肺炎疫情結束時間進行預測。

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳

import torch 
print("torch.__version__ = ", torch.__version__)
torch.__version__ =  2.0.1

公眾號 算法美食屋 回復關鍵詞:pytorch, 獲取本項目源碼和所用數據集百度云盤下載鏈接。

import os#mac系統上pytorch和matplotlib在jupyter中同時跑需要更改環境變量
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 

一,準備數據

本文的數據集取自tushare,獲取該數據集的方法參考了以下文章。

《https://zhuanlan.zhihu.com/p/109556102》

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'svg'df = pd.read_csv("./eat_pytorch_datasets/covid-19.csv",sep = "\t")
df.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
plt.xticks(rotation=60);

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳

dfdata = df.set_index("date")
dfdiff = dfdata.diff(periods=1).dropna()
dfdiff = dfdiff.reset_index("date")dfdiff.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
plt.xticks(rotation=60)
dfdiff = dfdiff.drop("date",axis = 1).astype("float32")

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳

dfdiff.head()
confirmed_numcured_numdead_num
0457.04.016.0
1688.011.015.0
2769.02.024.0
31771.09.026.0
41459.043.026.0

下面我們通過繼承torch.utils.data.Dataset實現自定義時間序列數據集。

torch.utils.data.Dataset是一個抽象類,用戶想要加載自定義的數據只需要繼承這個類,并且覆寫其中的兩個方法即可:

  • __len__:實現len(dataset)返回整個數據集的大小。
  • __getitem__:用來獲取一些索引的數據,使dataset[i]返回數據集中第i個樣本。

不覆寫這兩個方法會直接返回錯誤。

import torch 
from torch import nn 
from torch.utils.data import Dataset,DataLoader,TensorDataset#用某日前8天窗口數據作為輸入預測該日數據
WINDOW_SIZE = 8class Covid19Dataset(Dataset):def __len__(self):return len(dfdiff) - WINDOW_SIZEdef __getitem__(self,i):x = dfdiff.loc[i:i+WINDOW_SIZE-1,:]feature = torch.tensor(x.values)y = dfdiff.loc[i+WINDOW_SIZE,:]label = torch.tensor(y.values)return (feature,label)ds_train = Covid19Dataset()#數據較小,可以將全部訓練數據放入到一個batch中,提升性能
dl_train = DataLoader(ds_train,batch_size = 38)for features,labels in dl_train:break #dl_train同時作為驗證集
dl_val = dl_train

二,定義模型

使用Pytorch通常有三種方式構建模型:使用nn.Sequential按層順序構建模型,繼承nn.Module基類構建自定義模型,繼承nn.Module基類構建模型并輔助應用模型容器進行封裝。

此處選擇第二種方式構建模型。

import torch
from torch import nn 
import importlib 
import torchkeras torch.random.seed()class Block(nn.Module):def __init__(self):super(Block,self).__init__()def forward(self,x,x_input):x_out = torch.max((1+x)*x_input[:,-1,:],torch.tensor(0.0))return x_outclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# 3層lstmself.lstm = nn.LSTM(input_size = 3,hidden_size = 3,num_layers = 5,batch_first = True)self.linear = nn.Linear(3,3)self.block = Block()def forward(self,x_input):x = self.lstm(x_input)[0][:,-1,:]x = self.linear(x)y = self.block(x,x_input)return ynet = Net()
print(net)
Net((lstm): LSTM(3, 3, num_layers=5, batch_first=True)(linear): Linear(in_features=3, out_features=3, bias=True)(block): Block()
)
Net((lstm): LSTM(3, 3, num_layers=5, batch_first=True)(linear): Linear(in_features=3, out_features=3, bias=True)(block): Block()
)
from torchkeras import summary
summary(net,input_data=features);
--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
==========================================================================
LSTM-1                                    [-1, 8, 3]                  480
Linear-2                                     [-1, 3]                   12
Block-3                                      [-1, 3]                    0
==========================================================================
Total params: 492
Trainable params: 492
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.000069
Forward/backward pass size (MB): 0.000229
Params size (MB): 0.001877
Estimated Total Size (MB): 0.002174
--------------------------------------------------------------------------

三,訓練模型

訓練Pytorch通常需要用戶編寫自定義訓練循環,訓練循環的代碼風格因人而異。

有3類典型的訓練循環代碼風格:腳本形式訓練循環,函數形式訓練循環,類形式訓練循環。

此處我們通過引入torchkeras庫中的KerasModel工具來訓練模型,無需編寫自定義循環。

torchkeras詳情: https://github.com/lyhue1991/torchkeras

注:循環神經網絡調試較為困難,需要設置多個不同的學習率多次嘗試,以取得較好的效果。

from torchmetrics.regression import MeanAbsolutePercentageErrordef mspe(y_pred,y_true):err_percent = (y_true - y_pred)**2/(torch.max(y_true**2,torch.tensor(1e-7)))return torch.mean(err_percent)net = Net() 
loss_fn = mspe
metric_dict = {"mape":MeanAbsolutePercentageError()}optimizer = torch.optim.Adam(net.parameters(), lr=0.03)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.0001)
from torchkeras import KerasModel 
model = KerasModel(net,loss_fn = loss_fn,metrics_dict= metric_dict,optimizer = optimizer,lr_scheduler = lr_scheduler) 
dfhistory = model.fit(train_data=dl_train,val_data=dl_val,epochs=100,ckpt_path='checkpoint',patience=10,monitor='val_loss',mode='min',callbacks=None,plot=True,cpu=True)
[0;31m<<<<<< 🐌 cpu is used >>>>>>[0m

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳

18.00% [18/100] [00:02<00:10]
████████████████████100.00% [1/1] [val_loss=0.4363, val_mape=0.5570]
[0;31m<<<<<< val_loss without improvement in 10 epoch,early stopping >>>>>> 
[0m

四,評估模型

評估模型一般要設置驗證集或者測試集,由于此例數據較少,我們僅僅可視化損失函數在訓練集上的迭代情況。

model.evaluate(dl_val)
100%|█████████████████████████████████| 1/1 [00:00<00:00, 63.91it/s, val_loss=0.384, val_mape=0.505]{'val_loss': 0.38373321294784546, 'val_mape': 0.5048269033432007}

五,使用模型

此處我們使用模型預測疫情結束時間,即 新增確診病例為0 的時間。

#使用dfresult記錄現有數據以及此后預測的疫情數據
dfresult = dfdiff[["confirmed_num","cured_num","dead_num"]].copy()
dfresult.tail()
confirmed_numcured_numdead_num
41143.01681.030.0
4299.01678.028.0
4344.01661.027.0
4440.01535.022.0
4519.01297.017.0
#預測此后1000天的新增走勢,將其結果添加到dfresult中
for i in range(1000):arr_input = torch.unsqueeze(torch.from_numpy(dfresult.values[-38:,:]),axis=0)arr_predict = model.forward(arr_input)dfpredict = pd.DataFrame(torch.floor(arr_predict).data.numpy(),columns = dfresult.columns)dfresult = pd.concat([dfresult,dfpredict],ignore_index=True)
dfresult.query("confirmed_num==0").head()# 第50天開始新增確診降為0,第45天對應3月10日,也就是5天后,即預計3月15日新增確診降為0
# 注:該預測偏樂觀
confirmed_numcured_numdead_num
500.0999.00.0
510.0948.00.0
520.0900.00.0
530.0854.00.0
540.0810.00.0

dfresult.query("cured_num==0").head()
# 第137天開始新增治愈降為0,第45天對應3月10日,也就是大概3個月后,即6月12日左右全部治愈。
# 注: 該預測偏悲觀,并且存在問題,如果將每天新增治愈人數加起來,將超過累計確診人數。
confirmed_numcured_numdead_num
1370.00.00.0
1380.00.00.0
1390.00.00.0
1400.00.00.0
1410.00.00.0

六,保存模型

模型權重保存在了model.ckpt_path路徑。

print(model.ckpt_path)
checkpoint
model.load_ckpt('checkpoint') #可以加載權重

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/38857.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/38857.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/38857.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

信息學奧賽初賽天天練-41-CSP-J2021基礎題-n個數取最大、樹的邊數、遞歸、遞推、深度優先搜索應用

PDF文檔公眾號回復關鍵字:20240701 2021 CSP-J 選擇題 單項選擇題&#xff08;共15題&#xff0c;每題2分&#xff0c;共計30分&#xff1a;每題有且僅有一個正確選項&#xff09; 4.以比較作為基本運算&#xff0c;在N個數中找出最大數&#xff0c;最壞情況下所需要的最少比…

我在中東做MCN,月賺10萬美金

圖片&#xff5c;Photo by Ben Koorengevel on Unsplash ©自象限原創 作者丨程心 在迪拜購物中心和世界最高建筑哈利法塔旁的主街上&#xff0c;徐晉已經“蹲”了三個小時&#xff0c;每當遇到穿著時髦的年輕男女&#xff0c;他都會上前詢問&#xff0c;有沒有意愿成為…

【計算機網絡】常見的網絡通信協議

目錄 1. TCP/IP協議 2. HTTP協議 3. FTP協議 4. SMTP協議 5. POP3協議 6. IMAP協議 7. DNS協議 8. DHCP協議 9. SSH協議 10. SSL/TLS協議 11. SNMP協議 12. NTP協議 13. VoIP協議 14. WebSocket協議 15. BGP協議 16. OSPF協議 17. RIP協議 18. ICMP協議 1…

網頁自動化測試開發中記錄pytest

1切換cmd文件目錄C:\Users\14600>D: D:\>cd D:\worksoftware D:\worksoftware>2單個py文件打包成.exe文件1.pyinstaller -F -c (項目主文件)test_01shouye.py 該路徑下存在文件名&#xff0c;主項目文件 test_01shouye.py 2.執行spec文件&#xff1a; pyinstaller -F …

C語言部分復習筆記

1. 指針和數組 數組指針 和 指針數組 int* p1[10]; // 指針數組int (*p2)[10]; // 數組指針 因為 [] 的優先級比 * 高&#xff0c;p先和 [] 結合說明p是一個數組&#xff0c;p先和*結合說明p是一個指針 括號保證p先和*結合&#xff0c;說明p是一個指針變量&#xff0c;然后指…

Web2Code :網頁理解和代碼生成能力的評估框架

多模態大型語言模型&#xff08;MLLMs&#xff09;在過去幾年中取得了爆炸性的增長。利用大型語言模型&#xff08;LLMs&#xff09;中豐富的常識知識&#xff0c;MLLMs在處理和推理各種模態&#xff08;如圖像、視頻和音頻&#xff09;方面表現出色&#xff0c;涵蓋了識別、推…

系統中非功能性需求的思考

概要 設計系統時不僅要考慮功能性需求&#xff0c;還要考慮一些非功能性需求&#xff0c;比如&#xff1a; 擴展性可靠性和冗余安全和隱私服務依賴SLA要求 下面對這5項需要考慮的事項做個簡單的說明 1. 可擴展性 數據量增長如何擴展&#xff1f; 流量增長如何擴展&#xf…

【LLM教程-llama】如何Fine Tuning大語言模型?

今天給大家帶來了一篇超級詳細的教程,手把手教你如何對大語言模型進行微調(Fine Tuning)&#xff01;&#xff08;代碼和詳細解釋放在后文&#xff09; 目錄 大語言模型進行微調(Fine Tuning)需要哪些步驟&#xff1f; 大語言模型進行微調(Fine Tuning)訓練過程及代碼 大語言…

VuePress介紹

從本文開始&#xff0c;動手搭建自己的博客&#xff01;希望讀者能跟著一起動手&#xff0c;這樣才能真正掌握。 ? VuePress 是什么 VuePress 是由 Vue 作者帶領團隊開發的&#xff0c;非常火&#xff0c;使用的人很多&#xff1b;Vue 框架官網也是用了 VuePress 搭建的。即…

000.二分查找算法題解目錄

000.二分查找算法題解目錄 69. x 的平方根&#xff08;簡單&#xff09;

4PCS點云配準算法實現

4PCS點云配準算法的C實現如下&#xff1a; #include <iostream> #include <pcl/io/pcd_io.h> #include <pcl/point_types.h> #include <pcl/common/common.h> #include <pcl/common/distances.h> #include <pcl/common/transforms.h> #in…

唯一ID:UUID 介紹與 google/uuid 庫生成 UUID

UUID 即通用唯一識別碼&#xff0c;是一種用于計算機系統中以確保全局唯一性的標識符。其標準定義于 RFC 4122 文檔中。標準形式包含 32 個 16 進制數字&#xff0c;以連字符切割為五組&#xff0c;格式為 8-4-4-4-12&#xff0c;總共 36 個字符。&#xff08;形如, d169aa7f-4…

php 通過vendor文件 生成還原最新的composer.json

起因&#xff1a;因為歷史原因&#xff0c;在本項目中composer.json基本算廢了&#xff0c;沒法直接使用composer管理擴展&#xff0c;今天嘗試修復一下composer.json。 歷史文件&#xff0c;可以看出來已經很久沒有維護了&#xff0c;我們主要是恢復require的信息 {"na…

K8s節點維護流程

用途 用于下線異常節點、集群縮容等 操作步驟 1. 查看節點名稱 先確認節點的名稱 kubectl get node -o wide2. 設置節點不可調度 設置節點不可調度狀態&#xff0c;禁止新的pod調度到該節點上 kubectl cordon ${node_name}3. 剔除節點上運行的pod&#xff08;生產環境慎…

Spring Boot中集成Redis實現緩存功能

Spring Boot中集成Redis實現緩存功能 大家好&#xff0c;我是免費搭建查券返利機器人省錢賺傭金就用微賺淘客系統3.0的小編&#xff0c;也是冬天不穿秋褲&#xff0c;天冷也要風度的程序猿&#xff01;今天我們將深入探討如何在Spring Boot應用程序中集成Redis&#xff0c;實現…

AP無法上線原因分析及排障

一、AP未分配到IP地址 如果遇到AP無法上線問題&#xff0c;可以檢查下AP是否分配到IP地址。AP獲取IP地址有兩種方式&#xff1a;靜態方式&#xff1a;登錄到AP設備&#xff0c;手工配置IP地址&#xff0c;該方式操作起來比較麻煩&#xff0c;不推薦使用&#xff1b;DHCP方式&am…

基于CNN的股票預測方法【卷積神經網絡】

基于機器學習方法的股票預測系列文章目錄 一、基于強化學習DQN的股票預測【股票交易】 二、基于CNN的股票預測方法【卷積神經網絡】 文章目錄 基于機器學習方法的股票預測系列文章目錄一、CNN建模原理二、模型搭建三、模型參數的選擇&#xff08;1&#xff09;探究window_size…

下代iPhone或回歸可拆卸電池,蘋果這操作把我看傻了

剛度過一個愉快的周末&#xff0c;蘋果又雙叒叕攤上事兒了。 iPhone13 系列被曝扎堆電池鼓包了。 早在去年&#xff0c;就有 iPhone13 和 iPhone14 用戶反饋過類似的問題&#xff0c;表示在手機僅僅使用了一年多的時間就出現了電池鼓包的情況&#xff0c;而且還把屏幕給撐起來了…

舞會無領導:一種樹形動態規劃的視角

沒有上司的舞會 Ural 大學有 &#x1d441; 名職員&#xff0c;編號為1~&#x1d441;。 他們的關系就像一棵以校長為根的樹&#xff0c;父節點就是子節點的直接上司。 每個職員有一個快樂指數&#xff0c;用整數 &#x1d43b;&#x1d456; 給出&#xff0c;其中1≤&…

校園卡手機卡怎么注銷?

校園手機卡的注銷流程可以根據不同的運營商和具體情況有所不同&#xff0c;但一般來說&#xff0c;以下是注銷校園手機卡的幾種常見方式&#xff0c;我將以分點的方式詳細解釋&#xff1a; 一、線上注銷&#xff08;通過手機APP或官方網站&#xff09; 下載并打開對應運營商的…