昇思25天學習打卡營第04天 | 數據集 Dataset
文章目錄
- 昇思25天學習打卡營第04天 | 數據集 Dataset
- 數據集加載
- 數據集迭代
- 數據集的變換
- shuffle
- map
- batch
- 自定義數據集
- 可隨機訪問數據集對象
- 可迭代數據集
- 生成器
- 總結
- 打卡
數據集Dataset對原始數據進行封裝、變換,為神經網絡提供高質量的輸入數據。
mindspore.dataset
內置了的文本、圖像、音頻等數據集的加載接口,也提供創建自定義數據集的方法。
數據集加載
mindspore.dataset
僅支持解壓后的數據文件,對于壓縮包文件,需要先解壓才能創建數據集:
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True) # 下載并解壓到"./"目錄下train_dataset = MnistDataset("MNIST_Data/train", shuffle=False) # 通過數據集的文件目錄直接創建dataset
數據集迭代
數據集可以通過create_tuple_iterator
或create_dict_iterator
接口創建迭代器進行訪問。
訪問的類型默認為Tensor
,如果設置output_numpy=True
,則訪問Numpy
。
def visualize(dataset):figure = plt.figure(figsize=(4, 4))cols, rows = 3, 3plt.subplots_adjust(wspace=0.5, hspace=0.5)for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):figure.add_subplot(rows, cols, idx + 1)plt.title(int(label))plt.axis("off")plt.imshow(image.asnumpy().squeeze(), cmap="gray")if idx == cols * rows - 1:breakplt.show()
數據集的變換
shuffle
shuffle
可以對數據集進行打亂,消除數據排列不均勻的問題。
mindspore.dataset
可以在加載數據集的時候配置shuffle=True
,或者通過mindspore.dataset.shuffle(buffer_size)
進行打亂。
map
map
是數據預處理的關鍵,可以針對數據集的指定列添加變換,將指定變換應用于該列數據的每個元素,并返回變換后的新元素。
train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
batch
將數據集中的數據打包為batch有利于在資源受限的情況下使用梯度下降,可以保證梯度下降的隨機性和優化計算量。
train_dataset = train_dataset.batch(batch_size = 32)
通過batch
后數據增加一維,大小為batch_size
。
自定義數據集
通過構造自定義數據加載類和自定義數據集生成函數的方法來生成數據集,并通過GeneratorDataset
接口實現數據集的加載
GeneratorDataset
支持可隨機訪問數據集對象,可迭代數據集對象和生成器構造數據集。
可隨機訪問數據集對象
是指實現了__getitem__
和__len__
方法的數據集,可以通過索引/鍵直接訪問對應位置的數據。
class RandomAccessDataset:def __init__(self):self._data = np.ones((5, 2))self._label = np.zeros((5, 1))def __getitem__(self, index):return self._data[index], self._label[index]def __len__(self):return len(self._data)loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])
可迭代數據集
是指實現了__iter__
和__next__
方法的數據集,可以通過迭代的方式獲取數據樣本,適用于隨機訪問成本太高或不可行的情況。
class IterableDataset():def __init__(self, start, end):'''init the class object to hold the data'''self.start = startself.end = enddef __next__(self):'''iter one data and return'''return next(self.data)def __iter__(self):'''reset the iter'''self.data = iter(range(self.start, self.end))return selfloader = IterableDataset(1, 5)
dataset = GeneratorDataset(source=loader, column_names=["data"])
生成器
生成器屬于可迭代的數據集類型,依賴于python的generator
返回數據,直到生成器拋出StopIteration
異常。
def my_generator(start, end):for i in range(start, end):yield idataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=["data"])
總結
通過這一節的學習,對MindSpore中Dataset的使用方法有了一定的了解,掌握了通過download
下載/解壓數據并創建數據集,通過迭代器方位數據集中的元素,通過shuffle
打亂數據,通過map
變換數據以及將數據打包為固定大小的patch
。此外,還了解了自定義數據集的創建以及自定義數據集需要實現的方法,為之后的訓練數據管理和處理打下了基礎。