數據集 Dataset
介紹
之前說過,MindSpore是基于Pipeline,通過Dataset和Transformer進行數據處理。Dataset在其中是用來加載原始數據的。mindSpore提供了數據集加載接口,可以加載文本、圖像、音頻等,同時也可以自定義加載接口。此外還提供了預加載的數據集,可直接使用。
環境配置
import numpy as np
from mindspore.dataset import vision
from mindspore.dataset import MnistDataset, GeneratorDataset
import matplotlib.pyplot as plt
加載dataset
依然使用之前的圖片及其標簽數據集Mnist
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
數據集迭代
數據集加載后,一般使用迭代的方式獲取數據,再送入神經網絡中訓練。
訪問的數據類型默認為Tensor,可以設置為Numpy output_numpy=True
def visualize(dataset):figure = plt.figure(figsize=(4, 4))cols, rows = 3, 3plt.subplots_adjust(wspace=0.5, hspace=0.5)# 這里進行每個數據點的迭代處理for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):figure.add_subplot(rows, cols, idx + 1)plt.title(int(label))plt.axis("off")plt.imshow(image.asnumpy().squeeze(), cmap="gray")# 直到達到指定的數量再結束if idx == cols * rows - 1:breakplt.show()
常用操作
數據集操作采用了異步的執行方式(多虧了pipeline)。具體的表現是,執行操作后會先返回新的dataset,當前未執行具體的操練做,而是在pipeline中加入節點,迭代時才執行整個pipeline。
shuffle
shuffle意思是洗牌,可以改善數據分布不均的問題。
train_dataset = train_dataset.shuffle(buffer_size=64)
map
map實際上不是一個具體的操作,而是對數據集的每一個元素執行指定的數據變換(transformer)并返回這個數據集。變換可能包括簡單的數據清洗函數(如刪除空值)、更復雜的特征工程函數(如對數變換或獨熱編碼),甚至是深度學習模型進行數據增強的函數。
train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
這里對數據進行了歸一化,即縮放到除以255之后變為0-1之間。
歸一化之前數據類型是uInt8,除以255后自然的產生了小數,變成了float32
batch
這個操作將數據集打包成了固定大小。實際上就是把數據切成了指定大小的小塊。搞成batch之后,可以每次只用加載一小部分到內存中。這解決了大規模數據集無法一次性加載到內存中的問題。
train_dataset = train_dataset.batch(batch_size=32)
經過batch操作之后的dataset會增加一個維度,標記了這個數據的batch_size。
自定義數據集
對于沒有預加載和不能使用api加載的數據集,可構造自定義數據加載類或自定義數據集生成函數的方式來生成數據集。再通過GeneratorDataset接口實現自定義方式的數據集加載。這個接口支持通過以下三種方式構造自定義數據集。
可隨機訪問數據集
實現了__getitem__和__len__方法,可以通過索引或鍵直接訪問相應的數據。
class RandomAccessDataset:
# 初始化data和label為(5,2)形狀的1和(5,1)形狀的0def __init__(self):self._data = np.ones((5, 2))self._label = np.zeros((5, 1))def __getitem__(self, index):return self._data[index], self._label[index]def __len__(self):return len(self._data)# RAD作為loader,加載進GeneratorDataset的source,并指定列名
loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])# 同時source也支持list和tuple
loader = [np.array(0), np.array(1), np.array(2)]
dataset = GeneratorDataset(source=loader, column_names=["data"])
可迭代數據集
實現了__iter__和__next__方法,可以通過迭代的方式逐步獲取數據。
class IterableDataset():def __init__(self, start, end):# 初始化開始和結束數字,用在了后面的_iter_方法中 self.start = startself.end = enddef __next__(self):'''iter one data and return'''return next(self.data)def __iter__(self):'''reset the iter'''self.data = iter(range(self.start, self.end))return self
loader = IterableDataset(1, 5)
dataset = GeneratorDataset(source=loader, column_names=["data"])
# 這個dataset的輸出就是【1,2,3,4】
生成器
可迭代,直接依賴Python的生成器類型generator返回數據,直至生成器拋出StopIteration異常。
# 經典的使用yield實現生成器
def my_generator(start, end):for i in range(start, end):yield i
dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=["data"])
# dataset的內容是3,4,5
總結
這節學了一些dataset的加載、操作、以及自定義數據集。