深度學習篇---模型訓練(1)


文章目錄

  • 前言
  • 一、庫導入與配置部分
    • 介紹
  • 二、超參數配置
    • 簡介
  • 三、模型定義
    • 1. 改進殘差塊
    • 2. 完整CNN模型
  • 四、數據集類
  • 五、數據加載函數
  • 六、訓練函數
  • 七、驗證函數
  • 八、檢查點管理
  • 九、主函數
  • 十、執行入口
  • 十一、關鍵設計亮點總結
    • 1.維度管理
    • 2.數據標準化
    • 3.動態學習率
    • 4.梯度剪裁
    • 5.檢查點系統
    • 6.結果可追溯
    • 7.工業級健壯性
    • 8.高效數據加載


前言

本文再網絡結構(1)的基礎上,完善數據讀取、數據增強、數據處理、模型訓練、斷點訓練等功能。


一、庫導入與配置部分

import torch
import torch.nn as nn  # PyTorch核心神經網絡模塊
import pandas as pd    # 數據處理
import numpy as np     # 數值計算
from torch.utils.data import Dataset, DataLoader  # 數據加載工具
from sklearn.preprocessing import StandardScaler  # 數據標準化
from sklearn.model_selection import train_test_split  # 數據分割
from torch.optim.lr_scheduler import ReduceLROnPlateau  # 動態學習率調整
from collections import Counter  # 統計類別分布
import csv  # 結果記錄
import time  # 時間戳生成
import joblib  # 模型/參數持久化

介紹

導入Pytorch核心神經網路模塊、數據處理庫和數值處理庫數據標準化、數據分割、動態學習率調整、統計類別分布、結果記錄、時間戳生成、模型/參數持久化。

二、超參數配置

config = {"batch_size": 256,        # 每批數據量"num_workers": 128,       # 數據加載并行進程數"lr": 1e-3,               # 初始學習率"weight_decay": 1e-4,     # L2正則化強度"epochs": 200,            # 最大訓練輪數"patience": 15,           # 早停等待輪數"min_delta": 0.001,       # 視為改進的最小精度提升"grad_clip": 5.0,         # 梯度裁剪閾值"num_classes": None       # 自動計算類別數
}

簡介

設置每批數據量、數據加載并行進程數、初始學習率、L2正則化強度、最大訓練輪數、早停等待輪數、視為改進的最小精度提升、梯度剪裁閾值、自動計算類別數。

三、模型定義

1. 改進殘差塊

class ImprovedResBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()  # 初始化父類# 第一個卷積層self.conv1 = nn.Conv1d(in_channels, out_channels, 5, stride, 2)# 參數解釋:輸入通道,輸出通道,卷積核大小5,步長,填充2(保持尺寸)self.bn1 = nn.BatchNorm1d(out_channels)  # 批量歸一化# 第二個卷積層self.conv2 = nn.Conv1d(out_channels, out_channels, 3, 1, 1)# 3x1卷積,步長1,填充1保持尺寸self.bn2 = nn.BatchNorm1d(out_channels)self.relu = nn.ReLU()  # 激活函數# 下采樣路徑(當需要調整維度時)self.downsample = nn.Sequential(nn.Conv1d(in_channels, out_channels, 1, stride),  # 1x1卷積調整維度nn.BatchNorm1d(out_channels)) if in_channels != out_channels or stride != 1 else None# 當輸入輸出通道不同或步長>1時啟用def forward(self, x):identity = x  # 保留原始輸入作為殘差# 主路徑處理x = self.relu(self.bn1(self.conv1(x)))  # Conv1 -> BN1 -> ReLUx = self.bn2(self.conv2(x))  # Conv2 -> BN2(無激活)# 調整殘差路徑維度if self.downsample:identity = self.downsample(identity)x += identity  # 殘差連接return self.relu(x)  # 最終激活

2. 完整CNN模型

class EnhancedCNN(nn.Module):def __init__(self, input_channels, seq_len, num_classes):super().__init__()# 初始特征提取層self.initial = nn.Sequential(nn.Conv1d(input_channels, 64, 7, stride=2, padding=3),  # 快速下采樣nn.BatchNorm1d(64),nn.ReLU(),nn.MaxPool1d(3, 2, 1)  # 核3,步長2,填充1,輸出尺寸約為輸入1/4)# 殘差塊堆疊self.blocks = nn.Sequential(ImprovedResBlock(64, 128, stride=2),  # 通道翻倍,尺寸減半ImprovedResBlock(128, 256, stride=2),ImprovedResBlock(256, 512, stride=2),nn.AdaptiveAvgPool1d(1)  # 自適應全局平均池化到長度1)# 分類器self.classifier = nn.Sequential(nn.Linear(512, 256),     # 全連接層nn.Dropout(0.5),         # 強正則化防止過擬合nn.ReLU(),nn.Linear(256, num_classes)  # 最終分類層)def forward(self, x):x = self.initial(x)  # 初始特征提取x = self.blocks(x)   # 通過殘差塊x = x.view(x.size(0), -1)  # 展平維度 (batch, 512)return self.classifier(x)  # 分類預測

四、數據集類

class SequenceDataset(Dataset):def __init__(self, sequences, labels, scaler=None):self.sequences = sequences  # 原始序列數據self.labels = labels        # 對應標簽self.scaler = scaler or StandardScaler()  # 標準化器# 如果未提供scaler,用當前數據擬合新的if scaler is None:flattened = np.concatenate(sequences)  # 展平所有數據點self.scaler.fit(flattened)  # 計算均值和方差# 對每個序列進行標準化self.normalized = [self.scaler.transform(seq) for seq in sequences]def __len__(self):return len(self.sequences)  # 返回數據集大小def __getitem__(self, idx):# 獲取單個樣本seq = torch.tensor(self.normalized[idx], dtype=torch.float32).permute(1, 0)# permute將形狀從(seq_len, features)轉為(features, seq_len)符合Conv1d輸入要求label = torch.tensor(self.labels[idx], dtype=torch.long)# 數據增強if np.random.rand() > 0.5:  # 50%概率時序翻轉seq = seq.flip(-1)  # 沿時間維度翻轉if np.random.rand() > 0.3:  # 70%概率添加噪聲seq += torch.randn_like(seq) * 0.01  # 高斯噪聲(均值0,方差0.01)return seq, label

五、數據加載函數

def load_data(excel_path):df = pd.read_excel(excel_path)  # 讀取Excel數據sequences = []labels = []for _, row in df.iterrows():  # 遍歷每一行數據try:# 處理可能存在的字符串格式異常loads = list(map(float, str(row['載荷']).split(',')))displacements = list(map(float, str(row['位移']).split(',')))powers = list(map(float, str(row['功率']).split(',')))# 對齊三列數據的長度min_len = min(len(loads), len(displacements), len(powers))# 組合成(時間步長, 3個特征)的數組combined = np.array([loads[:min_len], displacements[:min_len], powers[:min_len]).T  # 轉置為(min_len, 3)label = int(float(row['工況結果']))  # 轉換標簽sequences.append(combined)labels.append(label)except Exception as e:print(f"處理第{_}行時出錯: {str(e)}")  # 異常處理# 統計類別分布label_counts = Counter(labels)print("類別分布:", label_counts)# 創建標簽映射(將任意標簽轉換為0~N-1的索引)unique_labels = sorted(list(set(labels)))label_map = {l:i for i,l in enumerate(unique_labels)}config["num_classes"] = len(unique_labels)  # 更新配置labels = [label_map[l] for l in labels]  # 轉換所有標簽# 分層劃分訓練/驗證集(保持類別比例)return train_test_split(sequences, labels, test_size=0.2, stratify=labels)

六、訓練函數

def train_epoch(model, loader, optimizer, criterion, device):model.train()  # 訓練模式total_loss = 0for x, y in loader:  # 遍歷數據加載器x, y = x.to(device), y.to(device)  # 數據遷移到設備optimizer.zero_grad()  # 清空梯度outputs = model(x)     # 前向傳播loss = criterion(outputs, y)  # 計算損失loss.backward()        # 反向傳播# 梯度裁剪防止爆炸nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])optimizer.step()       # 參數更新total_loss += loss.item() * x.size(0)  # 累加損失(考慮批次大小)return total_loss / len(loader.dataset)  # 平均損失

七、驗證函數

def validate(model, loader, criterion, device):model.eval()  # 評估模式total_loss = 0correct = 0with torch.no_grad():  # 禁用梯度計算for x, y in loader:x, y = x.to(device), y.to(device)outputs = model(x)            loss = criterion(outputs, y)total_loss += loss.item() * x.size(0)# 計算準確率preds = outputs.argmax(dim=1)  # 取最大概率類別correct += preds.eq(y).sum().item()  # 統計正確數return (total_loss / len(loader.dataset),  # 平均損失(correct / len(loader.dataset))    # 準確率

八、檢查點管理

def save_checkpoint(epoch, model, optimizer, scheduler, best_acc, scaler, filename="checkpoint.pth"):torch.save({'epoch': epoch,                    # 當前輪數'model_state_dict': model.state_dict(),          # 模型參數'optimizer_state_dict': optimizer.state_dict(),  # 優化器狀態'scheduler_state_dict': scheduler.state_dict(),  # 學習率調度器狀態'best_acc': best_acc,              # 當前最佳準確率'scaler': scaler                   # 數據標準化參數}, filename)def load_checkpoint(filename, model, optimizer, scheduler):checkpoint = torch.load(filename)model.load_state_dict(checkpoint['model_state_dict'])       # 加載模型optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict'])return checkpoint['epoch'], checkpoint['best_acc'], checkpoint['scaler']

九、主函數

def main(resume=False):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 自動選擇設備# 生成帶時間戳的結果文件名timestamp = time.strftime("%Y%m%d_%H%M%S")results_file = f"training_results_{timestamp}.csv"# 加載并劃分數據train_seq, val_seq, train_lb, val_lb = load_data("./dcgt.xls")# 初始化模型(恢復訓練時自動獲取序列長度)sample_seq = train_seq[0].shape[1] if resume else Nonemodel = EnhancedCNN(input_channels=3, seq_len=sample_seq,  num_classes=config["num_classes"]).to(device)# 定義損失函數和優化器criterion = nn.CrossEntropyLoss()optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])# 學習率調度器(根據驗證損失調整)scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)# 恢復訓練邏輯start_epoch = 0best_acc = 0if resume:checkpoint = torch.load("checkpoint.pth")model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])scheduler.load_state_dict(checkpoint['scheduler_state_dict'])start_epoch = checkpoint['epoch']best_acc = checkpoint['best_acc']train_set = SequenceDataset(train_seq, train_lb, scaler=checkpoint['scaler'])else:train_set = SequenceDataset(train_seq, train_lb)# 驗證集使用訓練集的scalerval_set = SequenceDataset(val_seq, val_lb, scaler=train_set.scaler)# 持久化標準化參數joblib.dump(train_set.scaler, 'scaler.save')# 創建數據加載器train_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True, num_workers=config["num_workers"]  # 多進程加載加速)val_loader = DataLoader(val_set, batch_size=config["batch_size"], num_workers=config["num_workers"])# 訓練循環with open(results_file, 'w', newline='') as f:writer = csv.writer(f)writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_acc', 'learning_rate'])for epoch in range(start_epoch, config["epochs"]):# 訓練一個epochtrain_loss = train_epoch(model, train_loader, optimizer, criterion, device)# 驗證val_loss, val_acc = validate(model, val_loader, criterion, device)current_lr = optimizer.param_groups[0]['lr']  # 獲取當前學習率# 更新學習率scheduler.step(val_loss)# 保存檢查點save_checkpoint(epoch+1, model, optimizer, scheduler, best_acc, train_set.scaler)# 記錄結果writer.writerow([epoch + 1, f"{train_loss:.4f}", f"{val_loss:.4f}", f"{val_acc:.4f}", f"{current_lr:.6f}"])print(f"\nEpoch {epoch+1}/{config['epochs']}")print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")print(f"Val Acc: {val_acc*100:.2f}% | Learning Rate: {current_lr:.6f}")# 早停邏輯(偽代碼示意)if val_acc > best_acc + config["min_delta"]:best_acc = val_accpatience_counter = 0else:patience_counter += 1if patience_counter >= config["patience"]:print(f"早停觸發于第{epoch+1}輪")break# 保存最終模型torch.save(model.state_dict(), "best_model.pth")

十、執行入口

if __name__ == "__main__":main(resume=False)  # 首次訓練# main(resume=True)  # 恢復訓練

十一、關鍵設計亮點總結

1.維度管理

維度管理:通過permute確保數據形狀符合Conv1d要求

2.數據標準化

數據標準化:使用全體訓練數據計算均值和方差,避免數據泄露

3.動態學習率

動態學習率:ReduceLROnPlateau根據驗證損失自動調整

4.梯度剪裁

梯度裁剪:防止梯度爆炸,穩定訓練過程

5.檢查點系統

檢查點系統:完整保存訓練狀態,支持訓練中斷恢復

6.結果可追溯

結果可追溯:帶時間戳的CSV記錄和模型保存

7.工業級健壯性

工業級健壯性:異常捕獲、參數持久化、自動類別映射

8.高效數據加載

高效數據加載:多進程并行加速數據預處理

這個實現涵蓋了從數據預處理到模型訓練的完整流程,適合工業級時間序列分類任務,具有良好的可擴展性和可維護性。


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/76843.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/76843.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/76843.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

題解:AT_abc241_f [ABC241F] Skate

一道經典的 bfs 題。 提醒:本題解是為小白專做的,不想看的大佬請離開。 這道題首先一看就知道是 bfs,但是數據點不讓我們過: 1 ≤ H , W ≤ 1 0 9 1\le H,W\le10^9 1≤H,W≤109。 那么我們就需要優化了,從哪兒下手…

【含文檔+PPT+源碼】基于微信小程序的鄉村振興民宿管理系統

項目介紹 本課程演示的是一款基于微信小程序的鄉村振興民宿管理系統,主要針對計算機相關專業的正在做畢設的學生與需要項目實戰練習的 Java 學習者。 1.包含:項目源碼、項目文檔、數據庫腳本、軟件工具等所有資料 2.帶你從零開始部署運行本套系統 3.該…

STM32定時器通道1-4(CH1-CH4)的引腳映射關系

以下是 STM32定時器通道1-4(CH1-CH4)的引腳映射關系的詳細說明,以常見型號為例。由于不同系列/型號差異較大,請務必結合具體芯片的參考手冊確認。 一、STM32F1系列(如STM32F103C8T6) 1. TIM1(高級定時器) 通道默認引腳重映射引腳(部分/完全)備注CH1PA8無互補輸出CH1…

bge-m3+deepseek-v2-16b+離線語音能力實現離線文檔向量化問答語音版

ollama run deepseek-v2:16b ollama pull bge-m3 1、離線聽寫效果的大幅度提升。50M 1.3G(每次初始化都會很慢)---優化到首次初始化使用0延遲響應。 2、文檔問答歷史問題處理與優化,文檔問答離線策略討論與參數暴露。 3、離線大模型答復中斷…

前端界面在線excel編輯器 。node編寫post接口獲取文件流,使用傳參替換表格內容展示、前后端一把梭。

首先luckysheet插件是支持在線替換excel內容編輯得但是瀏覽器無法調用本地文件,如果只是展示,讓后端返回文件得二進制文件流就可以了,直接使用luckysheet展示。 這里我們使用xlsx-populate得node簡單應用來調用本地文件,自己寫一個…

JavaScript學習20-Event事件對象

1.屬性 即點擊誰就打印出來誰 2.方法 未添加stopPropagatio方法: 添加stopPropagatio方法后:

FreeRTOS 啟動過程中 SVC 和 PendSV 的工作流程?

在 FreeRTOS 的啟動過程中,SVC(Supervisor Call) 和 PendSV(Pendable Service Call) 是兩個關鍵的系統異常,分別用于 首次任務啟動 和 任務上下文切換。它們的協作確保了從內核初始化到多任務調度的平滑過渡。以下是詳細的工作流程分析(以 ARM Cortex-M 為例): 1. SVC…

[自制調試工具]構建高效調試利器:Debugger 類詳解

一、引言 在軟件開發的漫漫征程中,調試就像是一位忠誠的伙伴,時刻陪伴著開發者解決代碼里的各類問題。為了能更清晰地了解程序運行時變量的狀態,我們常常需要輸出各種變量的值。而 Debugger 類就像是一個貼心的調試助手,它能幫我…

foobar2000 VU Meter Visualisation 插件漢化版 VU表

原英文插件點此 界面展示 下載 https://wwtn.lanzout.com/iheI22ssoybi 安裝方式 解壓安裝文件,文件名為:foo_vis_vumeter-0.10.2_CHINIESE.fb2k-component

消息中間件對比與選型指南:Kafka、ActiveMQ、RabbitMQ與RocketMQ

目錄 引言 消息中間件的定義與作用 消息中間件在分布式系統中的重要性 對比分析的四種主流消息中間件概述 消息中間件核心特性對比 消息傳遞模型 Kafka:專注于發布-訂閱模型 ActiveMQ:支持點對點和發布-訂閱兩種模型 RabbitMQ:支持點…

liunx輸入法

1安裝fcitx5 sudo apt update sudo apt install fcitx fcitx-pinyin 2配置為默認輸入法 設置-》系統-》區域和語言 點擊系統彈出語言和支持選擇鍵盤輸入法系統 3設置設置 fcitx-configtool 如果沒顯示需要重啟電腦 4配置fcitx 把搜狗輸入法放到第一位(點擊下面…

WindowsPE文件格式入門05.PE加載器LoadPE

https://bpsend.net/thread-316-1-1.html LoadPE - pe 加載器 殼的前身 如果想訪問一個程序運行起來的內存,一種方法就是跨進程讀寫內存,但是跨進程讀寫內存需要來回調用api,不如直接訪問地址來得方便,那么如果我們需要直接訪問地址,該怎么做呢?.需要把dll注進程,注進去的代碼…

QGIS中第三方POI坐標偏移的快速校正-百度POI

1.百度POI: name,lng,lat,address 龍記黃燜雞米飯(共享區店),121.908315,30.886636,南匯新城鎮滬城環路699弄117號(A1區110室) 好福記黃燜雞(御橋路店),121.571409,31.162292,滬南路2419弄26號1層B間 御品黃燜雞米飯(安亭店),121.160322,31.305977,安亭鎮新源路792號…

SQL的調優方案

一、前言 SQL調優是提升數據庫性能的關鍵手段。需結合索引優化、SQL語句優化、執行計劃分析及數據庫架構設計等多方面綜合處理。 二、索引優化 創建合適索引 高頻查詢字段:對WHERE、JOIN、ORDER BY涉及的字段創建索引,尤其是區分度高的字段&#xff08…

【項目管理】第一部分 信息技術 1/2

相關文檔,希望互相學習,共同進步 風123456789~-CSDN博客 概要 知識點: 現代化基礎設施、數字經濟、工業互聯網、車聯網、智能制造、智慧城市、數字政府、5G、常用數據庫類型、數據倉庫、信息安全、網絡安全態勢感知、物聯網、大數…

【玩泰山派】1、mac上使用串口連接泰山派

文章目錄 前言picocom工具連接泰山派安裝picocom工具安裝ch340的驅動串口工具接線使用picocom連接泰山派 參考 前言 windows上面有xshell這個好用的工具可以使用串口連接板子,在mac上好像沒找到太好的工具,只能使用命令行工具去搞了。 之前查找說mac上…

【C++奇遇記】C++中的進階知識(繼承(一))

🎬 博客主頁:博主鏈接 🎥 本文由 M malloc 原創,首發于 CSDN🙉 🎄 學習專欄推薦:LeetCode刷題集 數據庫專欄 初階數據結構 🏅 歡迎點贊 👍 收藏 ?留言 📝 如…

【Scratch編程系列】Scratch編程軟件界面

Scratch是一款由麻省理工學院(MIT) 設計開發的少兒編程工具。其特點是:使用者可以不認識英文單詞,也可以不使用鍵盤,就可以進行編程。構成程序的命令和參數通過積木形狀的模塊來實現。用鼠標拖動指令模塊到腳本區就可以了。 這個軟…

開篇 - 配置Unlua+VsCode的智能提示、調試以及學習方法

智能提示 為要綁定Lua的藍圖創建模板文件,這會在Content/Script下生成lua文件 然后點擊生成智能代碼提示,這會在Plugins/Unlua/Intermediate/生成Intenllisense文件夾 打開VSCode,點擊文件->將工作區另存為。生成一個空工作區,放置在工程…

QEMU-KVM加SPICE,云電腦誕生了

沒錯!?QEMU-KVM SPICE? 的組合,本質上就是一套?輕量級云電腦(云桌面)?的解決方案。通過虛擬化技術將計算資源池化,再通過SPICE協議提供流暢的遠程桌面體驗,用戶用任意設備(筆記本/平板/瘦客…