深度學習中Dataset類通用的架構思路
Dataset 類設計的必備部分
1. 初始化 __init__
- 配置和路徑管理:保存
config
,區分train/val/test
路徑。 - 加載原始數據:CSV、JSON、Numpy、Parquet 等。
- 預處理器/歸一化器:如
StandardScaler
,或者 Tokenizer(在 NLP 任務里)。 - 準備輔助信息:比如 meta 特征、文本 embedding。
- 構造樣本列表(self.samples):保證后面取樣時直接
O(1)
訪問。
2. 數據預處理
- normalize / inverse_transform:數值數據標準化和反變換。
- tokenize / pad:文本分詞、對齊。
- feature engineering:特征拼接、缺失值處理。
3. 核心接口
__len__
: 返回數據集樣本數。__getitem__
: 返回一個樣本(通常是(features, label)
的 tuple 或 dict)。
4. 可選接口
get_scaler()
: 返回歸一化器。get_vocab()
: NLP 任務里返回詞表。collate_fn
: 定義 batch 內如何拼接(特別是變長序列)。save_cache
/load_cache
: 大數據集可以存緩存,避免每次都重新處理。
5. 繼承關系
-
BaseDataset:負責
- 通用邏輯(加載文件、歸一化、拼裝 sample)。
- 提供鉤子函數,比如
load_paths(flag)
、process_sample(sample)
。
-
子類:只需要實現 路徑差異 或 樣本加工方式差異。
通用代碼結構示意
class BaseDataset(Dataset):def __init__(self, config, flag="train", scaler=None):self.config = configself.flag = flagself.scaler = scaler or StandardScaler()self.samples = []self._load_data()self._build_samples()def _load_data(self):"""子類可重寫,加載原始數據"""raise NotImplementedErrordef _build_samples(self):"""子類可重寫,拼裝每個樣本的x, y, feats"""raise NotImplementedErrordef __len__(self):return len(self.samples)def __getitem__(self, idx):return self.samples[idx]def get_scaler(self):return self.scalerdef inverse_transform(self, x):return x * self.std + self.mean
子類只管:
class ElectricityDataset(BaseDataset):def _load_data(self):# 只寫路徑和文件加載邏輯passdef _build_samples(self):# 根據任務需要定義樣本結構pass
調用示例
data_config = {"root": "data/electricity/","train_file": "train.json","train_meta_file": "train_meta.npy","train_news_file": "train_news.npy"
}train_config = {"batch_size": 64,"learning_rate": 1e-3,"epochs": 20
}train_ds = ElectricityDataset(data_config, flag="train")train_loader = DataLoader(train_ds,batch_size=train_config["batch_size"],shuffle=True,collate_fn=custom_collate_fn
))