時序預測 | Pytorch實現CNN-KAN電力負荷時間序列預測模型

預測效果

在這里插入圖片描述

代碼功能

該代碼實現了一個結合卷積神經網絡(CNN)和Kolmogorov–Arnold網絡(KAN)的混合模型(CNN-KAN),用于時間序列預測任務。核心功能包括:

  1. 數據加載與預處理:加載標準化后的訓練集和測試集(時間序列數據)。
  2. 模型構建
    • CNN部分:提取時間序列的局部特征(使用1D卷積層和池化層)。
    • KAN部分:替代全連接層,通過樣條基函數增強非線性擬合能力,提高預測精度。
  3. 模型訓練與評估:使用MSE損失和Adam優化器訓練模型,保存最佳模型參數,并在測試集上計算評估指標(MSE、RMSE、MAE、R2)。
  4. 結果可視化:繪制訓練/測試損失曲線,并反歸一化預測結果。

算法步驟

  1. 數據加載

    • 使用joblib加載預處理后的訓練集(train_set, train_label)和測試集(test_set, test_label)。
    • 封裝為DataLoader(批量大小=64)。
  2. 模型定義

    • KANLinear
      • 基礎線性變換 + 樣條基函數(B-splines)的非線性變換。
      • 支持動態網格更新和正則化損失計算。
    • CNN1DKANModel
      • 卷積塊:多個Conv1d + ReLU + MaxPool1d層(參考VGG架構)。
      • 自適應平均池化:替代全連接層,減少參數量。
      • KAN輸出層:生成最終預測結果。
  3. 模型訓練

    • 損失函數:均方誤差(MSELoss)。
    • 優化器:Adam(學習率=0.0003)。
    • 訓練循環
      • 前向傳播 → 計算損失 → 反向傳播 → 參數更新。
      • 記錄每個epoch的訓練/測試MSE,保存最佳模型(最低測試MSE)。
  4. 模型評估

    • 加載最佳模型進行預測。
    • 計算指標:R2(模型擬合優度)、MSERMSEMAE
    • 反歸一化預測結果(使用預訓練的StandardScaler)。
  5. 可視化

    • 繪制訓練/測試MSE隨epoch的變化曲線。
    • 輸出評估指標和反歸一化后的結果。

技術路線

  1. 框架:PyTorch(模型構建、訓練、評估)。
  2. 數據預處理:使用StandardScaler標準化數據(通過joblib保存/加載)。
  3. 模型架構
    • 特征提取:CNN(1D卷積層)捕獲時間序列局部模式。
    • 非線性映射:KAN層替代傳統全連接層,通過樣條函數靈活擬合復雜關系。
  4. 評估指標sklearn計算R2MSE等。
  5. 可視化matplotlib繪制損失曲線。

關鍵參數設定

參數說明
batch_size64數據批量大小
epochs50訓練輪數
learn_rate0.0003Adam優化器學習率
conv_archs((2, 32), (2, 64))CNN層配置(卷積層數×通道數)
grid_size5KAN樣條網格大小
spline_order3樣條多項式階數
output_dim1預測輸出維度(回歸任務)

運行環境

  • Python庫
    torch, joblib, numpy, pandas, sklearn, matplotlib
    
  • 硬件:支持CUDA的GPU(優先)或CPU(自動切換)。
  • 數據依賴
    • 預處理的訓練/測試集文件(train_set, train_label等)。
    • 預訓練的StandardScalerscaler文件)。

應用場景

  1. 時間序列預測
    • 如股票價格、氣象數據、電力負荷等序列數據的未來值預測。
  2. 高非線性關系建模
    • KAN層通過樣條基函數靈活擬合復雜非線性模式,優于傳統全連接層。
  3. 輕量化模型需求
    • 自適應池化替代全連接層,減少參數量(模型總參數量:22,432)。
  4. 研究驗證
    • 探索CNN與KAN結合的混合架構在預測任務中的有效性(最終R2=0.995,擬合優度高)。

補充說明

  • 創新點:KAN作為輸出層,通過動態網格更新和正則化約束(L1 + 熵),增強模型表達能力。
  • 性能:50個epoch后測試集MSE=0.2627(反歸一化后MSE=0.0041),預測精度高。
  • 擴展性:可通過調整卷積架構、KAN參數適配不同時間序列長度和復雜度。

完整代碼

  • 完整代碼訂閱專欄獲取
# 模型預測
# 模型 測試集 驗證  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 模型加載
model = torch.load('best_model_cnn_kan.pt')
model = model.to(device)# 預測數據
original_data = []
pre_data = []
with torch.no_grad():for data, label in test_loader:origin_lable = label.tolist()original_data += origin_lablemodel.eval()  # 將模型設置為評估模式data, label = data.to(device), label.to(device)# 預測test_pred = model(data)  # 對測試集進行預測test_pred = test_pred.tolist()pre_data += test_pred
[8]
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score# 模型分數
score = r2_score(original_data, pre_data)
print('*'*50)
print('模型分數--R^2:',score)print('*'*50)
# 測試集上的預測誤差
test_mse = mean_squared_error(original_data, pre_data)
test_rmse = np.sqrt(test_mse)
test_mae = mean_absolute_error(original_data, pre_data)
print('測試數據集上的均方誤差--MSE: ',test_mse)
print('測試數據集上的均方根誤差--RMSE: ',test_rmse)
print('測試數據集上的平均絕對誤差--MAE: ',test_mae)
**************************************************
模型分數--R^2: 0.9954956071920047
**************************************************
測試數據集上的均方誤差--MSE:  0.004104453060426307
測試數據集上的均方根誤差--RMSE:  0.06406600549766082
測試數據集上的平均絕對誤差--MAE:  0.047805079976603375[19]
from sklearn.preprocessing import StandardScaler, MinMaxScaler# 將列表轉換為 NumPy 數組
original_data = np.array(original_data)
pre_data = np.array(pre_data)# 反歸一化處理
# 使用相同的均值和標準差對預測結果進行反歸一化處理
# 反標準化
scaler  = load('scaler')
original_data = scaler.inverse_transform(original_data)
pre_data = scaler.inverse_transform(pre_data)
[20]
# 可視化結果
plt.figure(figsize=(12, 6), dpi=100)
plt.plot(original_data, label='原始值',color='orange')  # 真實值
plt.plot(pre_data, label='CNN-KAN預測值',color='green')  # 預測值
plt.legend()
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/914377.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/914377.shtml
英文地址,請注明出處:http://en.pswp.cn/news/914377.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

UI前端與數字孿生結合實踐探索:智慧物流的倉儲優化與管理系統

hello寶子們...我們是艾斯視覺擅長ui設計和前端數字孿生、大數據、三維建模、三維動畫10年經驗!希望我的分享能幫助到您!如需幫助可以評論關注私信我們一起探討!致敬感謝感恩!一、引言:倉儲管理的 “數字孿生革命”傳統物流倉儲正面臨 “效率瓶頸、可視化差、響應滯…

【Android】在平板上實現Rs485的數據通訊

前言 在工業控制領域,Android 設備通過 RS485 接口與 PLC(可編程邏輯控制器)通信是一種常見的技術方案。最近在實現一個項目需要和plc使用485進行通訊,記錄下實現的方式。 我這邊使用的從平的Android平板,從平里面已經…

MySQL技術筆記-備份與恢復完全指南

目錄 前言 一、備份概述 (一)備份方式 (二)備份策略 二、物理備份及恢復 (一)備份操作 (二)恢復操作 三、邏輯備份及恢復 (一)邏輯備份 &#xff0…

SpringBoot或OpenFeign中 Jackson 配置參數名蛇形、小駝峰、大駝峰、自定義命名

SpringBoot或OpenFeign中 Jackson 配置參數名蛇形、小駝峰、大駝峰、自定義命名 前言 在調用外部接口時,對方給出的接口文檔中,入參參數名一會大寫加下劃線,一會又是駝峰命名。 示例如下: {"MOF_DIV_CODE": "xx…

uni-app 途徑站點組件開發與實現分享

在移動應用開發中,涉及到出行、物流等場景時,途徑站點的展示是一個常見的需求。本文將為大家分享一個基于 uni-app 開發的途徑站點組件,該組件能夠清晰展示路線中的各個站點信息,包括站點名稱、到達時間、是否已到達等狀態&#x…

kotlin中集合的用法

從一個實際應用看起以下kotlin中代碼語法正確嗎 var testBeanAIP0200()var testList:List<AIP0200> ArrayList()testList.add(testBean)這段Kotlin代碼存在語法錯誤&#xff0c;主要問題在于&#xff1a;List<AIP0200> 是Kotlin中的不可變集合接口&#xff0c;不能…

深入理解 Java Map 與 Set

文章目錄前言1. 搜索樹1.1 什么是搜索樹1.2 查找1.3 插入1.4 刪除情況一&#xff1a;cur 沒有子節點&#xff08;即為葉子節點&#xff09;情況二&#xff1a;cur 只有一個子節點&#xff08;只有左子樹或右子樹&#xff09;情況三&#xff1a;cur 有兩個子節點&#xff08;左右…

excel如何只保留前幾行

方法一&#xff1a;手動刪除多余行 選中你想保留的最后一行的下一行&#xff08;比如你只保留前10行&#xff0c;那選第11行&#xff09;。按住 Shift Ctrl ↓&#xff08;Windows&#xff09;或 Shift Command ↓&#xff08;Mac&#xff09;&#xff0c;選中從第11行到最…

實時連接,精準監控:風丘科技數據遠程顯示方案提升試驗車隊管理效率

風丘科技推出的數據遠程實時顯示方案更好地滿足了客戶對于試驗車隊遠程實時監控的需求&#xff0c;并真正實現了試驗車隊的遠程管理。隨著新的數據記錄儀軟件IPEmotion RT和相應的跨平臺顯示解決方案的引入&#xff0c;讓我們的客戶端不僅可在線訪問記錄器系統狀態&#xff0c;…

灰盒級SOA測試工具Parasoft SOAtest重新定義端到端測試

還在為脆弱的測試環境、強外部依賴和低效的測試復用拖慢交付而頭疼&#xff1f;尤其在銀行、醫療、制造等關鍵領域&#xff0c;傳統的端到端測試常因環境不穩、接口難模擬、用例難共享而舉步維艱。 灰盒級SOA測試工具Parasoft SOAtest以可視化編排簡化復雜測試流程&#xff0c…

OKHttp 核心知識點詳解

OKHttp 核心知識點詳解 一、基本概念與架構 1. OKHttp 簡介 類型&#xff1a;高效的HTTP客戶端特點&#xff1a; 支持HTTP/2和SPDY&#xff08;多路復用&#xff09;連接池減少請求延遲透明的GZIP壓縮響應緩存自動恢復網絡故障2. 核心組件組件功能OkHttpClient客戶端入口&#…

從“被動巡檢”到“主動預警”:塔能物聯運維平臺重構路燈管理模式

從以往的‘被動巡檢’轉變至如今的‘主動預警’&#xff0c;塔能物聯運維平臺對路燈管理模式展開了重新構建。城市路燈屬于極為重要的市政基礎設施范疇&#xff0c;它的實際運行狀態和市民出行安全以及城市形象有著直接且緊密的關聯。不過呢&#xff0c;傳統的路燈管理模式當下…

10. 常見的 http 狀態碼有哪些

總結 1xx: 正在處理2xx: 成功3xx: 重定向&#xff0c;302 重定向&#xff0c;304 協商緩存4xx: 客戶端錯誤&#xff0c;401 未登錄&#xff0c;403 沒權限&#xff0c;404 資源不存在5xx: 服務器錯誤常見的 HTTP 狀態碼詳解 HTTP 狀態碼&#xff08;HTTP Status Code&#xff0…

springBoot對接第三方系統

yml文件 yun:ip: port: username: password: controller package com.ruoyi.web.controller.materials;import com.ruoyi.common.core.controller.BaseController; import com.ruoyi.common.core.domain.AjaxResult; import com.ruoyi.materials.service.IYunService; import o…

【PTA數據結構 | C語言版】車廂重排

本專欄持續輸出數據結構題目集&#xff0c;歡迎訂閱。 文章目錄題目代碼題目 一列掛有 n 節車廂&#xff08;編號從 1 到 n&#xff09;的貨運列車途徑 n 個車站&#xff0c;計劃在行車途中將各節車廂停放在不同的車站。假設 n 個車站的編號從 1 到 n&#xff0c;貨運列車按照…

量子計算能為我們做什么?

科技公司正斥資數十億美元投入量子計算領域&#xff0c;盡管這項技術距離實際應用還有數年時間。那么&#xff0c;未來的量子計算機將用于哪些方面&#xff1f;為何眾多專家堅信它們會帶來顛覆性變革&#xff1f; 自 20 世紀 80 年代起&#xff0c;打造一臺利用量子力學獨特性質…

BKD 樹(Block KD-Tree)Lucene

BKD 樹&#xff08;Block KD-Tree&#xff09;是 Lucene 用來存儲和快速查詢 **多維數值型數據** 的一種磁盤友好型數據結構&#xff0c;可以把它想成&#xff1a;> **“把 KD-Tree 分塊壓縮后落到磁盤上&#xff0c;既能做磁盤順序讀&#xff0c;又能像內存 KD-Tree 一樣做…

【Mysql作業】

第一次作業要求1.首先打開Windows PowerShell2.連接到MYSQL服務器3.執行以下SQL語句&#xff1a;-- 創建數據庫 CREATE DATABASE mydb6_product;-- 使用數據庫 USE mydb6_product;-- 創建employees表 CREATE TABLE employees (id INT PRIMARY KEY,name VARCHAR(50) NOT NULL,ag…

(C++)STL:list認識與使用全解析

本篇基于https://cplusplus.com/reference/list/list/講解 認識 list是一個帶頭結點的雙向循環鏈表翻譯總結&#xff1a; 序列容器&#xff1a;list是一種序列容器&#xff0c;允許在序列的任何位置進行常數時間的插入和刪除操作。雙向迭代&#xff1a;list支持雙向迭代&#x…

Bash函數詳解

目錄**1. 基礎函數****2. 參數處理函數****3. 文件操作函數****4. 日志與錯誤處理****5. 實用工具函數****6. 高級函數技巧****7. 常用函數庫示例****總結&#xff1a;Bash 函數核心要點**1. 基礎函數 1.1 定義與調用 可以自定義函數名稱&#xff0c;例如將greet改為yana。?…