Pytorch-Transformer軸承故障一維信號分類(三)

目錄

前言

1?數據集制作與加載

1.1 導入數據

第一步,導入十分類數據

第二步,讀取MAT文件驅動端數據

第三步,制作數據集

第四步,制作訓練集和標簽

1.2 數據加載,訓練數據、測試數據分組,數據分batch

2 Transformer分類模型和超參數選取

2.1 定義Transformer分類模型,采用Transformer架構中的編碼器:

2.2 定義模型參數

2.3 模型結構

3 Transformer模型訓練與評估

3.1?模型訓練

3.2 模型評估


往期精彩內容:

Python-凱斯西儲大學(CWRU)軸承數據解讀與分類處理

Python軸承故障診斷 (一)短時傅里葉變換STFT

Python軸承故障診斷 (二)連續小波變換CWT

Python軸承故障診斷 (三)經驗模態分解EMD

Python軸承故障診斷 (四)基于EMD-CNN的故障分類

Python軸承故障診斷 (五)基于EMD-LSTM的故障分類

Pytorch-LSTM軸承故障一維信號分類(一)

Pytorch-CNN軸承故障一維信號分類(二)

前言

本文基于凱斯西儲大學(CWRU)軸承數據,先經過數據預處理進行數據集的制作和加載,最后通過Pytorch實現Transformer模型對故障數據的分類,并介紹Transformer模型的超參數。凱斯西儲大學軸承數據的詳細介紹可以參考下文:

Python-凱斯西儲大學(CWRU)軸承數據解讀與分類處理

1?數據集制作與加載

1.1 導入數據

參考之前的文章,進行故障10分類的預處理,凱斯西儲大學軸承數據10分類數據集:

第一步,導入十分類數據

import numpy as np
import pandas as pd
from scipy.io import loadmatfile_names = ['0_0.mat','7_1.mat','7_2.mat','7_3.mat','14_1.mat','14_2.mat','14_3.mat','21_1.mat','21_2.mat','21_3.mat']for file in file_names:# 讀取MAT文件data = loadmat(f'matfiles\\{file}')print(list(data.keys()))

第二步,讀取MAT文件驅動端數據

# 采用驅動端數據
data_columns = ['X097_DE_time', 'X105_DE_time', 'X118_DE_time', 'X130_DE_time', 'X169_DE_time','X185_DE_time','X197_DE_time','X209_DE_time','X222_DE_time','X234_DE_time']
columns_name = ['de_normal','de_7_inner','de_7_ball','de_7_outer','de_14_inner','de_14_ball','de_14_outer','de_21_inner','de_21_ball','de_21_outer']
data_12k_10c = pd.DataFrame()
for index in range(10):# 讀取MAT文件data = loadmat(f'matfiles\\{file_names[index]}')dataList = data[data_columns[index]].reshape(-1)data_12k_10c[columns_name[index]] = dataList[:119808]  # 121048  min: 121265
print(data_12k_10c.shape)
data_12k_10c

第三步,制作數據集

train_set、val_set、test_set 均為按照7:2:1劃分訓練集、驗證集、測試集,最后保存數據

第四步,制作訓練集和標簽

# 制作數據集和標簽
import torch# 這些轉換是為了將數據和標簽從Pandas數據結構轉換為PyTorch可以處理的張量,
# 以便在神經網絡中進行訓練和預測。def make_data_labels(dataframe):'''參數 dataframe: 數據框返回 x_data: 數據集     torch.tensory_label: 對應標簽值  torch.tensor'''# 信號值x_data = dataframe.iloc[:,0:-1]# 標簽值y_label = dataframe.iloc[:,-1]x_data = torch.tensor(x_data.values).float()y_label = torch.tensor(y_label.values.astype('int64')) # 指定了這些張量的數據類型為64位整數,通常用于分類任務的類別標簽return x_data, y_label# 加載數據
train_set = load('train_set')
val_set = load('val_set')
test_set = load('test_set')# 制作標簽
train_xdata, train_ylabel = make_data_labels(train_set)
val_xdata, val_ylabel = make_data_labels(val_set)
test_xdata, test_ylabel = make_data_labels(test_set)
# 保存數據
dump(train_xdata, 'trainX_1024_10c')
dump(val_xdata, 'valX_1024_10c')
dump(test_xdata, 'testX_1024_10c')
dump(train_ylabel, 'trainY_1024_10c')
dump(val_ylabel, 'valY_1024_10c')
dump(test_ylabel, 'testY_1024_10c')

1.2 數據加載,訓練數據、測試數據分組,數據分batch

import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 參數與配置
torch.manual_seed(100)  # 設置隨機種子,以使實驗結果具有可重復性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU訓練# 加載數據集
def dataloader(batch_size, workers=2):# 訓練集train_xdata = load('trainX_1024_10c')train_ylabel = load('trainY_1024_10c')# 驗證集val_xdata = load('valX_1024_10c')val_ylabel = load('valY_1024_10c')# 測試集test_xdata = load('testX_1024_10c')test_ylabel = load('testY_1024_10c')# 加載數據train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_xdata, train_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)val_loader = Data.DataLoader(dataset=Data.TensorDataset(val_xdata, val_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_xdata, test_ylabel),batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)return train_loader, val_loader, test_loaderbatch_size = 32
# 加載數據
train_loader, val_loader, test_loader = dataloader(batch_size)

2 Transformer分類模型和超參數選取

2.1 定義Transformer分類模型,采用Transformer架構中的編碼器:

注意:輸入數據進行了堆疊 ,把一個1*1024 的序列 進行劃分堆疊成形狀為 32 * 32, 就使輸入序列的長度降下來了

2.2 定義模型參數

# 模型參數
input_dim = 32 # 輸入維度
hidden_dim = 512  # 注意力維度
output_dim  = 10  # 輸出維度
num_layers = 4   # 編碼器層數
num_heads = 8    # 多頭注意力頭數
batch_size = 32
# 模型
model = TransformerModel(input_dim, output_dim, hidden_dim, num_layers, num_heads, batch_size)  
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)  # 優化器

2.3 模型結構

3 Transformer模型訓練與評估

3.1?模型訓練

訓練結果

100個epoch,準確率將近90%,Transformer模型分類效果良好,參數過擬合了,適當調整模型參數,降低模型復雜度,還可以進一步提高分類準確率。

注意調整參數:

  • 可以適當增加 Transforme編碼器層數 和隱藏層的維度,微調學習率;

  • 調整多頭注意力的頭數,增加更多的 epoch (注意防止過擬合)

  • 可以改變一維信號堆疊的形狀(設置合適的長度和維度)

3.2 模型評估

# 模型 測試集 驗證  
import torch.nn.functional as F# 加載模型
model =torch.load('best_model_transformer.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))# 將模型設置為評估模式
model.eval()
# 使用測試集數據進行推斷
with torch.no_grad():correct_test = 0test_loss = 0for test_data, test_label in test_loader:test_data, test_label = test_data.to(device), test_label.to(device)test_output = model(test_data)probabilities = F.softmax(test_output, dim=1)predicted_labels = torch.argmax(probabilities, dim=1)correct_test += (predicted_labels == test_label).sum().item()loss = loss_function(test_output, test_label)test_loss += loss.item()test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')Test Accuracy: 0.9570  Test Loss: 0.12100271

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/215382.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/215382.shtml
英文地址,請注明出處:http://en.pswp.cn/news/215382.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

據房間Id是否存在,判斷當前房間是否到期且實時更改顏色

重點代碼展示&#xff1a; <template><el-col style"width: 100%;height: 100%;"><el-col :span"20"><el-card class"room_info"><avue-data-icons :option"option"></avue-data-icons></el-…

RT-DETR算法優化改進:輕量化自研設計雙卷積重新設計backbone和neck,完成漲點且計算量和參數量顯著下降

??????本文自研創新改進:雙卷積由組卷積和異構卷積組成,執行 33 和 11 卷積運算代替其他卷積核僅執行 11 卷積,YOLOv8 Conv,從而輕量化RT-DETR,性能如下表,GFLOPs 8.1降低至7.6,參數量6.3MB降低至5.8MB RT-DETR魔術師專欄介紹: https://blog.csdn.net/m0_637742…

ubuntu-c++-可執行模塊-動態鏈接庫-鏈接庫搜索-基礎知識

文章目錄 1.動態鏈接庫簡介2.動態庫搜索路徑3.運行時鏈接及搜索順序4.查看可運行模塊的鏈接庫5.總結 1.動態鏈接庫簡介 動態庫又叫動態鏈接庫&#xff0c;是程序運行的時候加載的庫&#xff0c;當動態鏈接庫正確安裝后&#xff0c;所有的程序都可以使用動態庫來運行程序。動態…

Android帝國之日志系統--logd、logcat

本文概要 這是Android系統進程系列的第四篇文章&#xff0c;本文以自述的方式來介紹logd進程&#xff0c;通過本文您將了解到logd進程存在的意義&#xff0c;以及日志系統的實現原理。&#xff08;文中的代碼是基于android13&#xff09; Android系統進程系列的前三篇文章如下…

C#基礎與進階擴展合集-基礎篇(持續更新)

目錄 本文分兩篇&#xff0c;進階篇點擊&#xff1a;C#基礎與進階擴展合集-進階篇 一、基礎入門 Ⅰ 關鍵字 Ⅱ 特性 Ⅲ 常見異常 Ⅳ 基礎擴展 1、哈希表 2、擴展方法 3、自定義集合與索引器 4、迭代器與分部類 5、yield return 6、注冊表 7、不安全代碼 8、方法…

MATLAB中cell函數的用法

cell用法 在MATLAB中&#xff0c;cell 是一種特殊的數據類型&#xff0c;用于存儲不同大小和類型的數據。cell 數組是一種容器&#xff0c;每個元素可以包含任意類型的數據&#xff0c;包括數值、字符串、矩陣、甚至其他的 cell 數組。 以下是 cell 數組的基本語法和示例&…

gitblit自建git倉庫

安裝 java sudo apt-get update sudo apt-get install openjdk-8-jdk # 或者其它你喜歡的版本 驗證&#xff1a; java -version 下載 gitblit https://github.com/gitblit-org/gitblit/releases 解壓/usr/local tar -zxvf gitblit-1.9.3.tar.gz 修改配置文件 nano /usr/local/…

【React】useCallback 使用的說明

文章目錄 useCallback的優缺點優點缺點JavaScript 的內聯優化 使用場景 用了兩年多的react&#xff0c;今天抽空寫點小內容 useCallback的優缺點 緩存了每次渲染時候 inline callback的實例 優點 關鍵點&#xff1a;利用memoize減少無效的re-render&#xff0c;通常配合shouldC…

ElasticSearch之cat trained model API

命令樣例如下&#xff1a; curl -X GET "https://localhost:9200/_cat/ml/trained_models?vtrue&pretty" --cacert $ES_HOME/config/certs/http_ca.crt -u "elastic:ohCxPHQBEs5*lo7F9"執行結果輸出如下&#xff1a; id heap_size …

如何在OpenWRT軟路由系統部署uhttpd搭建web服務器實現遠程訪問——“cpolar內網穿透”

文章目錄 前言1. 檢查uhttpd安裝2. 部署web站點3. 安裝cpolar內網穿透4. 配置遠程訪問地址5. 配置固定遠程地址 前言 uhttpd 是 OpenWrt/LuCI 開發者從零開始編寫的 Web 服務器&#xff0c;目的是成為優秀穩定的、適合嵌入式設備的輕量級任務的 HTTP 服務器&#xff0c;并且和…

docker-compose的介紹與使用

一、docker-compose 常用命令和指令 1. 概要 默認的模板文件是 docker-compose.yml&#xff0c;其中定義的每個服務可以通過 image 指令指定鏡像或 build 指令&#xff08;需要 Dockerfile&#xff09;來自動構建。 注意如果使用 build 指令&#xff0c;在 Dockerfile 中設置…

RHEL網絡服務器

目錄 1.時間同步的重要性 2.配置時間服務器 &#xff08;1&#xff09;指定所使用的上層時間服務器。 (2&#xff09;指定允許訪問的客戶端 (3&#xff09;把local stratum 前的注釋符#去掉。 3.配置chrony客戶端 &#xff08;1&#xff09;修改pool那行,指定要從哪臺時間…

Python常見面試知識總結(一):迭代器、拷貝、線程及底層結構

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 今天來總結一下Python和C語言中常見的面試知識&#xff0c;歡迎大家一起前來探討學習~ 【一】Python中迭代器的概念&#xff1f; 可迭代對象是迭代器、生成器和裝飾器的基礎。簡單來說&#xff0c;可以使用for來循環遍歷…

[古劍山2023] pwn

最近這個打stdout的題真多。這個比賽沒打。拿到附件作了一天。 choice 32位&#xff0c;libc-2.23-i386&#xff0c;nbytes初始值為0x14,讀入0x804A04C 0x14字節后會覆蓋到nbytes 1個字節。當再次向v1讀入nbytes字節時會造成溢出。 先寫0x14p8(0xff)覆蓋到nbytes然后溢出寫傳…

初次參加軟考就想報高級,哪個相對容易考?

如果你想第一次參加軟考時就報考高級科目&#xff0c;但是卻不知道該報考高級中的哪個科目好、 ? ?那么今天的這篇文章你一定不要錯過&#xff01;首先&#xff0c;我們一起來了解一下&#xff0c;軟考高級中的5個科目。 ? ?軟考高級科目 ? 信息系統項目管理師 ? …

記錄一次postgresql臨時表丟失問題

項目相關技術棧 springboot hikari連接池pgbouncerpostgresql數據庫 背景 為了優化一個任務執行的速度&#xff0c;我將任務的sql中部分語句抽出生成臨時表&#xff08;create temp table tempqw as xxxxxxxxx&#xff09;&#xff0c;再和其他表關聯&#xff0c;提高查詢速…

三翼鳥2023輝煌收官, 定盤2024高質量棋局

最近在不同平臺上接連看到這樣的熱搜話題&#xff1a;用時間膠囊記錄2023的自己、2023年度問答、2023十大網絡流行語公布… 顯然&#xff0c; 2023年進入最后一個月&#xff0c;時間匆匆&#xff0c;這也意味著又到了總結過去和規劃未來的時候。拿到結果、取得成績當然是對202…

算法通關村第十五關 | 白銀 | 海量數據場景下的熱門算法題

1.從 40 個億中產生一個不存在的整數 可以采用位圖存儲數據&#xff0c;申請一個 bit 類型的數組 bitArr &#xff0c;每個位置只表示 0 或者 1 狀態&#xff0c;可以將占用內存縮小為使用哈希表的 1/32 。 遍歷給定的 40 億個數&#xff0c;遇到數時就將 bitArr 相應位置設置…

短視頻引流獲客系統:引領未來營銷的新潮流

在這個信息爆炸的時代&#xff0c;短視頻已經成為了人們獲取信息的主要渠道之一。而隨著短視頻的火爆&#xff0c;引流獲客系統也逐漸成為了營銷領域的新寵。本文將詳細介紹短視頻引流獲客系統的開發流程以及涉及到的技術&#xff0c;讓我們一起來看看這個引領未來營銷的新潮流…

華清遠見作業第二十四天

使用消息隊列完成兩個進程之間相互通信 代碼 #include<stdio.h> #include<string.h> #include<stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <sys/ipc.h> #include <sys/msg.h> #in…