一、DataLoader 是什么?
torch.utils.data.DataLoader
是 PyTorch 中用于加載數據的核心接口,它支持:
- 批量讀取(batch)
- 數據打亂(shuffle)
- 多線程并行加載(num_workers)
- 自動將數據打包成 batch
- 數據預處理和增強(搭配 Dataset 使用)
二、常見參數詳解
參數 | 含義 |
---|---|
dataset | 傳入的 Dataset 對象(如自定義或 torchvision.datasets ) |
batch_size | 每個 batch 的樣本數量 |
shuffle | 是否打亂數據(通常訓練集為 True) |
num_workers | 并行加載數據的線程數(越大越快,但依機器決定) |
drop_last | 是否丟棄最后一個不足 batch_size 的 batch |
pin_memory | 若為 True,會將數據復制到 CUDA 的 page-locked 內存中(加速 GPU 訓練) |
collate_fn | 自定義打包 batch 的函數(可用于變長序列、圖神經網絡等) |
sampler | 控制數據采樣策略,不能與 shuffle 同時使用 |
persistent_workers | 若為 True,worker 在 epoch 間保持運行狀態(提高效率,PyTorch 1.7+) |
三、基本使用示例
搭配 Dataset 使用
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self):self.data = [i for i in range(100)]def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]dataset = MyDataset()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2)for batch in loader:print(batch)
四、自定義 collate_fn 示例
適用于:變長數據(如文本、點云)或特殊處理需求
from torch.nn.utils.rnn import pad_sequencedef my_collate_fn(batch):# 假設每個樣本是 list 或 tensor(變長)batch = [torch.tensor(item) for item in batch]padded = pad_sequence(batch, batch_first=True, padding_value=0)return paddedloader = DataLoader(dataset, batch_size=4, collate_fn=my_collate_fn)
五、使用注意事項
-
Windows 平臺注意:
-
設置
num_workers > 0
時,必須使用:if __name__ == '__main__':DataLoader(...)
-
-
過多線程數可能導致瓶頸:
- 通常
num_workers = cpu_count() // 2
較穩定
- 通常
-
GPU 加速:
- 訓練時推薦設置
pin_memory=True
可提高 GPU 訓練數據傳輸效率。
- 訓練時推薦設置
-
不要同時設置
shuffle=True
和sampler
:- 否則會報錯,二者功能沖突。
六、訓練中的典型使用方式
for epoch in range(num_epochs):for i, batch in enumerate(train_loader):inputs, labels = batchinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()
七、調試技巧與加速建議
場景 | 建議 |
---|---|
數據加載慢 | 增加 num_workers |
GPU 等數據 | 設置 pin_memory=True |
Dataset 中有耗時操作 | 考慮預處理或使用緩存 |
debug 模式 | 設置 num_workers=0 ,禁用多進程 |
八、與 TensorDataset、ImageFolder 配合
from torchvision.datasets import ImageFolder
from torchvision import transformstransform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])dataset = ImageFolder(root='your/image/folder', transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
九、點云數據處理場景應用實例
在 點云數據處理 場景中,使用 torch.utils.data.DataLoader
時,常遇到如下需求:
- 每幀點云大小不同(變長 Tensor)
- 點云數據 + 標簽(如語義、實例)
- 使用
.bin
、.pcd
或.npy
等格式加載 - 數據增強(如旋轉、裁剪、噪聲)
- GPU 加速 + 批量訓練
1. 點云數據 Dataset 示例(以 .npy
文件為例)
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoaderclass PointCloudDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.files = sorted([f for f in os.listdir(root_dir) if f.endswith('.npy')])self.transform = transformdef __len__(self):return len(self.files)def __getitem__(self, idx):point_cloud = np.load(os.path.join(self.root_dir, self.files[idx])) # shape: [N, 3] or [N, 6]point_cloud = torch.tensor(point_cloud, dtype=torch.float32)if self.transform:point_cloud = self.transform(point_cloud)return point_cloud
2. 自定義 collate_fn
(處理變長點云)
def collate_pointcloud_fn(batch):"""輸入: List of [N_i x 3] tensors輸出: - 合并后的 [B x N_max x 3] tensor- 每個樣本的真實點數 list"""max_points = max(pc.shape[0] for pc in batch)padded = torch.zeros((len(batch), max_points, batch[0].shape[1]))lengths = []for i, pc in enumerate(batch):lengths.append(pc.shape[0])padded[i, :pc.shape[0], :] = pcreturn padded, torch.tensor(lengths)
3. 加載器構建示例
dataset = PointCloudDataset("/path/to/your/pointclouds")loader = DataLoader(dataset,batch_size=8,shuffle=True,num_workers=4,pin_memory=True,collate_fn=collate_pointcloud_fn
)for batch_points, batch_lengths in loader:# batch_points: [B, N_max, 3]# batch_lengths: [B]print(batch_points.shape)
4. 可選擴展功能
功能 | 實現方法 |
---|---|
點云旋轉/縮放 | 自定義 transform (例如隨機旋轉矩陣乘點云) |
加載 .pcd | 使用 open3d , pypcd , 或 pclpy |
同時加載標簽 | 在 Dataset 中返回 (point_cloud, label) ,修改 collate_fn |
voxel downsampling | 使用 open3d.geometry.VoxelDownSample |
GPU 加速 | point_cloud = point_cloud.cuda(non_blocking=True) |
5. 訓練循環中使用
for epoch in range(num_epochs):for batch_pc, batch_len in loader:batch_pc = batch_pc.to(device)# 可用 batch_len 做 mask 或 attention maskout = model(batch_pc)...