【筆記】訓練步驟代碼解析

目錄

config參數配置

setup_dirs創建訓練文件夾

?load_data加載數據

build_model創建模型

train訓練


記錄一下訓練代碼中不理解的地方

config參數配置

config = {'data_root': r"D:\project\megnetometer\datasets\WISDM_ar_latest\organized_dataset",'train_dir': 'train','test_dir': 'test','seq_length': 300,  # 序列長度'batch_size': 32,  # 可能需減小batch_size'epochs': 60,'initial_lr': 3e-4,  # 初始學習率'max_lr': 5e-4,'patience': 20}

配置好需要用到的參數,比如數據集地址,訓練輪數,批次大小,學習率等

setup_dirs創建訓練文件夾

    def setup_dirs(self):self.run_dir = os.path.join(self.config['data_root'], 'run')  os.makedirs(self.run_dir, exist_ok=True)print('創建運行目錄run_dir  = ', self.run_dir)# 創建帶時間戳的實驗目錄timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")print('時間戳 = ', timestamp)self.exp_dir = os.path.join(self.run_dir, f"exp_{timestamp}")os.makedirs(self.exp_dir, exist_ok=True)# 保存當前配置with open(os.path.join(self.exp_dir, 'config.json'), 'w') as f:json.dump(self.config, f, indent=2)  # 兩個字符縮進,沒有則壓縮成一行,把config內容存在config.json里

os.path.join(self.config['data_root'], 'run')  

用于拼接文件路徑data_root的路徑加上run,中間的連接符會根據系統自動調整

os.makedirs(self.exp_dir, exist_ok=True)

創建文件,exist_ok=True當文件夾存在的時候不報錯

創建的文件夾用于存放后續訓練生成的模型以及保存訓練參數等文件

?load_data加載數據

    def load_data(self):"""從按行為分類的目錄加載數據(帶多級進度條)"""def load_activity_data(subset_dir):"""加載train或test子目錄下的數據"""data = []subset_path = os.path.join(self.config['data_root'], subset_dir)  #在數據集路徑內讀取,由subset_dir決定讀取的是訓練集還是測試集# 獲取所有活動類別目錄activities = [d for d in os.listdir(subset_path)if os.path.isdir(os.path.join(subset_path, d))]#print('activities=',activities)#activities= ['Downstairs', 'Jogging', 'Sitting', 'Standing', 'Upstairs', 'Walking']# 第一層進度條:活動類別pbar_activities = tqdm(activities, desc=f"掃描{subset_dir}目錄", position=0)for activity in pbar_activities:activity_lower = activity.lower()if activity_lower not in self.label_map:continueactivity_dir = os.path.join(subset_path, activity)#當前活動的目錄# 獲取所有用戶文件user_files = [f for f in os.listdir(activity_dir)if f.endswith('.txt')]#獲取所有txt結尾的文件# 第二層進度條:用戶文件#pbar_users = tqdm(user_files, desc="讀取用戶文件", leave=False, position=1)#后面要close,但是已經把所有的進度注釋掉了只留下來一個總的第一層進度#print('pbar_users=',pbar_users)for user_file in user_files:file_path = os.path.join(activity_dir, user_file)# 獲取文件行數用于進度條with open(file_path, 'r') as f:num_lines = sum(1 for _ in f)# 第三層進度條:讀取文件內容with open(file_path, 'r') as f:for line in f:line = line.strip()if not line:continuetry:x, y, z = map(float, line.split(','))data.append({'x': x,'y': y,'z': z,'activity': activity_lower})except ValueError:continuepbar_activities.close()return data# 調用示例print("\n" + "=" * 50)print("開始加載數據集...")train_data = load_activity_data(self.config['train_dir'])#print(train_data)#{'x': 5.33, 'y': 8.73, 'z': -0.42, 'activity': 'walking'},test_data = load_activity_data(self.config['test_dir'])
pbar_activities = tqdm(activities, desc=f"掃描{subset_dir}目錄", position=0)

tqdm創建進度條,desc是進度條前面的描述,position用于多級進度條之間的嵌套,以免位置混亂,在運行完之后要關閉進度條

pbar_activities.close()
with open('data.txt', 'r') as f:打開文件夾,r為只讀模式

# 轉換為模型輸入格式(帶優化進度條)def create_sequences(data, desc="生成序列"):seq_length = self.config['seq_length']features, labels = [], []total_windows = len(data) - seq_lengthpbar = tqdm(range(total_windows),desc=desc,position=0,bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [速度:{rate_fmt}]")for i in pbar:window = data[i:i + seq_length]# 檢查窗口內活動是否一致if len(set(d['activity'] for d in window)) != 1:continuefeatures.append([[d['x'], d['y'], d['z']] for d in window])labels.append(self.label_map[window[0]['activity']])# 每1000次更新一次進度信息if i % 1000 == 0:pbar.set_postfix({"有效窗口": len(features),"跳過窗口": i - len(features) + 1}, refresh=True)return np.array(features), np.array(labels)print("\n正在預處理訓練集...")X_train, y_train = create_sequences(train_data, "訓練集序列化")#返回的x是數據,y是標簽print("\n正在預處理測試集...")X_test, y_test = create_sequences(test_data, "測試集序列化")# 標準化(顯示進度)print("\n正在計算標準化參數...")self.mean = np.mean(X_train, axis=(0, 1))self.std = np.std(X_train, axis=(0, 1))print("應用標準化...")X_train = (X_train - self.mean) / (self.std + 1e-8)X_test = (X_test - self.mean) / (self.std + 1e-8)# One-hot編碼# 將 NumPy 數組轉為 PyTorch 張量,并指定類型為 int64(等價于 .long())y_train = torch.from_numpy(y_train).long()  # 或 .to(torch.int64)y_train = torch.nn.functional.one_hot(y_train.long(), num_classes=len(self.label_map))y_test = torch.from_numpy(y_test).long()  # 或 .to(torch.int64)y_test = torch.nn.functional.one_hot(y_test.long(), num_classes=len(self.label_map))print("\n" + "=" * 50)print("數據預處理完成!")print(f"訓練集形狀: X_train{X_train.shape}, y_train{y_train.shape}")print(f"測試集形狀: X_test{X_test.shape}, y_test{y_test.shape}")print("=" * 50 + "\n")return (X_train, y_train), (X_test, y_test)

滑動窗口開銷大,改用向量化滑動窗口(NumPy)

參數標準化全部使用訓練集數據

1e-8的作用:防止除零的小常數,特別適用于某些標準差接近0的特征

axis=(0,1):假設您的數據是3D張量(樣本×時間步/空間×特征),這樣計算每個特征通道的統計量

消除量綱影響:當特征的單位/量綱不同時(如年齡0-100 vs 工資0-100000),標準化使所有特征具有可比性

只使用訓練集統計量:測試集必須使用訓練集的mean/std,這是為了避免數據泄露(data leakage)

數據泄露:是機器學習中一個常見但嚴重的問題,指在模型訓練過程中意外地使用了測試集或未來數據的信息,導致模型評估結果被高估,無法反映真實性能。這種現象會使模型在實際應用中表現遠差于預期。

將分類標簽(整數形式)轉換為 One-hot 編碼,這是機器學習中處理分類任務的常見方法。

build_model創建模型

    def build_model(self):"""構建改進的BiLSTM分類模型"""model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(self.config['seq_length'], 3)),# 雙向LSTM層tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),tf.keras.layers.BatchNormalization(),tf.keras.layers.Dropout(0.2),tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),tf.keras.layers.BatchNormalization(),# 全連接層tf.keras.layers.Dense(32, activation='relu'),tf.keras.layers.Dropout(0.3),tf.keras.layers.Dense(len(self.label_map), activation='softmax')])model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])return model

兩個模型框架TensorFlow更早,但PyTorch的初始設計更現代,以上是TensorFlow的模型。

計算圖(Computational Graph) 是描述數學運算和數據處理流程的抽象結構,而 靜態圖動態圖 是兩種不同的計算圖構建和執行方式。

計算圖 是一個有向無環圖(DAG),用于表示計算過程:

  • 節點(Node):代表運算(如加法、矩陣乘法)或數據(如張量、變量)。

  • 邊(Edge):描述數據流動方向(如張量從一層傳遞到下一層)。

改用PyTorch模型需要注意

PyTorch更推薦類式構建,而且保存時僅保存模型的參數(權重和偏置),不包含模型結構。如果需要測試,加載時必須先實例化一個結構完全相同的模型,再加載參數。

先創建一個模型類,再去調用?

class BiLSTMModel(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes, bidirectional=True):super(BiLSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.num_directions = 2 if bidirectional else 1# 雙向LSTMself.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True,bidirectional=bidirectional)# 全連接層(雙向時hidden_size需*2)self.fc = nn.Linear(hidden_size * self.num_directions, num_classes)def forward(self, x):# 初始化隱藏狀態(可選,PyTorch默認全零)h0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(x.device)# LSTM前向傳播out, _ = self.lstm(x, (h0, c0))  # out形狀: (batch, seq_len, hidden_size * num_directions)# 取最后一個時間步的輸出out = out[:, -1, :]  # 形狀: (batch, hidden_size * num_directions)# 分類層out = self.fc(out)return out

此處構建的就是雙向LSTM模型,然后再構建函數調用

    def build_model(self):# 使用示例model = LSTMModel(input_size=3,  # 對應x/y/z特征hidden_size=32,num_layers=2,num_classes=6,  # 類別數bidirectional=True)return model

train訓練

    def train(self):"""PyTorch版本訓練流程"""# 1. 數據加載與預處理(X_train, y_train), (X_test, y_test) = self.load_data()# 轉換為PyTorch張量并移至設備device = torch.device("cuda" if torch.cuda.is_available() else "cpu")X_train = torch.FloatTensor(X_train).to(device)y_train = torch.LongTensor(y_train.argmax(axis=1)).to(device)  # 如果y是one-hotX_test = torch.FloatTensor(X_test).to(device)y_test = torch.LongTensor(y_test.argmax(axis=1)).to(device)# 創建DataLoadertrain_dataset = TensorDataset(X_train, y_train)# 類似zip(features, labels)train_loader = DataLoader(train_dataset,batch_size=self.config['batch_size'],shuffle=True)# 2. 模型初始化self.model = self.build_model().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(self.model.parameters(),lr=self.config.get('lr', 0.001))# 3. 回調函數設置"""# 早停early_stopping = EarlyStopping(patience=self.config['patience'],verbose=True,path=os.path.join(self.exp_dir, 'best_model.pth'))"""# 學習率調度scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.1,patience=5,verbose=True)# TensorBoard日志writer = SummaryWriter(log_dir=os.path.join(self.exp_dir, 'logs'))print("\n開始訓練...")print(f"實驗目錄: {self.exp_dir}")print(f"使用設備: {device}")# 4. 訓練循環for epoch in range(self.config['epochs']):self.model.train()train_loss = 0.0# 訓練批次for inputs, labels in train_loader:optimizer.zero_grad()outputs = self.model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()# 驗證階段self.model.eval()with torch.no_grad():test_outputs = self.model(X_test)test_loss = criterion(test_outputs, y_test)_, predicted = torch.max(test_outputs, 1)accuracy = (predicted == y_test).float().mean()# 記錄日志writer.add_scalar('Loss/train', train_loss / len(train_loader), epoch)writer.add_scalar('Loss/test', test_loss.item(), epoch)writer.add_scalar('Accuracy/test', accuracy.item(), epoch)# 打印進度print(f"Epoch {epoch + 1}/{self.config['epochs']} | "f"Train Loss: {train_loss / len(train_loader):.4f} | "f"Test Loss: {test_loss.item():.4f} | "f"Accuracy: {accuracy.item():.4f}")# 學習率調整scheduler.step(test_loss)"""# 早停檢查early_stopping(test_loss, self.model)if early_stopping.early_stop:print("Early stopping triggered")break"""# 5. 保存最終結果writer.close()self.save_results(X_test, y_test)  # 需要適配PyTorch的保存方法

TensorDatasetDataLoader 都是 PyTorch 官方庫中的核心組件,專門用于高效的數據加載和批處理。

torch.utils.data.TensorDataset將多個張量(如特征張量和標簽張量)打包成一個數據集對象

dataset = TensorDataset(features, labels)  # 類似zip(features, labels)

torch.utils.data.DataLoader將數據集按批次加載,支持自動批處理、打亂數據、多進程加載等

shuffle=True代表打亂數據,此處是時序信號,但是由于從長序列中通過滑動窗口提取樣本每個窗口本身就是一個獨立樣本,此時打亂窗口順序是安全的

損失函數 criterion = nn.CrossEntropyLoss()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/90727.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/90727.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/90727.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Java填充Word模板

文章目錄前言一、設置word模板普通字段列表字段復選框二、代碼1. 引入POM2. 模板放入項目3.代碼實體類工具類三、測試四、運行結果五、注意事項前言 最近有個Java填充Word模板的需求,包括文本,列表和復選框勾選,寫一個工具類,以此…

【MYSQL8】springboot項目,開啟ssl證書安全連接

文章目錄一、開啟ssl證書1、msysql部署時默認開啟ssl證書2、配置文件3、創建用戶并指定ssl二、添加Java信任庫1、使用 keytool 導入證書2、驗證證書是否已導入三、修改連接配置一、開啟ssl證書 1、msysql部署時默認開啟ssl證書 可通過命令查看: SHOW VARIABLES L…

Telegraf vs. Logstash:實時數據處理架構中的關鍵組件對比

在現代數據基礎設施中,Telegraf 和 Logstash 是兩種廣泛使用的開源數據收集與處理工具,但它們在設計目標、應用場景和架構角色上存在顯著差異。本文將從實時數據處理架構、時序數據庫集成、消息代理支持等方面對比兩者的核心功能,并結合實際應…

Vue Vue-route (4)

Vue 漸進式JavaScript 框架 基于Vue2的學習筆記 - Vue-route 編程式導航和幾種路由 目錄 編程式導航 詳情組件 創建組件 設置路由 電影列表 傳參 另一種方式 動態路由 命名路由 別名 總結 編程式導航 點擊電影列表 跳轉電影詳情 詳情組件 創建組件 在views中創…

存在兩個cuda環境,在conda中切換到另一個

進入 openmmlab 環境 conda activate openmmlab 設置環境變量為 CUDA 12.4(只影響當前 shell 會話) export PATH/usr/local/cuda-12.4/bin:PATHexportLDLIBRARYPATH/usr/local/cuda?12.4/lib64:PATH export LD_LIBRARY_PATH/usr/local/cuda-12.4/lib64:…

Django 視圖(View)

1. 視圖簡介 視圖負責接收 web 請求并返回 web 響應。視圖就是一個 python 函數,被定義在 views.py 中。響應可以是一張網頁的 HTML 內容、一個重定向、一個 404 錯誤等等。響應處理過程如下圖: 用戶在瀏覽器中輸入網址:www.demo.com/1/100Django 獲取網址信息,去除域名和端…

HarmonyOS基礎概念

一、OpenHarmony、HarmonyOS和Harmony NEXT區別OpenHarmony是由開放原子開源基金會(OpenAtom Foundation)孵化及運營的開源項目,開放原子開源基金會由華為、阿里、騰訊、百度、浪潮、招商銀行、360等十家互聯網企業共同發起組建。目標是面向全…

spark3 streaming 讀kafka寫es

1. 代碼 package data_import import org.apache.spark.sql.{DataFrame, Row, SparkSession, SaveMode} import org.apache.spark.sql.types.{ArrayType, DoubleType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.functions._…

【跟著PMP學習項目管理】每日一練 - 3

1、你是一個建筑項目的項目經理。電工已經開始鋪設路線,此時客戶帶著一個變更請求來找你。他需要增加插座,你認為這會增加相關工作的成本。你要做的第一件事? A、拒絕做出變更,因為這會增加項目的成本并超出預算 B、參考項目管理計劃,查看是否應當處理這個變更 C、查閱…

CentOS 安裝 JDK+ NGINX+ Tomcat + Redis + MySQL搭建項目環境

目錄第一步:安裝JDK 1.8方法 1:安裝 Oracle JDK 1.8方法 2:安裝 OpenJDK 1.8第二步:使用yum安裝NGINX第三步:安裝Tomcat第四步:安裝Redis第五步:安裝MySQL第六步:MySQL版本兼容性問題…

如何設計一個登錄管理系統:單點登錄系統架構設計

關鍵詞:如何設計一個登錄管理系統、登錄系統架構、用戶認證、系統安全設計 📋 目錄 開篇:為什么登錄系統這么重要?整體架構設計核心功能模塊安全設計要點技術實現細節性能優化策略總結與展望 開篇:為什么登錄系統這么…

論跡不論心

2025年7月11日,16~26℃,陰 緊急不緊急重要 備考ing 備課不重要 遇見:免費人格測試 | 16Personalities,下面是我的結果 INFJ分析與優化建議 User: Anonymous (隱藏) Created: 2025/7/11 23:38 Updated: 2025/7/11 23:43 Exported:…

【面板數據】省級泰爾指數及城鄉收入差距測算(1990-2024年)

對中國各地區1990-2024年的泰爾指數、城鄉收入差距進行測算。本文參考龍海明等(2015),程名望、張家平(2019)的做法,采用泰爾指數測算城鄉收入差距。參考陳斌開、林毅夫(2013)的做法&…

http get和http post的區別

HTTP GET 和 HTTP POST 是兩種最常用的 HTTP 請求方法,它們在用途、數據傳輸方式、安全性等方面存在顯著差異。以下是它們的主要區別:1. 用途GET:主要用于請求從服務器獲取資源,比如獲取網頁內容、查詢數據庫等。GET 請求不應該用…

I2C集成電路總線

(摘要:空閑時,時鐘線數據線都是高電平,主機發送數據前,要在時鐘為高電平時,把數據線從高電平拉低,數據發送采取高位先行,時鐘線低電平時可以修改數據線,時鐘線高電平時要…

為了安全應該使用非root用戶啟動nginx

nginx基線安全,修復步驟。主要是由于使用了root用戶啟動nginx。為了安全應該使用非root用戶啟動nginx一、檢查項和問題檢查項分類檢查項名稱身份鑒別檢查是否配置Nginx賬號鎖定策略。服務配置檢查Nginx進程啟動賬號。服務配置Nginx后端服務指定的Header隱藏狀態服務…

論文解析篇 | YOLOv12:以注意力機制為核心的實時目標檢測算法

前言:Hello大家好,我是小哥談。長期以來,改進YOLO框架的網絡架構一直至關重要,但盡管注意力機制在建模能力方面已被證明具有優越性,相關改進仍主要集中在基于卷積神經網絡(CNN)的方法上。這是因…

學習C++、QT---20(C++的常用的4種信號與槽、自定義信號與槽的講解)

每日一言相信自己,你比想象中更接近成功,繼續勇往直前吧!那么我們開始用這4種方法進行信號與槽的通信第一種信號與槽的綁定方式我們將按鍵右鍵后轉到槽會自動跳轉到這個widget.h文件里面并自動生成了定義,我們要記住我們這個按鈕叫…

Anolis OS 23 架構支持家族新成員:Anolis OS 23.3 版本及 RISC-V 預覽版發布

自 Anolis OS 23 版本發布之始,龍蜥社區就一直致力于探索同源異構的發行版能力,從 Anolis OS 23.1 版本支持龍芯架構同源異構開始,社區就在持續不斷地尋找更多的異構可能性。 RISC-V 作為開放、模塊化、可擴展的指令集架構,正成為…

4萬億英偉達,憑什么?

CUDA正是英偉達所有神話的起點。它不是一個產品,而是一個生態系統。當越多的開發者使用CUDA,就會催生越多的基于CUDA的應用程序和框架;這些殺手級應用又會吸引更多的用戶和開發者投身于CUDA生態。這個正向飛輪一旦轉動起來,其產生…