【一起來學AI大模型】PyTorch DataLoader 實戰指南

DataLoader 是 PyTorch 中處理數據的核心組件,它提供了高效的數據加載、批處理和并行處理功能。下面是一個全面的 DataLoader 實戰指南,包含代碼示例和最佳實踐。

基礎用法:簡單數據加載

import torch
from torch.utils.data import Dataset, DataLoader# 1. 創建自定義數據集
class SimpleDataset(Dataset):def __init__(self, size=1000):self.data = torch.randn(size, 3, 32, 32)  # 模擬圖像數據self.labels = torch.randint(0, 10, (size,))  # 0-9的標簽def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 2. 創建DataLoader
dataset = SimpleDataset(1000)
dataloader = DataLoader(dataset,batch_size=64,       # 批大小shuffle=True,        # 是否打亂數據num_workers=4,       # 使用4個進程加載數據pin_memory=True      # 使用固定內存(加速GPU傳輸)
)# 3. 使用DataLoader
for epoch in range(3):print(f"Epoch {epoch+1}")for batch_idx, (data, targets) in enumerate(dataloader):# 數據自動分批:data.shape = [64, 3, 32, 32], targets.shape = [64]if batch_idx % 10 == 0:print(f"  Batch {batch_idx}: {data.shape}, {targets.shape}")print("Epoch completed\n")

高級功能:自定義數據集與轉換

圖像數據集示例

import os
from PIL import Image
from torchvision import transformsclass CustomImageDataset(Dataset):def __init__(self, img_dir, transform=None):self.img_dir = img_dirself.transform = transformself.img_names = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]# 假設文件名格式為 "label_imageid.jpg",例如 "3_001.jpg"self.labels = [int(f.split('_')[0]) for f in self.img_names]def __len__(self):return len(self.img_names)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_names[idx])image = Image.open(img_path).convert('RGB')label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 定義數據轉換
transform = transforms.Compose([transforms.Resize((256, 256)),      # 調整大小transforms.RandomHorizontalFlip(),   # 隨機水平翻轉transforms.RandomRotation(15),       # 隨機旋轉 ±15度transforms.ToTensor(),               # 轉為Tensor [0,1]transforms.Normalize(                # 標準化mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 創建數據集和DataLoader
dataset = CustomImageDataset('/path/to/images', transform=transform)
dataloader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=4,collate_fn=lambda batch: tuple(zip(*batch))  # 自定義批處理函數
)

文本數據集示例

from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizerclass TextDataset(Dataset):def __init__(self, file_path, max_len=100):self.max_len = max_lenself.tokenizer = get_tokenizer('basic_english')# 讀取文本數據和標簽self.texts = []self.labels = []with open(file_path, 'r', encoding='utf-8') as f:for line in f:label, text = line.split('\t')self.labels.append(int(label))self.texts.append(text.strip())# 構建詞匯表self.vocab = build_vocab_from_iterator((self.tokenizer(text) for text in self.texts),specials=['<unk>', '<pad>'])self.vocab.set_default_index(self.vocab['<unk>'])def __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]tokens = self.tokenizer(text)# 將token轉換為索引indices = [self.vocab[token] for token in tokens]# 截斷或填充序列if len(indices) > self.max_len:indices = indices[:self.max_len]else:indices = indices + [self.vocab['<pad>']] * (self.max_len - len(indices))return torch.tensor(indices), self.labels[idx]# 自定義批處理函數(處理變長序列)
def collate_fn(batch):texts, labels = zip(*batch)# 找到批次中最長序列的長度max_len = max(len(t) for t in texts)# 填充所有序列到相同長度padded_texts = []for text in texts:padding = torch.zeros(max_len - len(text), dtype=torch.long)padded_texts.append(torch.cat((text, padding)))return torch.stack(padded_texts), torch.tensor(labels)# 創建DataLoader
text_dataset = TextDataset('/path/to/text_data.txt', max_len=100)
text_dataloader = DataLoader(text_dataset,batch_size=32,shuffle=True,num_workers=2,collate_fn=collate_fn  # 使用自定義批處理函數
)

性能優化技巧

1. 使用并行加載

# 根據CPU核心數設置num_workers
import os
num_workers = min(4, os.cpu_count())  # 使用不超過4個或CPU核心數的workerdataloader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=num_workers,pin_memory=True,  # 對于GPU訓練非常重要persistent_workers=True  # 保持worker進程活動(PyTorch 1.7+)
)

2. 數據預取

from torch.utils.data import DataLoader, PrefetchGenerator# 使用預取生成器(PyTorch 1.7+)
dataloader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4,prefetch_factor=2  # 每個worker預取的批次數
)# 或者使用自定義預取
class PrefetchLoader:def __init__(self, loader, device):self.loader = loaderself.device = deviceself.stream = torch.cuda.Stream() if device.type == 'cuda' else Nonedef __iter__(self):first = Truefor batch in self.loader:if self.stream is not None:with torch.cuda.stream(self.stream):batch = self._preprocess(batch)else:batch = self._preprocess(batch)if not first and self.stream is not None:torch.cuda.current_stream().wait_stream(self.stream)first = Falseyield batchdef _preprocess(self, batch):data, target = batchreturn data.to(self.device, non_blocking=True), target.to(self.device, non_blocking=True)# 使用自定義預取
device = torch.device('cuda')
prefetch_dataloader = PrefetchLoader(dataloader, device)

3. 內存映射文件處理大文件

import numpy as np
import torch
from torch.utils.data import Datasetclass MmapDataset(Dataset):def __init__(self, file_path, shape, dtype=np.float32):self.data = np.memmap(file_path, dtype=dtype, mode='r', shape=shape)def __len__(self):return self.data.shape[0]def __getitem__(self, idx):return torch.from_numpy(np.array(self.data[idx]))

分布式數據加載

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler# 初始化分布式環境
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()# 創建分布式采樣器
sampler = DistributedSampler(dataset,num_replicas=world_size,rank=rank,shuffle=True,seed=42
)# 創建分布式DataLoader
dist_dataloader = DataLoader(dataset,batch_size=64,sampler=sampler,num_workers=4,pin_memory=True,drop_last=True  # 丟棄最后不完整的批次
)# 在每個進程中
for epoch in range(10):# 設置epoch確保所有進程的shuffle一致dist_dataloader.sampler.set_epoch(epoch)for batch in dist_dataloader:# 處理批次數據pass

數據增強策略

圖像增強

from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2# 使用torchvision
torchvision_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 使用Albumentations(更豐富的增強)
albumentations_transform = A.Compose([A.RandomResizedCrop(224, 224),A.HorizontalFlip(p=0.5),A.VerticalFlip(p=0.2),A.Rotate(limit=30),A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.9),A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),ToTensorV2()
])# 在數據集類中使用
def __getitem__(self, idx):img_path = self.img_paths[idx]image = cv2.imread(img_path)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)if self.transform:augmented = self.transform(image=image)image = augmented['image']return image, self.labels[idx]

文本增強

import nlpaug.augmenter.word as naw# 創建文本增強器
augmenter = naw.ContextualWordEmbsAug(model_path='bert-base-uncased', action="substitute",  # 替換、插入等aug_p=0.1  # 增強比例
)# 在數據集中使用
def __getitem__(self, idx):text = self.texts[idx]if self.augment and random.random() < 0.5:  # 50%概率增強text = augmenter.augment(text)# 后續處理...

數據可視化與調試

import matplotlib.pyplot as plt
import numpy as npdef show_batch(dataloader, n=4):"""顯示一批圖像及其標簽"""dataiter = iter(dataloader)images, labels = next(dataiter)fig, axes = plt.subplots(1, n, figsize=(15, 4))for i in range(n):img = images[i].permute(1, 2, 0).numpy()  # CHW -> HWCimg = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # 反歸一化img = np.clip(img, 0, 1)axes[i].imshow(img)axes[i].set_title(f"Label: {labels[i].item()}")axes[i].axis('off')plt.show()# 使用
show_batch(dataloader, n=8)

常見問題解決方案

1. 內存不足

# 解決方案1:使用更小的批大小
dataloader = DataLoader(dataset, batch_size=16)# 解決方案2:使用內存映射文件
# 如前文的MmapDataset示例# 解決方案3:使用IterableDataset
from torch.utils.data import IterableDatasetclass LargeIterableDataset(IterableDataset):def __init__(self, file_path, chunk_size=1000):self.file_path = file_pathself.chunk_size = chunk_sizedef __iter__(self):with open(self.file_path, 'r') as f:chunk = []for line in f:chunk.append(process_line(line))  # 自定義處理函數if len(chunk) == self.chunk_size:yield from chunkchunk = []if chunk:yield from chunk# 使用
dataset = LargeIterableDataset('large_file.txt')
dataloader = DataLoader(dataset, batch_size=64)

2. Windows多進程問題

# 解決方案:將主代碼放入if __name__ == '__main__'塊中
if __name__ == '__main__':# 在這里創建DataLoaderdataloader = DataLoader(dataset, num_workers=4)# 訓練代碼...

3. 數據加載成為瓶頸

# 解決方案1:增加num_workers
dataloader = DataLoader(dataset, num_workers=os.cpu_count())# 解決方案2:使用預取
# 如前文的PrefetchLoader示例# 解決方案3:使用更快的存儲(如SSD代替HDD)# 解決方案4:使用更高效的數據格式(如HDF5、LMDB)

最佳實踐總結

  1. 批大小選擇:根據GPU內存選擇最大可用批大小

  2. Worker數量:設置為CPU核心數的1-2倍

  3. 固定內存:GPU訓練時始終設置pin_memory=True

  4. 數據增強:在CPU上執行,避免占用GPU資源

  5. 分布式訓練:使用DistributedSampler確保數據正確分區

  6. 內存優化:對大文件使用內存映射或IterableDataset

  7. 預取策略:使用內置prefetch_factor或自定義預取

  8. 數據驗證:定期可視化批次數據確保數據增強有效

  9. 資源監控:監控CPU/GPU利用率,識別瓶頸

  10. 格式優化:使用高效數據格式(如TFRecord、LMDB)加速IO

通過合理配置DataLoader,你可以顯著提高模型訓練效率,充分利用硬件資源,加速模型迭代過程。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/90998.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/90998.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/90998.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SpringBoot單元測試類拿不到bean報空指針異常

原代碼package com.atguigu.gulimall.product;import com.aliyun.oss.OSSClient; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; impo…

持續集成 簡介環境搭建

1. 持續集成簡介 1.1 持續集成的作用 隨著互聯網的蓬勃發展,軟件生命周期模型也經歷了幾個比較大的階段,從最初的瀑布模型,到 V 模型,再到現在的敏捷或者 devops,不論哪個階段,項目從立項到交付幾乎都離不開以下幾個過程,開發、構建、測試和發布,而且一直都在致力于又…

關于 java:11. 項目結構、Maven、Gradle 構建系統

一、Java 項目目錄結構標準1.1 Java 項目標準目錄結構總覽標準 Java 項目目錄結構&#xff08;以 Maven / Gradle 通用結構為基礎&#xff09;&#xff1a;project-root/ ├── src/ │ ├── main/ │ │ ├── java/ # 主業務邏輯代碼&#xff08;核心…

大數據的安全挑戰與應對

在大數據時代&#xff0c;大數據安全問題已成為開發者最為關注的核心議題之一。至少五年來&#xff0c;大數據已融入各類企業的運營體系&#xff0c;而采用先進數據分析解決方案的組織數量仍在持續增長。本文將明確當前市場中最關鍵的大數據安全問題與威脅&#xff0c;概述企業…

PostgreSQL ERROR: out of shared memory處理方式

系統允許的總鎖數 SELECT (SELECT setting::int FROM pg_settings WHERE name max_locks_per_transaction) * (SELECT setting::int FROM pg_settings WHERE name max_connections) (SELECT setting::int FROM pg_settings WHERE name max_prepared_transactions);當鎖大于…

Django 模型(Model)

1. 模型簡介 ORM 簡介 MVC 框架中一個重要的部分就是 ORM,它實現了數據模型與數據庫的解耦,即數據模型的設計不需要依賴于特定的數據庫,通過簡單的配置就可以輕松更換數據庫。即直接面向對象操作數據,無需考慮 sql 語句。 ORM 是“對象-關系-映射”的簡稱,主要任務是:…

深入解析Hadoop RPC:技術細節與推廣應用

Hadoop RPC框架概述在分布式系統的核心架構中&#xff0c;遠程過程調用&#xff08;RPC&#xff09;機制如同神經網絡般連接著各個計算節點。Hadoop作為大數據處理的基石&#xff0c;其自主研發的RPC框架不僅支撐著內部組件的協同運作&#xff0c;更以獨特的工程哲學詮釋了分布…

為什么玩游戲用UDP,看網頁用TCP?

故事場景&#xff1a;兩種不同的遠程溝通方式假設你需要和遠方的朋友溝通一件重要的事情。方式一&#xff1a;TCP — 打一個重要的電話打電話是一種非常嚴謹、可靠的溝通方式。? 1. 建立連接 (三次握手):? 你拿起電話&#xff0c;撥號&#xff08;SYN&#xff09;。? 朋友那…

【EGSR2025】材質+擴散模型+神經網絡相關論文整理隨筆(二)

High-Fidelity Texture Transfer Using Multi-Scale Depth-Aware Diffusion 這篇文章可以從一個帶有紋理的幾何物體出發&#xff0c;將其身上的紋理自動提取并映射到任意的幾何拓撲結構上&#xff08;見下圖紅線左側&#xff09;&#xff1b;或者從一個白模幾何對象出發&#x…

深度學習圖像分類數據集—玉米粒質量識別分類

該數據集為圖像分類數據集&#xff0c;適用于ResNet、VGG等卷積神經網絡&#xff0c;SENet、CBAM等注意力機制相關算法&#xff0c;Vision Transformer等Transformer相關算法。 數據集信息介紹&#xff1a;玉米粒質量識別分類&#xff1a;[crush, good, mul] 訓練數據集總共有3…

Unity VR手術模擬系統架構分析與數據流設計

Unity VR手術模擬系統架構分析與數據流設計 前言 本文將深入分析一個基于Unity引擎開發的多人VR手術模擬系統。該系統采用先進的網絡架構設計&#xff0c;支持多用戶實時協作&#xff0c;具備完整的手術流程引導和精確的工具交互功能。通過對系統架構和數據管道的詳細剖析&…

【Spring Boot】Spring Boot 4.0 的顛覆性AI特性全景解析,結合智能編碼實戰案例、底層架構革新及Prompt工程手冊

Spring Boot 4.0 的顛覆性AI特性全景解析&#xff0c;結合智能編碼實戰案例、底層架構革新及Prompt工程手冊一、Spring Boot 4.0 核心AI能力矩陣二、AI智能編碼插件實戰&#xff08;Spring AI Assistant&#xff09;1. 安裝與激活2. 實時代碼生成場景3. 缺陷預測與修復三、AI引…

audiobookshelf-web 項目怎么運行

git clone https://github.com/audiobookshelf/audiobookshelf-web.git cd audiobookshelf-web npm i 啟動項目 npm run dev http://localhost:3000/

掃描文件 PDF / 圖片 糾斜 | 圖片去黑邊 / 裁剪 / 壓縮

問題&#xff1a;掃描后形成的 PDF 或圖片文檔常存在變形傾斜等問題&#xff0c;手動調整頗為耗時費力。 一、PDF 糾斜 - Adobe Acrobat DC 1、所用功能 掃描和 OCR&#xff1a; 識別文本&#xff1a;在文件中 → 設置 確定后啟動掃描&#xff0c;識別過程中自動糾偏。 2、…

適配器模式:兼容不兼容接口

將一個類的接口轉換成客戶端期望的另一個接口&#xff0c;解決接口不兼容問題。代碼示例&#xff1a;// 目標接口&#xff08;客戶端期望的格式&#xff09; interface ModernPrinter {void printDocument(String text); }// 被適配的舊類&#xff08;不兼容&#xff09; class…

流程控制:從基礎結構到跨語言實踐與優化

流程控制 一、流程控制基礎概念與核心價值 &#xff08;一&#xff09;流程控制定義與本質 流程控制是通過特定邏輯結構決定程序執行順序的機制&#xff0c;核心是控制代碼運行路徑&#xff0c;包括順序執行、條件分支、循環迭代三大核心邏輯。其本質是將無序的指令集合轉化為有…

Http與Https區別和聯系

一、HTTP 詳解 HTTP&#xff08;HyperText Transfer Protocol&#xff09;?? 是互聯網數據通信的基礎協議&#xff0c;用于客戶端&#xff08;瀏覽器&#xff09;與服務器之間的請求-響應交互 核心特性??&#xff1a; 1.無連接&#xff08;Connectionless&#xff09;??…

飛算JavaAI:開啟 Java 開發 “人機協作” 新紀元

每日一句 明天是新的一天&#xff0c; 你也不再是昨天的你。 目錄每日一句一、需求到架構&#xff1a;AI深度介入開發“源頭設計”1.1 需求結構化&#xff1a;自然語言到技術要素的精準轉化1.2 架構方案生成&#xff1a;基于最佳實踐的動態適配二、編碼全流程&#xff1a;從“…

Qt項目鍛煉——TODO(五)

發現問題如果是自己創建的ui文件&#xff0c;怎么包含進自己的窗口類并且成為ui成員&#xff1f;一般來說Qt designer 會根據你.ui文件生成對應的ui_文件名這個類&#xff08;文件名是ui文件名&#xff09;&#xff0c;它包含了所有 UI 組件&#xff08;如按鈕、文本框、標簽等…

Vue框架之模板語法全面解析

Vue框架之模板語法全面解析一、模板語法的核心思想二、插值表達式&#xff1a;數據渲染的基礎2.1 基本用法&#xff1a;渲染文本2.2 純HTML渲染&#xff1a;v-html指令2.3 一次性插值&#xff1a;v-once指令三、指令系統&#xff1a;控制DOM的行為3.1 條件渲染&#xff1a;v-if…