pytorch 數據預處理,加載,訓練,可視化流程

流程

    • 定義自定義數據集類
    • 定義訓練和驗證的數據增強
    • 定義模型、損失函數和優化器
    • 訓練循環,包括驗證
    • 訓練可視化
    • 整個流程
    • 模型評估
    • 高級功能擴展
      • 混合精度訓練?
      • 分布式訓練?

外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳{:width=“50%” height=“50%”}

定義自定義數據集類

#======================
#1. 自定義數據集類
#======================
class CustomImageDataset(Dataset):def __init__(self, root_dir, transform=None):"""自定義數據集初始化:param root_dir: 數據集根目錄:param transform: 數據增強和預處理"""self.root_dir = root_dirself.transform = transformself.classes = sorted(os.listdir(root_dir))self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}# 收集所有圖像路徑和標簽self.image_paths = []self.labels = []for cls_name in self.classes:cls_dir = os.path.join(root_dir, cls_name)for img_name in os.listdir(cls_dir):if img_name.lower().endswith(('.jpg', '.png', '.jpeg')):self.image_paths.append(os.path.join(cls_dir, img_name))self.labels.append(self.class_to_idx[cls_name])def __len__(self):return len(self.image_paths)def __getitem__(self, idx):# 加載圖像img_path = self.image_paths[idx]try:image = Image.open(img_path).convert('RGB')except Exception as e:print(f"Error loading image {img_path}: {e}")# 返回空白圖像作為占位符image = Image.new('RGB', (224, 224), (0, 0, 0))# 應用數據增強和預處理if self.transform:image = self.transform(image)# 獲取標簽label = self.labels[idx]return image, label

定義訓練和驗證的數據增強

#======================
#2. 數據增強與預處理
#======================
def get_transforms():"""返回訓練和驗證的數據增強管道"""# 訓練集增強(更豐富)train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 驗證集預處理(無隨機增強)val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])return train_transform, val_transform

定義模型、損失函數和優化器

#======================
#3. 模型定義
#======================
def create_model(num_classes):"""創建模型(使用預訓練ResNet18)"""model = resnet18(pretrained=True)num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, num_classes)return model

訓練循環,包括驗證

#======================
#4. 訓練函數
#======================
def train_model(model, dataloaders, criterion, optimizer, scheduler, device, num_epochs=25, checkpoint_path='checkpoint.pth', resume=False):"""訓練模型并支持中斷恢復:param resume: 是否從檢查點恢復訓練"""# 訓練歷史記錄history = {'train_loss': [], 'val_loss': [],'train_acc': [], 'val_acc': [],'epoch': 0, 'best_acc': 0.0}# 從檢查點恢復start_epoch = 0if resume and os.path.exists(checkpoint_path):print(f"Loading checkpoint from {checkpoint_path}")checkpoint = torch.load(checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])scheduler.load_state_dict(checkpoint['scheduler_state_dict'])history = checkpoint['history']start_epoch = history['epoch'] + 1print(f"Resuming training from epoch {start_epoch}")# 訓練循環for epoch in range(start_epoch, num_epochs):print(f'Epoch {epoch+1}/{num_epochs}')print('-' * 10)# 更新歷史記錄history['epoch'] = epoch# 每個epoch都有訓練和驗證階段for phase in ['train', 'val']:if phase == 'train':model.train()  # 設置訓練模式else:model.eval()   # 設置評估模式running_loss = 0.0running_corrects = 0# 迭代數據for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 梯度清零optimizer.zero_grad()# 前向傳播with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 訓練階段反向傳播和優化if phase == 'train':loss.backward()optimizer.step()# 統計running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)# 記錄歷史history[f'{phase}_loss'].append(epoch_loss)history[f'{phase}_acc'].append(epoch_acc.item())print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 保存最佳模型if phase == 'val' and epoch_acc > history['best_acc']:history['best_acc'] = epoch_acc.item()torch.save(model.state_dict(), 'best_model.pth')print(f"New best model saved with accuracy: {epoch_acc:.4f}")# 保存檢查點(每個epoch結束后)checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'scheduler_state_dict': scheduler.state_dict(),'history': history}torch.save(checkpoint, checkpoint_path)print(f"Checkpoint saved at epoch {epoch+1}")print()# 保存最終模型torch.save(model.state_dict(), 'final_model.pth')print('Training finished!')return model, history

訓練可視化

#======================
#5. 可視化訓練歷史
#======================
def plot_history(history):plt.figure(figsize=(12, 4))# 損失曲線plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Train Loss')plt.plot(history['val_loss'], label='Validation Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.title('Training and Validation Loss')# 準確率曲線plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Train Accuracy')plt.plot(history['val_acc'], label='Validation Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.title('Training and Validation Accuracy')plt.tight_layout()plt.savefig('training_history.png')plt.show()

整個流程

#======================
#6. 主函數
#======================
def main():# 設置隨機種子(確保可復現性)torch.manual_seed(42)np.random.seed(42)# 檢查設備device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 創建數據增強管道train_transform, val_transform = get_transforms()# 創建數據集train_dataset = CustomImageDataset(root_dir='path/to/your/train_data',  # 替換為你的訓練數據路徑transform=train_transform)val_dataset = CustomImageDataset(root_dir='path/to/your/val_data',    # 替換為你的驗證數據路徑transform=val_transform)# 創建數據加載器train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=4,pin_memory=True)val_loader = DataLoader(val_dataset,batch_size=32,shuffle=False,num_workers=4,pin_memory=True)dataloaders = {'train': train_loader, 'val': val_loader}# 創建模型num_classes = len(train_dataset.classes)model = create_model(num_classes)model = model.to(device)# 定義損失函數和優化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 檢查是否要恢復訓練resume_training = Falsecheckpoint_path = 'checkpoint.pth'# 檢查是否存在檢查點文件if os.path.exists(checkpoint_path):print("Checkpoint file found. Do you want to resume training? (y/n)")response = input().lower()if response == 'y':resume_training = True# 開始訓練start_time = time.time()model, history = train_model(model=model,dataloaders=dataloaders,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,num_epochs=25,checkpoint_path=checkpoint_path,resume=resume_training)end_time = time.time()# 保存訓練歷史with open('training_history.json', 'w') as f:json.dump(history, f, indent=4)# 打印訓練時間training_time = end_time - start_timeprint(f"Total training time: {training_time//3600}h {(training_time%3600)//60}m {training_time%60:.2f}s")# 可視化訓練歷史plot_history(history)if __name__ == "__main__":main()

模型評估

#======================
#模型評估
#======================
def evaluate_model(model, dataloader, device):model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in dataloader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Test Accuracy: {accuracy:.2f}%')return accuracy
test_dataset = CustomImageDataset('path/to/test_data', transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
evaluate_model(model, test_loader, device)

高級功能擴展

混合精度訓練?

from torch.cuda.amp import autocast, GradScaler
#在訓練函數中添加
scaler = GradScaler()
#修改訓練循環
with autocast():outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

分布式訓練?

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
#初始化分布式環境
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
#包裝模型
model = DDP(model.to(local_rank), device_ids=[local_rank])
#修改數據加載器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(..., sampler=train_sampler)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/93681.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/93681.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/93681.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Prompt工程:OCR+LLM文檔處理的精準制導系統

在PDF OCR與大模型結合的實際應用中,很多團隊會發現一個現象:同樣的OCR文本,不同的Prompt設計會產生截然不同的提取效果。有時候準確率能達到95%,有時候卻只有60%。這背后的關鍵就在于Prompt工程的精細化程度。 🎯 為什…

RecSys:粗排模型和精排特征體系

粗排 在推薦系統鏈路中,排序階段至關重要,通常分為召回、粗排和精排三個環節。粗排作為精排前的預處理階段,需要在效果和性能之間取得平衡。 雙塔模型 后期融合:把用戶、物品特征分別輸入不同的神經網絡,不對用戶、…

spring聲明式事務,finally 中return對事務回滾的影響

finally 塊中使用 return 是一個常見的編程錯誤,它會: 跳過正常的事務提交流程。吞掉異常,使錯誤處理失效 導致不可預測的事務行為Java 中 finally 和 return 的執行機制:1. finally 塊的基本特性 在 Java 中,finally …

WPF 打印報告圖片大小的自適應(含完整示例與詳解)

目標:在 FlowDocument 報告里,根據 1~6 張圖片的數量, 自動選擇 2 行 3 列 的最佳布局;在只有 1、2、4 張時保持“占滿感”,打印清晰且不變形。規則一覽:1 張 → 占滿 23(大圖居中)…

【AI大模型前沿】百度飛槳PaddleOCR 3.0開源發布,支持多語言、手寫體識別,賦能智能文檔處理

系列篇章💥 No.文章1【AI大模型前沿】深度剖析瑞智病理大模型 RuiPath:如何革新癌癥病理診斷技術2【AI大模型前沿】清華大學 CLAMP-3:多模態技術引領音樂檢索新潮流3【AI大模型前沿】浙大攜手阿里推出HealthGPT:醫學視覺語言大模…

迅為RK3588開發板Android12 制作使用系統簽名

在 Android 源碼 build/make/target/product/security/下存放著簽名文件,如下所示:將北京迅為提供的 keytool 工具拷貝到 ubuntu 中,然后將 Android11 或 Android12 源碼build/make/target/product/security/下的 platform.pk8 platform.x509…

Day08 Go語言學習

1.安裝Go和Goland 2.新建demo項目實踐語法并使用git實踐版本控制操作 2.1 Goland配置 路徑**:** GOPATH workspace GOROOT golang 文件夾: bin 編譯后的可執行文件 pkg 編譯后的包文件 src 源文件 遇到問題1:運行 ‘go build awesomeProject…

Linux-文件創建拷貝刪除剪切

文章目錄Linux文件相關命令ls通配符含義touch 創建文件命令示例cp 拷貝文件rm 刪除文件mv剪切文件Linux文件相關命令 ls ls是英文單詞list的簡寫,其功能為列出目錄的內容,是用戶最常用的命令之一,它類似于DOS下的dir命令。 Linux文件或者目…

RabbitMQ:交換機(Exchange)

目錄一、概述二、Direct Exchange (直連型交換機)三、Fanout Exchange(扇型交換機)四、Topic Exchange(主題交換機)五、Header Exchange(頭交換機)六、Default Exchange(…

【實時Linux實戰系列】基于實時Linux的物聯網系統設計

隨著物聯網(IoT)技術的飛速發展,越來越多的設備被連接到互聯網,形成了一個龐大而復雜的網絡。這些設備從簡單的傳感器到復雜的工業控制系統,都在實時地產生和交換數據。實時Linux作為一種強大的操作系統,為…

第五天~提取Arxml中描述信息New_CanCluster--Expert

?? ARXML描述信息提取:挖掘汽車電子設計的"知識寶藏" 在AUTOSAR工程中,描述信息如同埋藏在ARXML文件中的金礦,而New_CanCluster--Expert正是打開這座寶藏的密鑰。本文將帶您深度探索ARXML描述信息的提取藝術,解鎖汽車電子設計的核心知識資產! ?? 為什么描述…

開源 C++ QT Widget 開發(一)工程文件結構

文章的目的為了記錄使用C 進行QT Widget 開發學習的經歷。臨時學習,完成app的開發。開發流程和要點有些記憶模糊,趕緊記錄,防止忘記。 相關鏈接: 開源 C QT Widget 開發(一)工程文件結構-CSDN博客 開源 C…

手寫C++ string類實現詳解

類定義cppnamespace ym {class string {private:char* _str; // 字符串數據size_t _size; // 當前字符串長度size_t _capacity; // 當前分配的內存容量static const size_t npos -1; // 特殊值,表示最大可能位置public:// 構造函數和析構函數string(…

C++信息學奧賽一本通-第一部分-基礎一-第3章-第2節

C信息學奧賽一本通-第一部分-基礎一-第3章-第2節 2057 星期幾 #include <iostream>using namespace std;int main() {int day; cin >> day;switch (day) {case 1:cout << "Monday";break;case 2:cout << "Tuesday";break;case 3:c…

【leetcode 3】最長連續序列 (Longest Consecutive Sequence) - 解題思路 + Golang實現

最長連續序列 (Longest Consecutive Sequence) - LeetCode 題解 題目描述 給定一個未排序的整數數組 nums&#xff0c;找出數字連續的最長序列&#xff08;不要求序列元素在原數組中連續&#xff09;的長度。要求設計并實現時間復雜度為 O(n) 的算法解決此問題。 示例 1&#x…

礦物分類系統開發筆記(一):數據預處理

目錄 一、數據基礎與預處理目標 二、具體預處理步驟及代碼解析 2.1 數據加載與初步清洗 2.2 標簽編碼 2.3 缺失值處理 &#xff08;1&#xff09;刪除含缺失值的樣本 &#xff08;2&#xff09;按類別均值填充 &#xff08;3&#xff09;按類別中位數填充 &#xff08;…

《UE5_C++多人TPS完整教程》學習筆記43 ——《P44 奔跑混合空間(Running Blending Space)》

本文為B站系列教學視頻 《UE5_C多人TPS完整教程》 —— 《P44 奔跑混合空間&#xff08;Running Blending Space&#xff09;》 的學習筆記&#xff0c;該系列教學視頻為計算機工程師、程序員、游戲開發者、作家&#xff08;Engineer, Programmer, Game Developer, Author&…

TensorRT-LLM.V1.1.0rc1:Dockerfile.multi文件解讀

一、TensorRT-LLM有三種安裝方式&#xff0c;從簡單到難 1.NGC上的預構建發布容器進行部署,見《tensorrt-llm0.20.0離線部署DeepSeek-R1-Distill-Qwen-32B》。 2.通過pip進行部署。 3.從源頭構建再部署&#xff0c;《TensorRT-LLM.V1.1.0rc0:在無 GitHub 訪問權限的服務器上編…

UniApp 實現pdf上傳和預覽

一、上傳1、html<template><button click"takeFile">pdf上傳</button> </template>2、JStakeFile() {// #ifdef H5// H5端使用input方式選擇文件const input document.createElement(input);input.type file;input.accept .pdf;input.onc…

《用Proxy解構前端壁壘:跨框架狀態共享庫的從零到優之路》

一個項目中同時出現React的函數式組件、Vue的模板語法、Angular的依賴注入時,數據在不同框架體系間的流轉便成了開發者不得不面對的難題—狀態管理,這個本就復雜的命題,在跨框架場景下更顯棘手。而Proxy,作為JavaScript語言賦予開發者的“元編程利器”,正為打破這道壁壘提…