PyTorch 的 DataLoader 是數據加載的核心組件,它能高效地批量加載數據并進行預處理。
Pytorch DataLoader基礎概念
DataLoader基礎概念
DataLoader是PyTorch基礎概念
DataLoader是PyTorch中用于加載數據的工具,它可以:批量加載數據(batch loading)打亂數據(shuffling)并行加載數據(多線程)
自定義數據加載方式Dataloader的基本使用from torch.utils.data import Dataset, DataLoader
自定義數據集類
class MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __getitem__(self, index):return self.data[index], self.labels[index]def __len__(self):return len(self.data)
創建數據集實例
dataset = MyDataset(data, labels)
創建DataLoader
dataloader = DataLoader(dataset=dataset, # 數據集batch_size=32, # 批次大小shuffle=True, # 是否打亂數據num_workers=4, # 多進程加載數據的線程數drop_last=False # 當樣本數不能被batch_size整除時,是否丟棄最后一個不完整的batch
)
# 使用DataLoader迭代數據
for batch_data, batch_labels in dataloader:# 訓練或推理代碼pass
DataLoader重要參數詳解
- dataset: 要加載的數據集,必須是Dataset類的實例 batch_size: 每個批次的樣本數
- shuffle:是否在每個epoch重新打亂數據
- sampler:自定義從數據集中抽取樣本的策略,如果指定了sampler,則shuffle必須為False
- num_workers:使用多少個子進程加載數據,0表示在主進程中加載。
- collate_fn:將一批數據整合成一個批次的函數,特別使用于處理不同長度的序列數據
- Pin_memory:如果為True,數據加載器會將張量復制到CUDA固定內存中,加速CPU到GPU的數據傳輸
- drop_last: 如果數據集大小不能被batch_size整除,是否丟棄最后一個不完整的批次。
- timeout:收集一個批次的超時值
- worker_init_fn:每個worker初始化時被調用的函數
- weight_sampler:參數決定是都使用加權采樣器來平衡類別分布
if infinite_data_loader:data_loader = InfiniteDataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)
else:data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)n_class = len(data.classes)
return data_loader, n_class
這段代碼決定了如何創建數據加載器,根據infinite_data_loader參數選擇不同的加載器類型:
if infinite_data_loader:data_loader = InfiniteDataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)
else:data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)n_class = len(data.classes)
return data_loader, n_class
代碼解析
這段代碼基于infinite_data_loader參數創建不同類型的數據加載器:
當infinite_data_loader為True時:
創建InfiniteDataLoader實例
自定義的無限循環數據加載器,會持續提供數據而不會在一個epoch結束時停止
當infinite_data_loader為False時:
創建標準的PyTorch DataLoader實例
這是普通的數據加載器,一個epoch結束后會停止
共同參數:
dataset=data:要加載的數據集
batch_size=batch_size:每批數據的大小
shuffle=shuffle:是否打亂數據(之前代碼中已設置)
num_workers=num_workers:用于并行加載數據的線程數
sampler=sampler:用于采樣的策略(之前代碼中已設置,可能是加權采樣器)
**kwargs:其他可能的參數,如pin_memory、drop_last等
返回值:
data_loader:創建好的數據加載器
n_class = len(data.classes):數據集中的類別數量
InfiniteDataLoader的作用
在您的代碼中定義了兩種InfiniteDataLoader實現:一種作為DataLoader的子類,另一種是完全自定義的類。它們的共同目的是:
持續提供數據:當一個epoch結束后,自動重新開始,不會引發StopIteration異常
支持長時間訓練:在需要長時間訓練的場景中特別有用,如半監督學習或者領域適應
避免手動重置:不需要在每個epoch結束后手動重置數據加載器
使用場景
無限數據加載器特別適用于:
持續訓練:模型需要無限期地訓練,如自監督學習或強化學習
不均勻更新:源域和目標域數據需要不同頻率的更新
流式訓練:數據以流的形式到達,不需要明確的epoch邊界
基于迭代而非epoch的訓練:訓練基于迭代次數而非數據epoch
最后的返回值n_class提供了數據集的類別數量,這對模型構建和評估都很重要,比如設置分類層的輸出維度或計算平均類別準確率。
高級用法
1.自定義collate_fn處理變長序列
def collate_fn(batch):# 排序批次數據,按序列長度降序batch.sort(key=lambda x: len(x[0]), reverse=True)# 分離數據和標簽sequences, labels = zip(*batch)# 計算每個序列的長度lengths = [len(seq) for seq in sequences]# 填充序列到相同長度padded_seqs = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)return padded_seqs, torch.tensor(labels), lengths
使用自定義的collate_fn
dataloader = DataLoader(dataset=text_dataset,batch_size=16,shuffle=True,collate_fn=collate_fn
)
2.使用Sampler進行不均衡數據采樣
from torch.utils.data import WeightedRandomSampler
假設我們有類別不平衡問題,計算采樣權重
class_count = [100, 1000, 500] # 每個類別的樣本數量
weights = 1.0 / torch.tensor(class_count, dtype=torch.float)
sample_weights = weights[target_list] # target_list是每個樣本的類別索引
創建WeightedRandomSampler
sampler = WeightedRandomSampler(weights=sample_weights,num_samples=len(sample_weights),replacement=True
)
使用sampler
dataloader = DataLoader(dataset=dataset,batch_size=32,sampler=sampler, # 使用sampler時,shuffle必須為Falsenum_workers=4
)