2025.3.9機器學習筆記:文獻閱讀

2025.3.9周報

  • 一、文獻閱讀
    • 題目信息
    • 摘要
    • Abstract
    • 創新點
    • 網絡架構
    • 實驗
    • 結論
    • 不足以及展望

一、文獻閱讀

題目信息

  • 題目: Time-series generative adversarial networks for flood forecasting
  • 期刊: Journal of Hydrology
  • 作者: Peiyao Weng, Yu Tian, Yingfei Liu, Ying Zheng
  • 發表時間: 2023/5/20
  • 文章鏈接: https://www.sciencedirect.com/science/article/pii/S0022169423006443?via%3Dihub

摘要

洪水每年對全世界造成了巨大的損失,有效的洪水預警是重要防洪方式,但是現有洪水預測方法并不可靠,比如,傳統的物理計算方法和數據驅動模型需大量水文和地貌數據,又因為極端洪水事件觀測數據有限,這會導致模型的預測不準確。本論文基于時間序列GANs,探究其在洪水時間序列生成及預報中的作用,利用時間序列生成對抗網絡進行洪水預測,應用TimeGAN和RTSGAN生成合成洪水時間序列。以中國西江流域為例,實驗結果表明TimeGAN能準確高效模擬多站點洪水序列的時空相關性,RTSGAN在長序列時表現更優,且合成數據集可減少常規深度學習模型的預測誤差。這種方法為解決洪水預報中數據稀缺問題提供了新思路,對提升洪水預報具有重要作用,有助于減少洪水造成的損失。

Abstract

Floods cause substantial losses worldwide each year, making effective flood warning systems a critical component of flood prevention. However, existing flood prediction methods are often unreliable. For instance, traditional physically-based computational approaches and data-driven models require extensive hydrological and geomorphological data. Moreover, the limited availability of observational data for extreme flood events often leads to inaccurate predictions. This paper investigates the application of time series Generative Adversarial Networks (GANs) in generating and forecasting flood time series. Specifically, we employ TimeGAN and RTSGAN to produce synthetic flood time series data. Using the Xijiang River Basin in China as a case study, experimental results demonstrate that TimeGAN can accurately and efficiently simulate the spatiotemporal correlations of multi-site flood sequences. RTSGAN, on the other hand, exhibits superior performance for longer sequences. Furthermore, the synthetic datasets generated by these models can reduce prediction errors in conventional deep learning approaches. This methodology offers a novel solution to the challenge of data scarcity in flood forecasting, significantly enhancing flood prediction accuracy and contributing to the mitigation of flood-related damages.

創新點

作者首次將時間序列GANs應用于洪水時間序列生成中,以提升洪水預報模型精度,其利用歷史序列數據生成新數據,且無需預處理,生成的序列保留原始時間特征。

網絡架構

在論文第3.1節“Flood generating model”中,提到在生成合成洪水數據之前,需要對原始洪水時間序列進行預處理。數據為洪水時間序列,包含多個洪水事件𝑘,特征維度 𝑛,時間步長𝑚,窗口大小𝑤,序列長度𝐿。將原始洪水時間序列通過滑動窗口的方式分割成三維數據,其表示為: [ ∑ i = 1 k ( L i ? w + 1 ) , n , w ] [\sum_{i=1}^{k}\left(L_{i}-w+1\right), n, w] [i=1k?(Li??w+1),n,w]
窗口大小w決定每次處理的數據長度,論文中將小時尺度定為48小時每3小時一個時間步,則分為16個時間步;日尺度則為9天,則分為9個時間步;
論文中“the sliding step 𝑚 is set to 1”,將步長m設為1,表示窗口每次滑動1個時間步,生成重疊窗口。
舉個例子:對于第𝑖個洪水事件,時間序列長度為 L i L_i Li?,設 L i = 100 L_i = 100 Li?=100(洪水持續了100個小事),窗口大小w=16,步長m=1,有三個特征,即n=3。
則第一個窗口時間步1到16;第二個窗口時間步2到17;最后一個窗口時間步85到100。?
總窗口數則為100?16+1=85,且一個窗口是一個二維矩陣[n,w]。
所以這個三維矩陣表達為: [ ∑ i = 1 k ( L i ? w + 1 ) , n , w ] [\sum_{i=1}^{k}\left(L_{i}-w+1\right), n, w] [i=1k?(Li??w+1),n,w] = [85,3,16]
在這里插入圖片描述
TimeGAN包含嵌入函數、恢復函數、序列生成器和序列鑒別器的生成模型,通過有監督損失和無監督損失的學習嵌入空間對抗性和聯合訓練,讓模型得以同時學習編碼特征、生成表示和跨時間迭代。有關于TimeGAN的知識點,我在之前的博客記錄過,這里就不再贅述了,大家需要的話可以點擊這個鏈接查看:https://blog.csdn.net/Zcymatics/article/details/145011621?spm=1001.2014.3001.5501
在這里插入圖片描述
接下來我們詳細分析一下作者提到的RTSGAN,其全稱為“Recurrent Time-Series Generative Adversarial Network”,它是TimeGAN的改進版本,專門用于生成高質量時間序列數據。
其核心是通過結合WGAN 和自回歸解碼器,克服了傳統TimeGAN在長序列生成中的不穩定性和長序列信息捕捉能力弱問題。
傳統GAN使用的是JS散度或KL散度優化生成分布與真實分布的相似性。想要了解GAN可以去我這篇博客看一下原理:https://blog.csdn.net/Zcymatics/article/details/145011685?spm=1001.2014.3001.5501
而WGAN使用Wasserstein距離替代JS散度,減少模式崩塌,提高訓練穩定性。

模型坍塌是指模型在訓練過程中逐漸失去多樣性和泛化能力,導致性能嚴重退化甚至無法繼續優化。

在這里插入圖片描述
其損失函數被定義為: L = E r ~ P r [ D ( r ) ] ? E r ^ ~ P g [ D ( r ^ ) ] L=E_{r \sim P_{r}}[D(r)]-E_{\hat{r} \sim P_{g}}[D(\hat{r})] L=ErPr??[D(r)]?Er^Pg??[D(r^)]
D是判別器 E r ~ P r [ D ( r ) ] E_{r \sim P_{r}}[D(r)] ErPr??[D(r)]是真實樣本𝑟通過𝐷的期望輸出; E r ^ ~ P g [ D ( r ^ ) ] E_{\hat{r} \sim P_{g}}[D(\hat{r})] Er^Pg??[D(r^)]合成樣本 r ^ \hat{r} r^通過𝐷的期望輸出。

說完了Wasserstein距離,下面來說一下RTSGAN的另一個改進點:
自回歸編碼器
TimeGAN中,使用的是嵌入函數和恢復函數完成。其嵌入函數是逐時間步生成潛在序列𝐻,每個時間步都有一個對應的潛在表示 h t h_t ht? ;其恢復函數從整個潛在序列𝐻一次性生成重建序列 X ^ \hat X X^
RTSGAN中使用的則是自回歸編碼器和解碼器。編碼器同樣負責將時間序列X映射到潛在空間,但其目標是生成一個固定維度的潛在表示𝑟,而不是逐時間步的潛在序列𝐻;其解碼器從潛在表示逐步生成時間序列(不是一次性輸出),采用自回歸方式生成每個時間步的輸出。
TimeGAN的一次性生成會導致長序列的時間相關性可能被削弱。RTSGAN中自回歸解碼器通過逐步生成的方式學習時間依賴性,每一步的生成依賴于前一步的輸出,更加能夠捕捉時間動態變化。
在這里插入圖片描述
下圖是論文工作的總體的流程圖:
首先數據經過預處理后給TimeGAN和RTSGAN生成數據,然后真實的訓練數據和生成數據進入滑動窗口。滑動窗口將真實和合成數據分割成兩個部分,為監督學習提供數據,真實數據為covariates,如 x t ? 2 1 , x t ? 2 2 , x t ? 2 3 x_{t-2}^1, x_{t-2}^2, x_{t-2}^3 xt?21?,xt?22?,xt?23?,綠色框為GAN生成數據y,如 y t y_t yt?,每個實例包含輸入和輸出。這些數據對供GBRT和LSTM訓練。最后兩個預測模型輸出預測值與真實值對比,得出評估指標。
在這里插入圖片描述

實驗

本文圍繞時間序列生成對抗網絡在洪水預測中的應用展開實驗,
論文的評價指標如下圖所示,其中比較陌生的有:
1、皮爾遜相關系數,其用于衡量預測值與觀測值的線性相關性,值越大越好(最大值為1):
O C = ∑ i = 1 n ( F i ? F ˉ ) ( O i ? O ˉ ) ∑ i = 1 n ( F i ? F ˉ ) 2 ∑ i = 1 n ( O i ? O ˉ ) 2 O C=\frac{\sum_{i=1}^n\left(F_i-\bar{F}\right)\left(O_i-\bar{O}\right)}{\sqrt{\sum_{i=1}^n\left(F_i-\bar{F}\right)^2} \sqrt{\sum_{i=1}^n\left(O_i-\bar{O}\right)^2}} OC=i=1n?(Fi??Fˉ)2 ?i=1n?(Oi??Oˉ)2 ?i=1n?(Fi??Fˉ)(Oi??Oˉ)?
F i F_i Fi?為預測值; O i O_i Oi?為觀測值; F ˉ \bar{F} Fˉ O ˉ \bar{O} Oˉ分別為預測和觀測均值。
2、平均絕對相對誤差,計算預測值與觀測值的相對偏差,值越小越好(0為最小值):
A A R E = 1 n ∑ i = 1 n ∣ F i p ? O i p O i p ∣ A A R E=\frac{1}{n} \sum_{i=1}^n\left|\frac{F_i^p-O_i^p}{O_i^p}\right| AARE=n1?i=1n? ?Oip?Fip??Oip?? ?
𝑛為樣本數; F i p F_i^p Fip?為預測值; O i p O_i^p Oip?為觀測值。
3、預測區間覆蓋概率,衡量觀測值落入預測區間的比例,最佳值為1:
P I C P = 1 n ∑ i = 1 n ε i , ε i = 1 P I C P=\frac{1}{n} \sum_{i=1}^n \varepsilon_i, \quad \varepsilon_i=1 PICP=n1?i=1n?εi?,εi?=1 if O i ∈ [ L i , U i ] O_i \in\left[L_i, U_i\right] Oi?[Li?,Ui?], else 0
L i L_i Li? U i U_i Ui?為預測區間的上下界。
4、預測區間歸一化平均寬度,評估預測區間的寬度,最佳值為0:
P I N A W = 1 n R ∑ i = 1 n ( U i ? L i ) , R = max ? ( O i ) ? min ? ( O i ) PINAW =\frac{1}{n R} \sum_{i=1}^n\left(U_i-L_i\right), \quad R=\max \left(O_i\right)-\min \left(O_i\right) PINAW=nR1?i=1n?(Ui??Li?),R=max(Oi?)?min(Oi?)
R為觀測值范圍;
5、覆蓋寬度綜合準則,平衡覆蓋率和區間寬度,PICP低于𝜇時施加指數懲罰,0為最佳值:
C W C = P I N A W ( 1 + γ e η ( μ ? P I C P ) ) , γ ( P I C P ) = 1 C W C=P I N A W\left(1+\gamma e^{\eta(\mu-P I C P)}\right), \quad \gamma(P I C P)=1 CWC=PINAW(1+γeη(μ?PICP)),γ(PICP)=1 if P I C P < μ P I C P<\mu PICP<μ, else 0
μ為期望覆蓋率;𝜂為懲罰系數;
在這里插入圖片描述
結果如下:
論文通過二維t-SNE和PCA兩種降維方法對原始洪水時間序列、TimeGAN生成的合成序列以及RTSGAN生成的合成序列進行了可視化分析。結果表明,RTSGAN生成的合成序列在分布上與原始數據集更為接近,表現為數據點簇的形狀和密度更貼近原始數據,而TimeGAN生成的序列分布則稍顯分散,偏離原始數據的特征。此外,RTSGAN生成的數據點不僅更集中,還展現出更高的多樣性。
在這里插入圖片描述
圖6通過熱圖形式展示了原始數據集、TimeGAN和RTSGAN生成序列在不同時間滯后 𝜏 下的時間相關性對比。熱圖中,顏色深淺表示相關系數的大小,RTSGAN的熱圖模式與原始數據高度一致,尤其是在日尺度上,梧州站點的合成序列時間相關系數幾乎完全重合原始數據。這表明RTSGAN能夠精確再現洪水序列的時間依賴性,例如洪峰的到達時間和衰退過程。
在小時尺度上,由于時間步長更小,序列長度增加,對模型捕捉長期依賴性的要求更高。RTSGAN依然表現出色,其熱圖顯示的空間相關性和時間相關性均優于TimeGAN。這可能是由于RTSGAN的自回歸解碼器設計,能夠逐步生成序列并保留長期依賴。RTSGAN訓練時間更短。
在這里插入圖片描述
在多時間步預測任務中,論文比較了GBRT、LSTM、QRLSTM等模型的性能,圖7展示了日尺度下各模型在T、T+1、T+2的評估指標箱線圖,包括OC、AARE等。結果顯示,LSTM在所有預測步長中的OC最高,AARE最低,且隨著預測期從T到T+2增加,其優勢逐漸擴大。例如,箱線圖中LSTM的中位數OC值可能接近1,離群點較少,而其他模型的OC分布更分散,AARE更高。LSTM“三門”機制使LSTM能夠有效捕捉洪水時間序列中的長期依賴關系。相比之下,GBRT基于窗口的特征提取更適合短期依賴,而QRLSTM在不確定性預測上雖有優勢,但在點預測準確性上不如標準LSTM。
在這里插入圖片描述
RTSGAN生成的合成數據集能顯著提升基于窗口的GBRT模型性能,尤其在較長預測期。在24小時預測中,GBRT - RTSGAN在CC和NSE指標上分別比GBRT高2.08%和6.74%,RMSE降低,AARE減少56.52%。而TimeGAN對GBRT的改進效果有限且不穩定,其合成訓練集甚至會導致誤差范圍增大。合成數據集對LSTM的有效性隨預測期增加而急劇下降。
在這里插入圖片描述
SHAP分析用于解釋模型預測中各特征的重要性,通過SHAP分析發現,GBRT和GBRT - RTSGAN模型使用幾乎相同的影響因素進行預測,且形狀值模式趨勢一致;而GBRT - TimeGAN從T + 5步開始與GBRT的形狀值模式差異較大。
在這里插入圖片描述
引入合成訓練數據集可幫助QRLSTM模型獲得更低的PINAW,但會降PICP,只有少數情況能在保持區間覆蓋的同時減小區間寬度。圖12結果顯示,使用RTSGAN或TimeGAN的合成數據后,QRLSTM的PINAW降低,例如從0.15降至0.10,表明預測區間更窄、更精確。但是PICP從0.95降至0.85,說明部分觀測值未被區間覆蓋。這種權衡表明合成數據增強了模型的準確性,但犧牲了覆蓋率。只有少數情況實現了PINAW降低的同時保持PICP接0.9。
隨著合成數據集數量增加,PINAW和PICP降低,模型的CWC下降。RTSGAN對PINAW的影響更大,表明可通過調節假樣本數量使預測區間更靈活、易調整。盡管合成訓練集在QRLSTM中表現不佳,但在洪水峰值處較窄的預測區間是可取的。
在這里插入圖片描述

代碼如下:

import numpy as np  
import torch  
import torch.nn as nn  
import torch.optim as optim  # 導入PyTorch的優化器模塊
from sklearn.ensemble import GradientBoostingRegressor  # 導入sklearn的GBRT模型
from sklearn.metrics import mean_squared_error  # 導入sklearn的均方誤差計算函數
import matplotlib.pyplot as plt  # 設置隨機種子,確保結果可重復
np.random.seed(42)
torch.manual_seed(42)# 定義數據參數
n_samples = 1755  # 樣本數量,參考論文小時尺度數據
n_features = 3  # 特征數量(假設流量、降雨量等)
seq_len = 16  # 序列長度(窗口大小,48小時,每3小時一個時間步)# 生成模擬洪水時間序列數據,假設數據服從正弦波+噪聲的形式
time_steps = np.arange(seq_len)  # 創建時間步數組
data = np.zeros((n_samples, seq_len, n_features))  # 初始化數據數組
for i in range(n_samples):for f in range(n_features):# 生成正弦波數據,模擬周期性洪水特征data[i, :, f] = np.sin(0.1 * time_steps + np.random.uniform(0, 2 * np.pi)) + np.random.normal(0, 0.1, seq_len)# 將數據轉換為PyTorch張量
data = torch.FloatTensor(data)  # 轉換為浮點張量,形狀為[n_samples, seq_len, n_features]# 定義GRU網絡類,作為TimeGAN的基礎模塊
class GRU(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers):super(GRU, self).__init__()self.hidden_dim = hidden_dim  # 設置隱藏層維度self.num_layers = num_layers  # 設置GRU層數self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)  # 定義GRU層def forward(self, x):# 初始化隱藏狀態h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)# 前向傳播out, _ = self.gru(x, h0)return out# 定義TimeGAN模型
class TimeGAN(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers):super(TimeGAN, self).__init__()self.hidden_dim = hidden_dim  # 設置隱藏層維度self.input_dim = input_dim  # 設置輸入維度# 嵌入函數(Embedding Function)self.embedder = GRU(input_dim, hidden_dim, num_layers)  # 定義嵌入函數,使用GRU網絡# 恢復函數(Recovery Function)self.recovery = nn.Sequential(nn.Linear(hidden_dim, input_dim),  # 線性層,將隱藏狀態映射回輸入維度nn.Sigmoid()  # Sigmoid激活,確保輸出在[0, 1]范圍內)# 生成器(Generator)self.generator = GRU(hidden_dim, hidden_dim, num_layers)  # 定義生成器,使用GRU網絡self.gen_output = nn.Linear(hidden_dim, hidden_dim)  # 線性層,生成潛在表示# 判別器(Discriminator)self.discriminator = GRU(hidden_dim, hidden_dim, num_layers)  # 定義判別器,使用GRU網絡self.dis_output = nn.Linear(hidden_dim, 1)  # 線性層,輸出判別結果def forward(self, x):# 嵌入:將輸入數據映射到潛在空間h = self.embedder(x)  # 通過嵌入函數生成潛在表示h# 恢復:從潛在空間重建輸入數據x_tilde = self.recovery(h)  # 通過恢復函數重建數據return h, x_tildedef generate(self, z):# 生成器:從噪聲生成潛在表示h_hat = self.generator(z)  # 通過生成器生成潛在表示h_hat = self.gen_output(h_hat)  # 映射到潛在空間# 恢復:從生成的潛在表示生成數據x_hat = self.recovery(h_hat)  # 通過恢復函數生成合成數據return x_hat# 設置模型參數
input_dim = n_features  # 輸入維度(特征數)
hidden_dim = 24  # 隱藏層維度,參考論文中的設置
num_layers = 3  # GRU層數# 初始化TimeGAN模型
timegan = TimeGAN(input_dim, hidden_dim, num_layers)# 定義損失函數
mse_loss = nn.MSELoss()  # 均方誤差損失,用于自編碼和監督損失
bce_loss = nn.BCEWithLogitsLoss()  # 二元交叉熵損失,用于對抗損失# 定義優化器
optimizer_E = optim.Adam(list(timegan.embedder.parameters()) + list(timegan.recovery.parameters()), lr=0.001)  # 嵌入和恢復優化器
optimizer_G = optim.Adam(list(timegan.generator.parameters()) + list(timegan.gen_output.parameters()), lr=0.001)  # 生成器優化器
optimizer_D = optim.Adam(list(timegan.discriminator.parameters()) + list(timegan.dis_output.parameters()), lr=0.001)  # 判別器優化器# 訓練TimeGAN
num_epochs = 100  # 訓練輪數
batch_size = 128  # 批次大小
for epoch in range(num_epochs):for i in range(0, n_samples, batch_size):# 獲取批次數據batch_data = data[i:i + batch_size].cuda() if torch.cuda.is_available() else data[i:i + batch_size]# 訓練嵌入和恢復(自編碼損失)optimizer_E.zero_grad()  # 清零梯度h, x_tilde = timegan(batch_data)  # 前向傳播,得到潛在表示和重建數據recon_loss = mse_loss(x_tilde, batch_data)  # 計算自編碼損失recon_loss.backward()  # 反向傳播optimizer_E.step()  # 更新參數# 訓練判別器optimizer_D.zero_grad()  # 清零梯度real_h = timegan.embedder(batch_data)  # 嵌入真實數據real_logit = timegan.dis_output(timegan.discriminator(real_h))  # 判別真實數據z = torch.randn(batch_size, seq_len, hidden_dim)  # 生成隨機噪聲fake_x = timegan.generate(z)  # 生成合成數據fake_h = timegan.embedder(fake_x)  # 嵌入合成數據fake_logit = timegan.dis_output(timegan.discriminator(fake_h))  # 判別合成數據d_loss = bce_loss(real_logit, torch.ones_like(real_logit)) + bce_loss(fake_logit, torch.zeros_like(fake_logit))  # 計算判別器損失d_loss.backward()  # 反向傳播optimizer_D.step()  # 更新參數# 訓練生成器optimizer_G.zero_grad()  # 清零梯度z = torch.randn(batch_size, seq_len, hidden_dim)  # 生成隨機噪聲fake_x = timegan.generate(z)  # 生成合成數據fake_h = timegan.embedder(fake_x)  # 嵌入合成數據fake_logit = timegan.dis_output(timegan.discriminator(fake_h))  # 判別合成數據g_loss = bce_loss(fake_logit, torch.ones_like(fake_logit))  # 計算生成器對抗損失# 計算監督損失(使用真實數據監督生成器)supervised_loss = mse_loss(fake_x[:, 1:, :], fake_x[:, :-1, :])  # 監督損失:預測下一步total_g_loss = g_loss + 0.1 * supervised_loss  # 總生成器損失(對抗損失+監督損失)total_g_loss.backward()  # 反向傳播optimizer_G.step()  # 更新參數# 打印損失if (epoch + 1) % 10 == 0:print(f"Epoch {epoch+1}/{num_epochs}, Recon Loss: {recon_loss.item():.4f}, D Loss: {d_loss.item():.4f}, G Loss: {total_g_loss.item():.4f}")# 生成合成數據
with torch.no_grad():z = torch.randn(n_samples, seq_len, hidden_dim)  # 生成隨機噪聲synthetic_data_timegan = timegan.generate(z)  # 生成合成數據
synthetic_data_timegan = synthetic_data_timegan.cpu().numpy()  # 轉換為numpy數組# 定義RTSGAN模型
class RTSGAN(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers):super(RTSGAN, self).__init__()self.hidden_dim = hidden_dim  # 設置隱藏層維度self.input_dim = input_dim  # 設置輸入維度# 自回歸編碼器(Autoregressive Encoder)self.encoder = GRU(input_dim, hidden_dim, num_layers)  # 定義編碼器,使用GRU網絡self.encoder_output = nn.Linear(hidden_dim * seq_len, hidden_dim)  # 線性層,壓縮為固定維度潛在表示# 自回歸解碼器(Autoregressive Decoder)self.decoder = GRU(input_dim, hidden_dim, num_layers)  # 定義解碼器,使用GRU網絡self.decoder_output = nn.Linear(hidden_dim, input_dim)  # 線性層,生成輸出數據# 判別器(Critic for WGAN)self.critic = GRU(hidden_dim, hidden_dim, num_layers)  # 定義判別器,使用GRU網絡self.critic_output = nn.Linear(hidden_dim, 1)  # 線性層,輸出Wasserstein距離def encode(self, x):# 編碼:將輸入序列壓縮為固定維度潛在表示h = self.encoder(x)  # 通過編碼器生成隱藏狀態h = h.reshape(h.size(0), -1)  # 展平隱藏狀態r = self.encoder_output(h)  # 壓縮為固定維度潛在表示return rdef decode(self, r):# 解碼:從潛在表示生成序列batch_size = r.size(0)  # 獲取批次大小# 初始化解碼器輸入x_hat = torch.zeros(batch_size, seq_len, self.input_dim).to(r.device)  # 初始化生成序列h = r.unsqueeze(1).repeat(1, seq_len, 1)  # 將潛在表示擴展到序列長度# 自回歸解碼for t in range(seq_len):h_t = self.decoder(x_hat[:, :t+1, :])[:, -1, :]  # 獲取當前時間步的隱藏狀態x_hat[:, t, :] = self.decoder_output(h_t)  # 生成當前時間步的輸出return x_hatdef forward(self, x):# 前向傳播:編碼+解碼r = self.encode(x)  # 編碼x_tilde = self.decode(r)  # 解碼return r, x_tilde# 初始化RTSGAN模型
rtsgan = RTSGAN(input_dim, hidden_dim, num_layers)# 定義優化器
optimizer_E = optim.Adam(list(rtsgan.encoder.parameters()) + list(rtsgan.decoder.parameters()), lr=0.001)  # 編碼器和解碼器優化器
optimizer_G = optim.Adam(list(rtsgan.decoder.parameters()), lr=0.001)  # 生成器優化器
optimizer_C = optim.Adam(list(rtsgan.critic.parameters()) + list(rtsgan.critic_output.parameters()), lr=0.001)  # 判別器優化器# WGAN-GP的梯度懲罰
def gradient_penalty(critic, real, fake):alpha = torch.rand(real.size(0), 1, 1).to(real.device)  # 隨機插值系數interpolates = alpha * real + (1 - alpha) * fake  # 插值數據interpolates.requires_grad_(True)  # 允許計算梯度critic_inter = critic(interpolates)  # 判別插值數據gradients = torch.autograd.grad(outputs=critic_inter, inputs=interpolates,grad_outputs=torch.ones_like(critic_inter),create_graph=True, retain_graph=True)[0]  # 計算梯度gradients = gradients.view(gradients.size(0), -1)  # 展平梯度gradient_norm = gradients.norm(2, dim=1)  # 計算梯度范數gp = ((gradient_norm - 1) ** 2).mean()  # 計算梯度懲罰return gp# 訓練RTSGAN
lambda_gp = 10  # 梯度懲罰系數,參考論文
for epoch in range(num_epochs):for i in range(0, n_samples, batch_size):# 獲取批次數據batch_data = data[i:i + batch_size].cuda() if torch.cuda.is_available() else data[i:i + batch_size]# 訓練編碼器和解碼器(自編碼損失)optimizer_E.zero_grad()  # 清零梯度r, x_tilde = rtsgan(batch_data)  # 前向傳播,得到潛在表示和重建數據recon_loss = mse_loss(x_tilde, batch_data)  # 計算自編碼損失recon_loss.backward()  # 反向傳播optimizer_E.step()  # 更新參數# 訓練判別器(WGAN-GP)optimizer_C.zero_grad()  # 清零梯度real_r = rtsgan.encode(batch_data)  # 編碼真實數據real_h = rtsgan.critic(real_r.unsqueeze(1).repeat(1, seq_len, 1))  # 判別真實數據real_score = rtsgan.critic_output(real_h)  # 計算真實數據得分z = torch.randn(batch_size, hidden_dim)  # 生成隨機噪聲fake_x = rtsgan.decode(z)  # 生成合成數據fake_r = rtsgan.encode(fake_x)  # 編碼合成數據fake_h = rtsgan.critic(fake_r.unsqueeze(1).repeat(1, seq_len, 1))  # 判別合成數據fake_score = rtsgan.critic_output(fake_h)  # 計算合成數據得分gp = gradient_penalty(rtsgan.critic, real_r.unsqueeze(1).repeat(1, seq_len, 1), fake_r.unsqueeze(1).repeat(1, seq_len, 1))  # 計算梯度懲罰c_loss = fake_score.mean() - real_score.mean() + lambda_gp * gp  # WGAN-GP損失c_loss.backward()  # 反向傳播optimizer_C.step()  # 更新參數# 訓練生成器optimizer_G.zero_grad()  # 清零梯度z = torch.randn(batch_size, hidden_dim)  # 生成隨機噪聲fake_x = rtsgan.decode(z)  # 生成合成數據fake_r = rtsgan.encode(fake_x)  # 編碼合成數據fake_h = rtsgan.critic(fake_r.unsqueeze(1).repeat(1, seq_len, 1))  # 判別合成數據fake_score = rtsgan.critic_output(fake_h)  # 計算合成數據得分g_loss = -fake_score.mean()  # 生成器損失(WGAN)g_loss.backward()  # 反向傳播optimizer_G.step()  # 更新參數# 打印損失if (epoch + 1) % 10 == 0:print(f"Epoch {epoch+1}/{num_epochs}, Recon Loss: {recon_loss.item():.4f}, C Loss: {c_loss.item():.4f}, G Loss: {g_loss.item():.4f}")# 生成合成數據
with torch.no_grad():z = torch.randn(n_samples, hidden_dim)  # 生成隨機噪聲synthetic_data_rtsgan = rtsgan.decode(z)  # 生成合成數據
synthetic_data_rtsgan = synthetic_data_rtsgan.cpu().numpy()  # 轉換為numpy數組# 準備監督學習數據
def prepare_supervised_data(data, window_size):inputs, targets = [], []  # 初始化輸入和目標列表data = data.reshape(data.shape[0], -1)  # 展平數據,形狀為[n_samples, seq_len * n_features]for i in range(len(data) - window_size):inputs.append(data[i:i + window_size].flatten())  # 提取輸入窗口targets.append(data[i + window_size, -1])  # 提取目標(最后一個特征)return np.array(inputs), np.array(targets)# 準備真實數據和合成數據
window_size = 3  # 窗口大小,參考圖4
real_inputs, real_targets = prepare_supervised_data(data.numpy(), window_size)  # 準備真實數據
syn_inputs_timegan, syn_targets_timegan = prepare_supervised_data(synthetic_data_timegan, window_size)  # 準備TimeGAN合成數據
syn_inputs_rtsgan, syn_targets_rtsgan = prepare_supervised_data(synthetic_data_rtsgan, window_size)  # 準備RTSGAN合成數據# 混合真實和合成數據
mixed_inputs = np.concatenate([real_inputs, syn_inputs_rtsgan], axis=0)  # 混合真實和RTSGAN數據
mixed_targets = np.concatenate([real_targets, syn_targets_rtsgan], axis=0)  # 混合目標# 訓練GBRT模型
gbrt = GradientBoostingRegressor(n_estimators=100, random_state=42)  # 初始化GBRT模型
gbrt.fit(mixed_inputs, mixed_targets)  # 訓練GBRT模型
gbrt_pred = gbrt.predict(real_inputs)  # 使用真實數據預測
gbrt_mse = mean_squared_error(real_targets, gbrt_pred)  # 計算均方誤差
print(f"GBRT MSE: {gbrt_mse:.4f}")  # 打印均方誤差# 定義LSTM模型
class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers):super(LSTM, self).__init__()self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)  # 定義LSTM層self.fc = nn.Linear(hidden_dim, 1)  # 定義全連接層def forward(self, x):h0 = torch.zeros(num_layers, x.size(0), hidden_dim).to(x.device)  # 初始化隱藏狀態c0 = torch.zeros(num_layers, x.size(0), hidden_dim).to(x.device)  # 初始化單元狀態out, _ = self.lstm(x, (h0, c0))  # 前向傳播out = self.fc(out[:, -1, :])  # 取最后一個時間步輸出return out# 初始化LSTM模型
lstm = LSTM(input_dim=n_features * (window_size - 1), hidden_dim=hidden_dim, num_layers=num_layers)# 定義優化器和損失函數
optimizer = optim.Adam(lstm.parameters(), lr=0.001)  # 定義優化器
criterion = nn.MSELoss()  # 定義損失函數# 準備LSTM輸入數據
lstm_inputs = torch.FloatTensor(mixed_inputs.reshape(-1, window_size - 1, n_features))  # 重塑輸入數據
lstm_targets = torch.FloatTensor(mixed_targets)  # 轉換目標數據# 訓練LSTM模型
for epoch in range(num_epochs):optimizer.zero_grad()  # 清零梯度outputs = lstm(lstm_inputs)  # 前向傳播loss = criterion(outputs.squeeze(), lstm_targets)  # 計算損失loss.backward()  # 反向傳播optimizer.step()  # 更新參數if (epoch + 1) % 10 == 0:print(f"LSTM Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")  # 打印損失# 使用LSTM預測
with torch.no_grad():lstm_pred = lstm(torch.FloatTensor(real_inputs.reshape(-1, window_size - 1, n_features)))  # 預測
lstm_pred = lstm_pred.numpy()  # 轉換為numpy數組
lstm_mse = mean_squared_error(real_targets, lstm_pred)  # 計算均方誤差
print(f"LSTM MSE: {lstm_mse:.4f}")  # 打印均方誤差# 可視化預測結果
plt.figure(figsize=(10, 5))  # 設置畫布大小
plt.plot(real_targets[:100], label="True")  # 繪制真實值
plt.plot(gbrt_pred[:100], label="GBRT Pred")  # 繪制GBRT預測值
plt.plot(lstm_pred[:100], label="LSTM Pred")  # 繪制LSTM預測值
plt.legend()  # 顯示圖例
plt.title("Flood Forecasting Results")  # 設置標題
plt.show()  # 顯示圖像

結論

機器學習對洪水預報有著重要的作用,但極端洪水事件稀少、現場數據不足導致模型的性能并沒有充分發揮出來。TimeGAN生成的數據能夠保留洪水場景的復雜時空相關性,引入合成數據集也有助于提升機器學習預測精度。不過對LSTM的改進有限,TimeGAN對Transformer等深度學習模型的有效性也有待進一步探究。

不足以及展望

作者提出了本論文中的不足,作者認為TimeGAN對如LSTM等深度學習模型的改進效果有限,沒有發揮出實際的作用。未建立TimeGAN損失函數與下游預報性能的關聯性,缺乏了對用戶自定義參數敏感性分析。
關于對論文未來的展望,作者提出后續實驗可建立TimeGAN損失函數與下游預報性能的直接聯系。還需要考慮將天氣信息作為洪水預報模型的重要輸入變量,以提供長周期預報。此外,需要進一步探索生成的洪水場景在洪水風險管理中的各種應用。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/72599.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/72599.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/72599.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

linux固定IP并解決虛擬機無法ping其他電腦問題

linux固定IP并解決虛擬機無法ping其他電腦問題 1.找到網卡文件 vim /etc/sysconfig/network-scripts/ifcfg-ens33 2.編輯文件信息 BOOTPROTO 這個dhcp改為static#添加以下內容IPADDR<你的IP地址>NETMASK<子網掩碼>&#xff0c;例如255.255.255.0。GATEWAY<網…

Spring實戰spring-ai運行

目錄 1. 配置 2 .搭建項目 3. 查看對應依賴 3.1 OpenAI 依賴 3.2 配置 OpenAI API 密鑰 application.properties application.yml 4. openai實戰 5. 運行和測試 6. 高級配置 示例&#xff1a;配置模型和參數 解釋&#xff1a; 7. 處理異常和錯誤 示例&#xff1a;…

docker:配置 Docker 鏡像加速器

1 鏡像加速器介紹 默認情況下&#xff0c;將來從docker hub&#xff08;https://hub.docker.com/&#xff09;上下載docker鏡像&#xff0c;太慢。一般都會配置鏡像加速器&#xff1a; USTC&#xff1a;中科大鏡像加速器&#xff08;https://docker.mirrors.ustc.edu.cn&…

[內網安全] Windows 本地認證 — NTLM 哈希和 LM 哈希

關注這個專欄的其他相關筆記&#xff1a;[內網安全] 內網滲透 - 學習手冊-CSDN博客 0x01&#xff1a;SAM 文件 & Windows 本地認證流程 0x0101&#xff1a;SAM 文件簡介 Windows 本地賬戶的登錄密碼是存儲在系統本地的 SAM 文件中的&#xff0c;在登錄 Windows 的時候&am…

算法-圖-dijkstra 最短路徑

理論知識 dijkstra三部曲 樸素版dijkstra 模擬過程 堆優化版dijksra 經典模版例題 Dijkstra求最短路 I 參加科學大會&#xff08;第六期模擬筆試&#xff09;--模版題 網絡延遲 ref 理論知識 最短路是圖論中的經典問題即&#xff1a;給出一個有向圖&#xff0c;一…

Qt添加MySql數據庫驅動

文章目錄 一. 安裝MySql二.編譯mysql動態鏈接庫 Qt版本&#xff1a;5.14.2 MySql版本&#xff1a;8.0.41 一. 安裝MySql 參考這里進行安裝&#xff1a;https://blog.csdn.net/qq_30150579/article/details/146042922 將mysql安裝目錄里的bin&#xff0c;include和lib拷貝出來…

淺論數據庫聚合:合理使用LambdaQueryWrapper和XML

提示&#xff1a;文章寫完后&#xff0c;目錄可以自動生成&#xff0c;如何生成可參考右邊的幫助文檔 文章目錄 前言一、數據庫聚合替代內存計算&#xff08;關鍵優化&#xff09;二、批量處理優化四、區域特殊處理解耦五、防御性編程增強 前言 技術認知點&#xff1a;使用 XM…

Ubuntu 22.04安裝NVIDIA A30顯卡驅動

一、安裝前準備 1.禁用Nouveau驅動 Ubuntu默認使用開源Nouveau驅動&#xff0c;需要手動禁用&#xff1a; vim /etc/modprobe.d/blacklist-nouveau.conf # 添加以下內容&#xff1a; blacklist nouveau options nouveau modeset0 # 更新內核并重啟&#xff1a; update-initr…

Docker Desktop 4.38 安裝與配置全流程指南(Windows平臺)

一、軟件定位與特性 Docker Desktop 是容器化應用開發與部署的一體化工具&#xff0c;支持在本地環境創建、管理和運行Docker容器。4.38版本新增GPU加速支持、WSL 2性能優化和Kubernetes 1.28集群管理功能&#xff0c;適用于微服務開發、CI/CD流水線搭建等場景。 二、安裝環境…

音視頻入門基礎:RTP專題(15)——FFmpeg源碼中,獲取RTP的視頻信息的實現

一、引言 通過FFmpeg命令可以獲取到SDP文件描述的RTP流的視頻壓縮編碼格式、色彩格式&#xff08;像素格式&#xff09;、分辨率、幀率信息&#xff1a; ffmpeg -protocol_whitelist "file,rtp,udp" -i XXX.sdp 本文以H.264為例講述FFmpeg到底是從哪個地方獲取到這…

深度學習---卷積神經網絡

一、卷積尺寸計算公式 二、池化 池化分為最大池化和平均池化 最常用的就是最大池化&#xff0c;可以認為最大池化不需要引入計算&#xff0c;而平均池化需要引出計算&#xff08;計算平均數&#xff09; 每種池化還分為Pooling和AdaptiveAvgPool Pooling(2)就是每2*2個格子…

netty中Future和ChannelHandler

netty中的Future&#xff0c;繼承自 jdk中的Future&#xff0c;&#xff0c; jdk中的Future&#xff0c;很垃圾&#xff0c;只能同步阻塞獲取結果&#xff0c;&#xff0c;&#xff0c; netty中的Future進行了升級&#xff0c;&#xff0c;可以addListener()異步獲取結果&…

java 初學知識點總結

自己總結著玩 1.基本框架 public class HelloWorld{ public static void main(String[] args){ }//類名用大寫字母開頭 } 2.輸入&#xff1a; (1)Scanner:可讀取各種類型&#xff0c;字符串相當于cin>>; Scanner anew Scanner(System.in); Scan…

質量屬性場景描述

為了精確描述軟件系統的質量屬性&#xff0c;通常采用質量屬性場景&#xff08;Quality Attribute Scenario&#xff09;作為描述質量屬性的手段。質量屬性場景是一個具體的質量屬性需求&#xff0c;使利益相關者與系統的交互的簡短陳述。 質量屬性場景是一種用于描述系統如何…

數據可攜帶權的多重價值與實踐思考

文章目錄 前言一、數據可攜帶權的提出與立法二、數據可攜帶權的多重價值1、推動數據要素市場化配置2、促進市場競爭與創新3、強化個人數據權益 三、數據可攜帶權的實踐挑戰1、數據安全與隱私保護面臨風險2、接口差異導致數據遷移成本高昂3、可攜帶的數據范圍尚存爭議 數據可攜帶…

藍橋每日打卡--分考場

#藍橋#JAVA#分考場 題目描述 n個人參加某項特殊考試。 為了公平&#xff0c;要求任何兩個認識的人不能分在同一個考場。 求是少需要分幾個考場才能滿足條件。 輸入描述 輸入格式&#xff1a; 第一行&#xff0c;一個整數n(1≤n≤100)&#xff0c;表示參加考試的人數。 …

RMAN備份bug-審計日志暴漲(select action from gv$session)

問題概述 /oracle 文件系統使用率過大&#xff0c;經過檢查是審計日志過大,/oracle 目錄 197G 審計日志占用70G&#xff0c;每6個小時產生大量審計日志&#xff0c;日志內容全是select action from gv$session &#xff0c;猜測可能跟備份有關&#xff0c; $>df -h /oracle…

在Blender中給SP分紋理組

在Blender中怎么分SP的紋理組/紋理集 其實紋理組就是材質 把同一組的材質分給同一組的模型 導入到sp里面自然就是同一個紋理組 把模型導入SP之后 就自動分好了

Nuxt:Nuxt3框架中onBeforeMount函數 和onBeforeRouteUpdate函數區別介紹 【超詳細!】

提示&#xff1a;在 Nuxt3 中&#xff0c;onBeforeMount 和 onBeforeRouteUpdate 是兩個不同場景下使用的鉤子函數&#xff0c;分別對應 Vue 組件生命周期 和 路由導航守衛。以下是它們的詳細解釋和對比&#xff1a; 文章目錄 一、onBeforeMount&#xff08;Vue 生命周期鉤子&a…

華為 Open Gauss 數據庫在 Spring Boot 中使用 Flyway

db-migration&#xff1a;Flyway、Liquibase 擴展支持達夢&#xff08;DM&#xff09;、南大通用&#xff08;GBase 8s&#xff09;、OpenGauss 等國產數據庫。部分數據庫直接支持 Flowable 工作流。 開源代碼倉庫 Github&#xff1a;https://github.com/mengweijin/db-migrat…