信號處理學習——文獻精讀與code復現之TFN——嵌入時頻變換的可解釋神經網絡(下)

書接上文:

信號處理學習——文獻精讀與code復現之TFN——嵌入時頻變換的可解釋神經網絡(上)-CSDN博客

接下來是重要的代碼復現!!!GitHub - ChenQian0618/TFN: this is the open code of paper entitled "TFN: An Interpretable Neural Network With Time Frequency Transform Embedded for Intelligent Fault Diagnosis".



一. 準備工作

因為我的論文中所使用的數據的樣本量是2048,而不是TFN文獻中的1024,所以有些地方需要調整一下。

先看TFN-main\Models\BackboneCNN.py中的代碼,查看是否需要調整。(不需要)

再看TFN-main\Models\TFconvlayer.py中的代碼,也同樣的(不需要)。

具體可去github上查看作者大大們的完整代碼。

模塊是否依賴輸入長度?原因
TFconv_* 類中的 forward? 不依賴它接受的輸入是 [B, C, L],不限制 L(你的2048沒問題)
weightforward()? 不依賴只與 kernel_sizesuperparams 有關,與你輸入的信號長度無關
AdaptiveMaxPool1d in CNN? 不依賴輸入長度自動調整為固定輸出維度,兼容任意長度
T = torch.arange(...)??但與 kernel_size 相關這個是內部卷積核構造,與輸入數據 2048 無關

二. 設置

參考文獻中的部分設置,分類損失函數采用交叉熵,訓練優化器選用Adam,動量參數設為0.9,初始學習率為0.001,總訓練周期為50次。總共重復10次平均實驗。

三.Code

1.數據劃分

這里用到的數據集是比CWRU要稍微難識別一些的變速軸承信號數據集,加拿大渥太華數據集。

import scipy.io as sio
import numpy as np
import random
import os# 定義基礎路徑
base_path = 'D:/0A_Gotoyourdream/00_BOSS_WHQ/A_Code/A_Data/'# 定義各類別對應的mat文件
file_mapping = {'H': 'H-B-1.mat','I': 'I-B-1.mat','B': 'B-B-1.mat','O': 'O-B-2.mat','C': 'C-B-2.mat'
}# 定義每個類別需要抽取的數量
sample_limit = {'H': 200,'I': 200,'B': 200,'O': 200,'C': 200
}# 保存最終數據
X_list = []
y_list = []# 固定參數
fs = 200000
window_size = 2048
step_size = int(fs * 0.015)  # 步長 0.015秒# 類別編碼
#label_mapping = {'H': 0, 'I': 1, 'B': 3, 'O': 2, 'C': 4}  # 注意和你之前保持一致label_mapping = {'H': 0, 'I': 1, 'O': 2, 'B': 3, 'C': 4}
inverse_mapping = {v: k for k, v in label_mapping.items()}
labels = [inverse_mapping[i] for i in range(len(inverse_mapping))]
# 再替換這些縮寫為全名
label_fullnames = {'H': 'Health','I': 'F_Inner','O': 'F_Outer','B': 'F_Ball','C': 'F_Combined'
}
labels = [label_fullnames[c] for c in labels]# 創建保存目錄(可選)
output_dir = os.path.join(base_path, "ClassBD-Processed_Samples")
os.makedirs(output_dir, exist_ok=True)# 遍歷每一類數據
for label_name, file_name in file_mapping.items():print(f"正在處理類別 {label_name}...")mat_path = os.path.join(base_path, file_name)dataset = sio.loadmat(mat_path)# 提取振動信號并去直流分量vib_data = np.array(dataset["Channel_1"].flatten().tolist()[:fs * 10])vib_data = vib_data - np.mean(vib_data)# 滑窗切分樣本vib_samples = []start = 0while start + window_size <= len(vib_data):sample = vib_data[start:start + window_size].astype(np.float32)  # 降低內存占用vib_samples.append(sample)start += step_sizevib_samples = np.array(vib_samples)print(f"共切分得到 {vib_samples.shape[0]} 個樣本")# 抽樣if vib_samples.shape[0] < sample_limit[label_name]:raise ValueError(f"類別 {label_name} 樣本不足(僅 {vib_samples.shape[0]}),無法抽取 {sample_limit[label_name]} 個")selected_indices = random.sample(range(vib_samples.shape[0]), sample_limit[label_name])selected_X = vib_samples[selected_indices]selected_y = np.full(sample_limit[label_name], label_mapping[label_name], dtype=np.int64)# 保存save_path_X = os.path.join(output_dir, f"X_{label_name}.mat")save_path_y = os.path.join(output_dir, f"y_{label_name}.mat")sio.savemat(save_path_X, {'X': selected_X})sio.savemat(save_path_y, {'y': selected_y})print(f"已保存類別 {label_name} 的數據:{save_path_X}, {save_path_y}")

2. 存儲為dataloder

import os
import scipy.io as sio
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader# ========== 1. 讀取四類數據 ==========
base_path = "D:/0A_Gotoyourdream/00_BOSS_WHQ/A_Code/A_Data/ClassBD-Processed_Samples"def load_data(label):X = sio.loadmat(os.path.join(base_path, f"X_{label}.mat"))["X"]y = sio.loadmat(os.path.join(base_path, f"y_{label}.mat"))["y"].flatten()return X.astype(np.float32), y.astype(np.int64)X_H, y_H = load_data("H")
X_I, y_I = load_data("I")
X_B, y_B = load_data("B")
X_O, y_O = load_data("O")
X_C, y_C = load_data("C")# ========== 2. 合并數據 + reshape ==========
X_all = np.concatenate([X_H, X_I, X_B, X_O, X_C], axis=0)
y_all = np.concatenate([y_H, y_I, y_B, y_O, y_C], axis=0)
X_all = X_all[:, np.newaxis, :] ?# (N, 1, 200000)# ========== 3. 劃分訓練/測試集 ==========
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.4, stratify=y_all, random_state=42)# ========== 4. DataLoader ==========
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

3. 定義模型及一些設置

需要注意的部分在代碼的注釋中有寫

from Models.TFN import TFN_STTF ?# 你也可以換成 TFN_Chirplet、TFN_Morlet
model = TFN_STTF(in_channels=1, out_channels=5, kernel_size=15) ?# out_channels = 類別數device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

4. 訓練及測試

# 開始訓練
for epoch in range(1, 51):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# TFN模型支持返回多個輸出(output, _, _)outputs = model(inputs)if isinstance(outputs, tuple):outputs = outputs[0]loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()#scheduler.step()train_acc = correct / total * 100# 測試集評估model.eval()correct_test = 0total_test = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)if isinstance(outputs, tuple):outputs = outputs[0]_, predicted = outputs.max(1)total_test += labels.size(0)correct_test += predicted.eq(labels).sum().item()test_acc = correct_test / total_test * 100print(f"Epoch {epoch:03d}: Loss={running_loss:.4f}, Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

四. 結果

代碼復現成功!!!

接著后面就是拿來做對比實驗啦~~~

(感恩大佬們提供github代碼!!!)

?

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/89398.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/89398.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/89398.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

線上故障排查:簽單合同提交報錯分析-對接e簽寶

在企業管理系統中&#xff0c;合同生成與簽署環節至關重要&#xff0c;尤其是在使用第三方平臺進行電子簽署時。本文將通過實際的報錯信息&#xff0c;分析如何進行線上故障排查&#xff0c;解決合同生成過程中出現的問題。 #### 1. 錯誤描述 在嘗試生成合同并提交至電子簽署…

知攻善防靶機 Linux easy溯源

知攻善防 【護網訓練-Linux】應急響應靶場-Easy溯源 小張是個剛入門的程序猿&#xff0c;在公司開發產品的時候突然被叫去應急&#xff0c;小張心想"早知道簡歷上不寫會應急了"&#xff0c;于是call了運維小王的電話&#xff0c;小王說"你面試的時候不是說會應急…

原神八分屏角色展示頁面(純前端html,學習交流)

原神八分屏角色展示頁面 - 一個精美的前端交互項目 項目簡介 這是一個基于原神游戲角色制作的八分屏展示頁面&#xff0c;采用純前端技術實現&#xff0c;包含了豐富的動畫效果、音頻交互和視覺設計。項目展示了一些熱門原神角色&#xff0c;每個角色都有獨立的介紹頁面和專屬…

華為認證二選一:物聯網 VS 人工智能,你的賽道在哪里?

一篇不講情懷只講干貨的科普指南 一、華為物聯網 & 人工智能到底在搞什么&#xff1f; 華為物聯網&#xff08;IoT&#xff09; 的核心是 “萬物互聯”。 通過傳感器、通信技術&#xff08;如NB-IoT/5G&#xff09;、云計算平臺&#xff08;如OceanConnect&#xff09;&…

CloudLens for PolarDB:解鎖數據庫性能優化與智能運維的終極指南

隨著企業數據規模的爆炸式增長,數據庫性能管理已成為技術團隊的關鍵挑戰。本文深入探討如何利用CloudLens for PolarDB實現高級監控、智能診斷和自動化運維,幫助您構建一個自我修復、高效運行的數據庫環境。 引言:數據庫監控的演進 在云原生時代,傳統的數據庫監控方式已不…

MySQL中TINYINT/INT/BIGINT的典型應用場景及實例

以下是MySQL中TINYINT/INT/BIGINT的典型應用場景及實例說明&#xff1a; 一、TINYINT&#xff08;1字節&#xff09; 1.狀態標識 -- 用戶激活狀態&#xff08;0未激活/1已激活&#xff09; ALTER TABLE users ADD is_active TINYINT(1) DEFAULT 0; 適用于布爾值存儲和狀態碼…

YOLOv13:最新的YOLO目標檢測算法

[2506.17733] YOLOv13: Real-Time Object Detection with Hypergraph-Enhanced Adaptive Visual Perception Github: https://github.com/iMoonLab/yolov13 YOLOv13&#xff1a;利用超圖增強型自適應視覺感知進行實時物體檢測 主要的創新點提出了HyperACE機制、FullPAD范式、輕…

【深入淺出:計算流體力學(CFD)基礎與核心原理--從NS方程到工業仿真實踐】

關鍵詞&#xff1a;#CFD、#Navier-Stokes方程、#有限體積法、#湍流模型、#網格收斂性、#工業仿真驗證 一、CFD是什么&#xff1f;為何重要&#xff1f; 計算流體力學&#xff08;Computational Fluid Dynamics, CFD&#xff09; 是通過數值方法求解流體流動控制方程&#xff0…

qt常用控件--04

文章目錄 qt常用控件labelLCD NumberProgressBar結語 很高興和大家見面&#xff0c;給生活加點impetus&#xff01;&#xff01;開啟今天的編程之路&#xff01;&#xff01; 今天我們進一步c11中常見的新增表達 作者&#xff1a;?( ‘ω’ )?260 我的專欄&#xff1a;qt&am…

Redmine:一款基于Web的開源項目管理軟件

Redmine 是一款基于 Ruby on Rails 框架開發的開源、跨平臺、基于 Web 的項目管理、問題跟蹤和文檔協作軟件。 Redmine 官方網站自身就是基于它構建的一個 Web 應用。 功能特性 Redmine 的主要特點和功能包括&#xff1a; 多項目管理&#xff1a; Redmine 可以同時管理多個項…

FPGA FMC 接口

1 FMC 介紹 FMC 接口即 FPGA Mezzanine Card 接口,中文名為 FPGA 中間層板卡接口。以下是對它的詳細介紹: 標準起源:2008 年 7 月,美國國家標準協會(ANSI)批準和發布了 VITA 57 FMC 標準。該標準由從 FPGA 供應商到最終用戶的公司聯盟開發,旨在為位于基板(載卡)上的 …

C++中std::atomic_bool詳解和實戰示例

std::atomic_bool 是 C 標準庫中提供的一種 原子類型&#xff0c;用于在多線程環境下對布爾值進行 線程安全的讀寫操作&#xff0c;避免使用 std::mutex 帶來的性能開銷。 1. 基本作用 在多線程環境中&#xff0c;多個線程同時訪問一個 bool 類型變量可能會出現 競態條件&…

深度學習之分類手寫數字的網絡

面臨的問題 定義神經?絡后&#xff0c;我們回到?寫識別上來。我們可以把識別?寫數字問題分成兩個?問題&#xff1a; 把包含許多數字的圖像分成?系列單獨的圖像&#xff0c;每個包含單個數字&#xff1b; 也就是把圖像 &#xff0c;分成6個單獨的圖像 分類單獨的數字 我們將…

nginx基本使用 linux(mac下的)

目錄結構 編譯后會有&#xff1a;conf html logs sbin 四個文件 &#xff08;其他兩個是之前下載的安裝包&#xff09; conf&#xff1a;配置文件html&#xff1a;頁面資源logs&#xff1a;日志sbin&#xff1a;啟動文件&#xff0c;nginx主程序 運行后多了文件&#xff1a;&l…

基于大眾點評的重慶火鍋在線評論數據挖掘分析(情感分析、主題分析、EDA探索性數據分析)

文章目錄 有需要本項目的代碼或文檔以及全部資源&#xff0c;或者部署調試可以私信博主項目介紹數據采集數據預處理EDA探索性數據分析關鍵詞提取算法情感分析LDA主題分析總結每文一語 有需要本項目的代碼或文檔以及全部資源&#xff0c;或者部署調試可以私信博主 項目介紹 本…

鴻蒙系統(HarmonyOS)應用開發之經典藍色風格登錄頁布局、圖文驗證碼

一、項目概述 本項目是一款基于鴻蒙 ArkTS&#xff08;ETS&#xff09;開發的用戶登錄頁面&#xff0c;集成了圖文驗證碼功能&#xff0c;旨在為應用提供安全、便捷的用戶身份驗證入口。項目采用現代化 UI 設計&#xff0c;兼顧用戶體驗與安全性&#xff0c;適用于多種需要用戶…

0.96寸OLED顯示屏 江協科技學習筆記(36個知識點)

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 32 33 34 35 36

Flutter SnackBar 控件詳細介紹

文章目錄 Flutter SnackBar 控件詳細介紹基本特性基本用法1. 顯示簡單 SnackBar2. 自定義持續時間 主要屬性高級用法1. 帶操作的 SnackBar2. 自定義樣式3. 浮動式 SnackBar SnackBarAction 屬性實際應用場景注意事項完整示例建議 Flutter SnackBar 控件詳細介紹 SnackBar 是 F…

【C++】頭文件的能力與禁忌

在C中&#xff0c;?頭文件&#xff08;.h/.hpp&#xff09;?? 的主要作用是聲明接口和共享代碼&#xff0c;但如果不規范使用&#xff0c;會導致編譯或鏈接錯誤。以下是詳細總結&#xff1a; 一、頭文件中可以做的事情 1.1 聲明 函數聲明&#xff08;無需inline&#xff…

騰訊 iOA 零信任產品:安全遠程訪問的革新者

在當今數字化時代&#xff0c;企業面臨著前所未有的挑戰與機遇。隨著遠程辦公、多分支運營以及云計算的廣泛應用&#xff0c;傳統的網絡安全架構逐漸暴露出諸多不足。騰訊 iOA 零信任產品憑借其創新的安全理念和強大的功能特性&#xff0c;為企業提供了一種全新的解決方案&…