数据集 Dataset
介绍
之前说过,MindSpore是基于Pipeline,通过Dataset和Transformer进行数据处理。Dataset在其中是用来加载原始数据的。mindSpore提供了数据集加载接口,可以加载文本、图像、音频等,同时也可以自定义加载接口。此外还提供了预加载的数据集,可直接使用。
环境配置
import numpy as np
from mindspore.dataset import vision
from mindspore.dataset import MnistDataset, GeneratorDataset
import matplotlib.pyplot as plt
加载dataset
依然使用之前的图片及其标签数据集Mnist
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
数据集迭代
数据集加载后,一般使用迭代的方式获取数据,再送入神经网络中训练。
访问的数据类型默认为Tensor,可以设置为Numpy output_numpy=True
def visualize(dataset):figure = plt.figure(figsize=(4, 4))cols, rows = 3, 3plt.subplots_adjust(wspace=0.5, hspace=0.5)# 这里进行每个数据点的迭代处理for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):figure.add_subplot(rows, cols, idx + 1)plt.title(int(label))plt.axis("off")plt.imshow(image.asnumpy().squeeze(), cmap="gray")# 直到达到指定的数量再结束if idx == cols * rows - 1:breakplt.show()
常用操作
数据集操作采用了异步的执行方式(多亏了pipeline)。具体的表现是,执行操作后会先返回新的dataset,当前未执行具体的操练做,而是在pipeline中加入节点,迭代时才执行整个pipeline。
shuffle
shuffle意思是洗牌,可以改善数据分布不均的问题。
train_dataset = train_dataset.shuffle(buffer_size=64)
map
map实际上不是一个具体的操作,而是对数据集的每一个元素执行指定的数据变换(transformer)并返回这个数据集。变换可能包括简单的数据清洗函数(如删除空值)、更复杂的特征工程函数(如对数变换或独热编码),甚至是深度学习模型进行数据增强的函数。
train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
这里对数据进行了归一化,即缩放到除以255之后变为0-1之间。
归一化之前数据类型是uInt8,除以255后自然的产生了小数,变成了float32
batch
这个操作将数据集打包成了固定大小。实际上就是把数据切成了指定大小的小块。搞成batch之后,可以每次只用加载一小部分到内存中。这解决了大规模数据集无法一次性加载到内存中的问题。
train_dataset = train_dataset.batch(batch_size=32)
经过batch操作之后的dataset会增加一个维度,标记了这个数据的batch_size。
自定义数据集
对于没有预加载和不能使用api加载的数据集,可构造自定义数据加载类或自定义数据集生成函数的方式来生成数据集。再通过GeneratorDataset接口实现自定义方式的数据集加载。这个接口支持通过以下三种方式构造自定义数据集。
可随机访问数据集
实现了__getitem__和__len__方法,可以通过索引或键直接访问相应的数据。
class RandomAccessDataset:
# 初始化data和label为(5,2)形状的1和(5,1)形状的0def __init__(self):self._data = np.ones((5, 2))self._label = np.zeros((5, 1))def __getitem__(self, index):return self._data[index], self._label[index]def __len__(self):return len(self._data)# RAD作为loader,加载进GeneratorDataset的source,并指定列名
loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])# 同时source也支持list和tuple
loader = [np.array(0), np.array(1), np.array(2)]
dataset = GeneratorDataset(source=loader, column_names=["data"])
可迭代数据集
实现了__iter__和__next__方法,可以通过迭代的方式逐步获取数据。
class IterableDataset():def __init__(self, start, end):# 初始化开始和结束数字,用在了后面的_iter_方法中 self.start = startself.end = enddef __next__(self):'''iter one data and return'''return next(self.data)def __iter__(self):'''reset the iter'''self.data = iter(range(self.start, self.end))return self
loader = IterableDataset(1, 5)
dataset = GeneratorDataset(source=loader, column_names=["data"])
# 这个dataset的输出就是【1,2,3,4】
生成器
可迭代,直接依赖Python的生成器类型generator返回数据,直至生成器抛出StopIteration异常。
# 经典的使用yield实现生成器
def my_generator(start, end):for i in range(start, end):yield i
dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=["data"])
# dataset的内容是3,4,5
总结
这节学了一些dataset的加载、操作、以及自定义数据集。