train_encoder_decoder.py

train_encoder_decoder.py

from __future__ import print_function #為了確保代碼同時兼容Python 2和Python 3版本中的print函數# 導入標準庫和第三方庫
import os.path #導入了Python的os.path模塊,用于處理文件和目錄路徑
from os import path #從os模塊中導入了path子模塊,可以直接使用path來調用os.path中的函數import sys #導入了sys模塊,用于系統相關的參數和函數
import math #導入了math模塊,提供了數學運算函數
import numpy as np #導入了NumPy庫,并使用np作為別名,NumPy是用于科學計算的基礎庫
import pandas as pd #導入了Pandas庫,并使用pd作為別名,Pandas是用于數據分析的強大庫# 導入深度學習相關庫
import tensorflow as tf #導入了TensorFlow深度學習框架from keras import backend as K #導入了Keras的backend模塊,并使用K作為別名,用于訪問后端引擎的函數
from keras.models import Model #從Keras導入了Model類,用于定義神經網絡模型
from keras.layers import LSTM, GRU, TimeDistributed, Input, Dense, RepeatVector #從Keras導入了LSTM、Input和Dense等神經網絡層
from keras.callbacks import CSVLogger, EarlyStopping, TerminateOnNaN #從Keras導入了CSVLogger、EarlyStopping和TerminateOnNaN等回調函數,用于模型訓練時的控制和記錄
from keras import regularizers #從Keras導入了regularizers模塊,用于正則化
from keras.optimizers import Adam #從Keras導入了Adam優化器,用于編譯模型時指定優化算法# 導入其他功能模塊
from functools import partial, update_wrapper #從Python標準庫functools中導入了partial和update_wrapper函數,用于函數式編程中的功能擴展和包裝# 這個函數的作用是創建一個部分應用(partial application)的函數,并保留原始函數的文檔字符串等信息。
def wrapped_partial(func, *args, **kwargs):partial_func = partial(func, *args, **kwargs)update_wrapper(partial_func, func)return partial_func# 這是一個自定義的損失函數,計算加權的均方誤差(Mean Squared Error),其中y_true是真實值,y_pred是預測值,weights是權重。
def weighted_mse(y_true, y_pred, weights):return K.mean(K.square(y_true - y_pred) * weights, axis=-1)# 這部分代碼用于選擇使用的GPU設備。它從命令行參數中獲取一個整數值gpu,如果gpu小于3,則設置CUDA環境變量以指定使用的GPU設備
import os
gpu = int(sys.argv[-13])
if gpu < 3:os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152os.environ["CUDA_VISIBLE_DEVICES"]= "{}".format(gpu)from tensorflow.python.client import device_libprint(device_lib.list_local_devices())# 這部分代碼獲取了一系列命令行參數,并將它們分別賦值給變量 
# 這些參數可能包括數據集名稱、訓練的批次數量、訓練周期數、學習率、正則化懲罰、Dropout率、耐心(用于Early Stopping)等 
imp = sys.argv[-1]
T = sys.argv[-2]
t0 = sys.argv[-3]
dataname = sys.argv[-4] 
nb_batches = sys.argv[-5]
nb_epochs = sys.argv[-6]
lr = float(sys.argv[-7])
penalty = float(sys.argv[-8])
dr = float(sys.argv[-9])
patience = sys.argv[-10]
n_hidden = int(sys.argv[-11])
hidden_activation = sys.argv[-12]# results_directory 是一個字符串,表示將要創建的結果文件夾路徑。dataname 是之前從命令行參數中獲取的數據集名稱
# 如果這個文件夾路徑不存在,就使用 os.makedirs 函數創建它。這個路徑通常用于存儲訓練模型的結果或者日志
results_directory = 'results/encoder-decoder/{}'.format(dataname)if not os.path.exists(results_directory):os.makedirs(results_directory)# 定義了一個函數 create_model,用于創建、編譯和返回一個循環神經網絡(RNN)模型
def create_model(n_pre, n_post, nb_features, output_dim, lr, penalty, dr, n_hidden, hidden_activation):""" creates, compiles and returns a RNN model @param nb_features: the number of features in the model"""# 這里定義了兩個輸入層:inputs 是一個形狀為 (n_pre, nb_features) 的輸入張量,用于模型的主輸入;weights_tensor 是一個形狀相同的張量,用于傳遞權重或其他需要的信息inputs = Input(shape=(n_pre, nb_features), name="Inputs")  weights_tensor = Input(shape=(n_pre, nb_features), name="Weights") # 這里使用了兩個 LSTM 層:lstm_1 是一個具有 n_hidden 個單元的 LSTM 層,應用了 dropout 和 recurrent_dropout,并且返回整個時間序列的輸出。lstm_2 是一個相同的 LSTM 層,但它只返回最后一個時間步的輸出。lstm_1 = LSTM(n_hidden, dropout=dr, recurrent_dropout=dr, activation=hidden_activation, return_sequences=True, name='LSTM_1')(inputs) # Encoderlstm_2 = LSTM(n_hidden, activation=hidden_activation, return_sequences=False, name='LSTM_2')(lstm_1) # Encoderrepeat = RepeatVector(n_post, name='Repeat')(lstm_2) # get the last output of the LSTM and repeats itgru_1 = GRU(n_hidden, activation=hidden_activation, return_sequences=True, name='Decoder')(repeat)  # Decoderoutput= TimeDistributed(Dense(output_dim, activation='linear', kernel_regularizer=regularizers.l2(penalty), name='Dense'), name='Outputs')(gru_1)model = Model([inputs, weights_tensor], output)# Compilecl = wrapped_partial(weighted_mse, weights=weights_tensor)model.compile(optimizer=Adam(lr=lr), loss=cl)print(model.summary()) return modeldef train_model(model, dataX, dataY, weights, nb_epoches, nb_batches):# Prepare model checkpoints and callbacksstopping = EarlyStopping(monitor='val_loss', patience=int(patience), min_delta=0, verbose=1, mode='min', restore_best_weights=True)csv_logger = CSVLogger('results/encoder-decoder/{}/training_log_{}_{}_{}_{}_{}_{}_{}_{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), separator=',', append=False)terminate = TerminateOnNaN()# Model fithistory = model.fit(x=[dataX,weights], y=dataY, batch_size=nb_batches, verbose=1,epochs=nb_epoches, callbacks=[stopping,csv_logger,terminate],validation_split=0.2)def test_model():n_post = int(1)n_pre =int(t0)-1seq_len = int(T)wx = np.array(pd.read_csv("data/{}-wx-{}.csv".format(dataname,imp)))print('raw wx shape', wx.shape)  wXC = []for i in range(seq_len-n_pre-n_post):wXC.append(wx[i:i+n_pre]) wXC = np.array(wXC)print('wXC shape:', wXC.shape)x = np.array(pd.read_csv("data/{}-x-{}.csv".format(dataname,imp)))print('raw x shape', x.shape) dXC, dYC = [], []for i in range(seq_len-n_pre-n_post):dXC.append(x[i:i+n_pre])dYC.append(x[i+n_pre:i+n_pre+n_post])dataXC = np.array(dXC)dataYC = np.array(dYC)print('dataXC shape:', dataXC.shape)print('dataYC shape:', dataYC.shape)nb_features = dataXC.shape[2]output_dim = dataYC.shape[2]# create and fit the encoder-decoder networkprint('creating model...')model = create_model(n_pre, n_post, nb_features, output_dim, lr, penalty, dr, n_hidden, hidden_activation)train_model(model, dataXC, dataYC, wXC, int(nb_epochs), int(nb_batches))# now testprint('Generate predictions on full training set')preds_train = model.predict([dataXC,wXC], batch_size=int(nb_batches), verbose=1)print('predictions shape =', preds_train.shape)preds_train = np.squeeze(preds_train)print('predictions shape (squeezed)=', preds_train.shape)print('Saving to results/encoder-decoder/{}/encoder-decoder-{}-train-{}-{}-{}-{}-{}-{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches))np.savetxt("results/encoder-decoder/{}/encoder-decoder-{}-train-{}-{}-{}-{}-{}-{}.csv".format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), preds_train, delimiter=",")print('Generate predictions on test set')wy = np.array(pd.read_csv("data/{}-wy-{}.csv".format(dataname,imp)))print('raw wy shape', wy.shape)  wY = []for i in range(seq_len-n_pre-n_post):wY.append(wy[i:i+n_pre]) # weights for outputswXT = np.array(wY)print('wXT shape:', wXT.shape)y = np.array(pd.read_csv("data/{}-y-{}.csv".format(dataname,imp)))print('raw y shape', y.shape)  dXT = []for i in range(seq_len-n_pre-n_post):dXT.append(y[i:i+n_pre]) # treated is inputdataXT = np.array(dXT)print('dataXT shape:', dataXT.shape)preds_test = model.predict([dataXT, wXT], batch_size=int(nb_batches), verbose=1)print('predictions shape =', preds_test.shape)preds_test = np.squeeze(preds_test)print('predictions shape (squeezed)=', preds_test.shape)print('Saving to results/encoder-decoder/{}/encoder-decoder-{}-test-{}-{}-{}-{}-{}-{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches))np.savetxt("results/encoder-decoder/{}/encoder-decoder-{}-test-{}-{}-{}-{}-{}-{}.csv".format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), preds_test, delimiter=",")def main():test_model()return 1if __name__ == "__main__":main()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/38592.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/38592.shtml
英文地址,請注明出處:http://en.pswp.cn/web/38592.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【場景題】數據庫優化和接口優化——異步思想

理解 異步處理&#xff1a; 對于耗時的操作&#xff0c;可以考慮使用異步處理方式來提升接口的響應速度。用戶可以在不阻塞當前操作的情況下&#xff0c;等待異步操作的結果。 異步處理在數據庫優化中的應用 雖然數據庫操作本身&#xff08;如查詢、插入、更新等&#xff09…

Git 安裝

目錄 Git 安裝 Git 安裝 在使用 Git 前我們需要先安裝 Git。Git 目前支持 Linux/Unix、Solaris、Mac 和 Windows 平臺上運行。Git 各平臺安裝包下載地址為&#xff1a;http://git-scm.com/downloads 在 Linux 平臺上安裝&#xff08;包管理工具安裝&#xff09; 首先&#xff0…

IIS在Windows上的搭建

&#x1f4d1;打牌 &#xff1a; da pai ge的個人主頁 &#x1f324;?個人專欄 &#xff1a; da pai ge的博客專欄 ??寶劍鋒從磨礪出&#xff0c;梅花香自苦寒來 目錄 一 概念&#xff1a; 二網絡…

深入理解C++中的鎖

目錄 1.基本互斥鎖&#xff08;std::mutex&#xff09; 2.遞歸互斥鎖&#xff08;std::recursive_mutex&#xff09; 3.帶超時機制的互斥鎖&#xff08;std::timed_mutex&#xff09; 4.帶超時機制的遞歸互斥鎖&#xff08;std::recursive_timed_mutex&#xff09; 5.共享…

【python腳本】批量檢測sql延時注入

文章目錄 前言批量檢測sql延時注入工作原理腳本演示 前言 SQL延時注入是一種在Web應用程序中利用SQL注入漏洞的技術&#xff0c;當傳統的基于錯誤信息或數據回顯的注入方法不可行時&#xff0c;例如當Web應用進行了安全配置&#xff0c;不顯示任何錯誤信息或敏感數據時&#x…

【TS】TypeScript 原始數據類型深度解析

&#x1f308;個人主頁: 鑫寶Code &#x1f525;熱門專欄: 閑話雜談&#xff5c; 炫酷HTML | JavaScript基礎 ?&#x1f4ab;個人格言: "如無必要&#xff0c;勿增實體" 文章目錄 TypeScript 原始數據類型深度解析一、引言二、基礎原始數據類型2.1 boolean2.2 …

蒼穹外賣--sky-take-out(四)10-12

蒼穹外賣--sky-take-out&#xff08;一&#xff09; 蒼穹外賣--sky-take-out&#xff08;一&#xff09;-CSDN博客?編輯https://blog.csdn.net/kussm_/article/details/138614737?spm1001.2014.3001.5501https://blog.csdn.net/kussm_/article/details/138614737?spm1001.2…

Unity動畫系統(2)

6.1 動畫系統基礎2-3_嗶哩嗶哩_bilibili p316 模型添加Animator組件 動畫控制器 AnimatorController AnimatorController 可以通過代碼控制動畫速度 建立動畫間的聯系 bool值的設定 trigger p318 trigger點擊的時候觸發&#xff0c;如喊叫&#xff0c;開槍及換子彈等&#x…

在js中如何Json字符串格式不對,如何處理

如果 JSON 字符串格式不正確&#xff0c;解析它時會拋出異常&#xff0c;但我們可以嘗試盡可能提取有效的信息。以下是一個方法&#xff0c;可以使用正則表達式和字符串操作來提取部分有效的 JSON 內容&#xff0c;即使整個字符串無法被 JSON.parse 完全解析。 示例代碼如下&a…

錯誤 [WinError 10013] 以一種訪問權限不允許的方式做了一個訪問套接字的嘗試 python ping

報錯提示&#xff1a;錯誤 [WinError 10013] 以一種訪問權限不允許的方式做了一個訪問套接字的嘗試 用python做了一個批量ping腳本&#xff0c;在windows專業版上沒問題&#xff0c;但是到了windows服務器就出現這個報錯 解決方法&#xff1a;右鍵 管理員身份運行 這個腳本 …

sql拉鏈表

1、定義&#xff1a;維護歷史狀態以及最新數據的一種表 2、使用場景 1、有一些表的數據量很大&#xff0c;比如一張用戶表&#xff0c;大約1億條記錄&#xff0c;50個字段&#xff0c;這種表 2.表中的部分字段會被update更新操作&#xff0c;如用戶聯系方式&#xff0c;產品的…

compute和computeIfAbsent的區別和用法

compute和computeIfAbsent都是Map接口中的默認方法&#xff0c;用于在映射中進行鍵值對的計算和更新。它們的主要區別在于它們的行為和使用場景。 compute 方法 定義: V compute(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction);參數: k…

在 WebGPU 與 Vulkan 之間做出正確的選擇(Making the Right Choice between WebGPU vs Vulkan)

在 WebGPU 與 Vulkan 之間做出正確的選擇&#xff08;Making the Right Choice between WebGPU vs Vulkan&#xff09; WebGPU 和 Vulkan 之間的主要區別WebGPU 是什么&#xff1f;它適合誰使用&#xff1f;Vulkan 是什么&#xff1f;它適合誰使用&#xff1f;WebGPU 和 Vulkan…

修改CentOS7 yum源

修改CentOS默認yum源為阿里鏡像源 備份系統自帶yum源配置文件 mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup 下載ailiyun的yum源配置文件 CentOS7 yum源如下&#xff1a; wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun…

AI領域最需要掌握的技術是什么?

在AI領域&#xff0c;掌握一系列核心技術和相關知識是非常重要的&#xff0c;以下是AI專業人士最需要掌握的一些關鍵技術&#xff1a; 1. **數學基礎** - 線性代數&#xff1a;用于處理向量和矩陣&#xff0c;是機器學習和深度學習的基石。 - 微積分&#xff1a;用于理解函數的…

SpringBoot項目使用WebSocket提示Error creating bean with name ‘serverEndpointExporter‘

問題描述&#xff1a;WebSocket在Controller中正常工作&#xff0c;但是在之后使用SpringBootTest進行單元測試的時候&#xff0c;突然提示WebSocket的相關錯誤。 錯誤提示&#xff1a; Exception encountered during context initialization - cancelling refresh attempt: …

項目中的代碼記錄日常

項目中的代碼記錄日常 /// <summary> /// 修改任務狀態 /// </summary> private void StartProcess21() {Process21Task new Thread(() >{while (CommonUtility.IsWorking){try{if (tPAgvTasksList.Count > 0){Parallel.ForEach(tPAgvTasksList, new Paral…

gitlab push的時候需要密碼,你忘記了密碼

情景: 忘記密碼,且登入網頁端gitlab的密碼并不能在push的時候使用,應該兩者是兩個不同的密碼 解決方法: 直接設置ssh密鑰登入,不使用密碼gitlab添加SSH密鑰——查看本地密鑰 & 生成ssh密鑰_gitlab生成ssh密鑰-CSDN博客

[OC]蘿卜圈Python手動機器人腳本

這是給機器人設置的端口&#xff0c;對照用 代碼 # #作者:溥哥’ ##機器人驅動主程序 #請在main中編寫您自己的機器人驅動代碼 import msvcrt def main():a"none"while True:key_input msvcrt.getch()akey_inputif abw:print(a)robot_drv.set_motors(1,40,2,40,3,…

uniapp學習筆記

uniapp官網地址&#xff1a;https://uniapp.dcloud.net.cn/ 學習源碼&#xff1a;https://gitee.com/qingnian8/uniapp-ling_project.git 顏色網址&#xff1a;https://colordrop.io/ uniapp中如何獲取導航中的路由信息&#xff1f; onLoad(e){console.log(e)console.log(e.w…