• MindSpore入门教程03 数据加载



    在神经网络训练和推理流程中,原始数据一般存储在磁盘或数据库中,需要首先通过数据加载步骤将其读取到内存空间,转换成张量(Tensor)格式,然后进行数据处理和增强步骤,增加样本的数量和泛化性,最后输入到网络进行计算。

    这里主要介绍MindSpore数据加载的方法。MindSpore数据加载方法根据数据的类型可以分为标准数据集用户自定义数据集,以及MindRecord格式数据集。其中,常用标准数据集包括如MNIST、CIFAR-10、CIFAR-100、VOC、COCO、ImageNet、CelebA、CLUE等;MindRecordMindSpore开发的一种高效数据格式,mindspore.mindrecord模块提供了一些方法帮助用户将不同数据集转换为MindRecord格式, 也提供了一些方法读取、写入或者检索MindRecord格式文件。

    1、标准数据集加载

    MindSpore.Dataset类提供了很多标准数据集加载的方法,如视觉图像类数据集mindspore.dataset.Cifar10Dataset, mindspore.dataset.CocoDataset;文本类数据集mindspore.dataset.CLUEDataset, mindspore.dataset.IWSLT2017Dataset; 音频类数据集mindspore.dataset.LJSpeechDataset, mindspore.dataset.TedliumDataset等等。

    接下来介绍CIFAR-10数据集的加载方法。
    关于CIFAR-10数据集:

    CIFAR-10数据集由60000张32x32彩色图片组成,总共有10个类别,每类6000张图片。有50000个训练样本和10000个测试样本。10个类别包含飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。详情可参考CIFAR-10

    以下为原始CIFAR-10数据集的结构,可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。

    .
    └── cifar-10-batches-bin
        ├── data_batch_1.bin
        ├── data_batch_2.bin
        ├── data_batch_3.bin
        ├── data_batch_4.bin
        ├── data_batch_5.bin
        ├── test_batch.bin
        ├── readme.html
        └── batches.meta.text
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    MindSpore.Dataset 提供的读取数据集的方法是

    class mindspore.dataset.Cifar10Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None)
    
    • 1

    其中,

    dataset_dir 是数据集存放路径。

    usage 指定数据集的子集,可取值为’train’,’test’或’all’,取值为’train’时,将会读取50,000个训练样本;取值为’test’时,将会读取10,000个测试样本;取值为’all’时将会读取全部60,000个样本;默认值为None,读取全部样本图片。

    num_samples 指定从数据集中读取的样本数,小于数据集总数,默认时读取全部样本图片。

    num_parallel_workers 指定读取数据的线程数。默认使用mindspore.dataset.config中配置的线程数。

    shuffle 表示是否混洗数据集。

    sampler 指定从数据集中选取样本的采样器,配置 sampler shuffle 的不同组合可以得到不同的排序结果。

    num_shards 指定分布式训练时将数据集进行划分的分片数,指定此参数后, num_samples 表示每个分片的最大样本数。

    shard_id 指定分布式训练时使用的分片ID号,只有当指定了 num_shards 时才能指定此参数。

    cache 表示单节点数据缓存服务,用于加快数据集处理,详情请参考 单节点数据缓存 。默认不使用缓存。

    使用方法如下:

    from mindspore import dataset as ds
    
    cifar10_dataset_dir = r"/path/to/cifar10_dataset_directory"
    
    # 1) Get all samples from CIFAR10 dataset in sequence
    dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, shuffle=False, num_samples=10)
    data_size = dataset.get_dataset_size()
    print("data_size:", data_size)
    data = next(dataset.create_dict_iterator())
    print("data:", data)
    
    #output
    #data_size: 10
    #data: {'image': Tensor(shape=[32, 32, 3], dtype=UInt8, value=[...]), 'label': Tensor(shape=[], #dtype=UInt32, value= 6)}
    
    # In CIFAR10 dataset, each dictionary has keys "image" and "label"
    print("image shape:", data['image'].shape)
    print("label:", data['label'])
    
    #output
    #image shape: (32, 32, 3)
    #label: 6
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    其中,get_dataset_size()获取数据集的大小,create_dict_iterator()以字典格式返回数据集的迭代器,通过print("data:", data)的打印结果可以看出,该数据集的字典的key分别是“image”,“label”,分别表示图像和分类标签。

    如果设置shuffle=True

    from mindspore import dataset as ds
    
    cifar10_dataset_dir = r"/path/to/cifar10_dataset_directory"
    
    dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=350, shuffle=True)
    data_size = dataset.get_dataset_size()
    print("data_size:", data_size)
    iter = dataset.create_dict_iterator()
    for i in range(5): # 选取部分数据
        data = next(iter)
        print("i=", i)
        print("image shape:", data['image'].shape)
        print("label:", data['label'])
        
    #output
    """
    data_size: 350
    i= 0
    image shape: (32, 32, 3)
    label: 1
    i= 1
    image shape: (32, 32, 3)
    label: 2
    i= 2
    image shape: (32, 32, 3)
    label: 8
    i= 3
    image shape: (32, 32, 3)
    label: 7
    i= 4
    image shape: (32, 32, 3)
    label: 0
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    此时,读取数据集的大小为350, 第一个数据的标签不再与前面的相等(之前等于6,现在等于1)。并且,重新执行上面的代码,得到数据标签也将发生变化。

    关于Cifar10Dataset的其他用法可以参考MindSpore官网资料:mindspore.dataset.Cifar10Dataset

    关于其他标准数据集的加载方法,可以参考MindSpore官网资料:mindspore.dataset

    2、用户自定义数据集加载

    MindSpore提供接口GeneratorDataset支持用户自定义数据集的加载。接口定义如下:

    class mindspore.dataset.GeneratorDataset(source, column_names=None, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6)
    
    • 1

    其中,必须设置的参数为:

    • source (Union[Callable, Iterable, Random Accessible]) - 是一个Python的可调用对象,支持可迭代或支持随机访问。
      • 如果 source 是可调用对象,要求 source 对象可以通过source().next()的方式返回一个由NumPy数组构成的元组。
      • 如果 source 是可迭代对象,要求 source 对象通过 iter(source).next() 的方式返回一个由NumPy数组构成的元组。
      • 如果 source 是支持随机访问的对象,要求 source 对象通过 source[idx] 的方式返回一个由NumPy数组构成的元组。

    然后可选参数包括:

    • column_names (Union[str, list[str]],可选) - 指定数据集生成的列名,默认值:None,不指定。用户可以通过此参数或 schema 参数指定列名。
    • column_types (list[mindspore.dtype],可选) - 指定生成数据集各个数据列的数据类型,默认值:None,不指定。 如果未指定该参数,则自动推断类型;如果指定了该参数,将在数据输出时做类型匹配检查。
    • schema (Union[Schema, str],可选) - 读取模式策略,用于指定读取数据列的数据类型、数据维度等信息。 支持传入JSON文件路径或 mindspore.dataset.Schema 构造的对象。默认值:None,不指定。 用户可以通过提供 column_namesschema指定数据集的列名,但如果同时指定两者,则将优先从schema中获取列名信息。
    • num_samples (int,可选) - 指定从数据集中读取的样本数,默认值:None,读取全部样本。
    • num_parallel_workers (int,可选) - 指定读取数据的工作进程数/线程数(由参数 python_multiprocessing 决定当前为多进程模式或多线程模式),默认值:1。
    • shuffle (bool,可选) - 是否混洗数据集。只有输入的 source 参数带有可随机访问属性(__getitem__)时,才可以指定该参数。默认值:None,下表中会展示不同配置的预期行为。
    • sampler (Union[Sampler, Iterable],可选) - 指定从数据集中选取样本的采样器。只有输入的 source 参数带有可随机访问属性(__getitem__)时,才可以指定该参数。默认值:None,下表中会展示不同配置的预期行为。
    • num_shards (int, 可选) - 指定分布式训练时将数据集进行划分的分片数,默认值:None。指定此参数后, num_samples 表示每个分片的最大样本数。
    • shard_id (int, 可选) - 指定分布式训练时使用的分片ID号,默认值:None。只有当指定了 num_shards 时才能指定此参数。
    • python_multiprocessing (bool,可选) - 启用Python多进程模式加速运算,默认值:True。当传入 source 的Python对象的计算量很大时,开启此选项可能会有较好效果。
    • max_rowsize (int,可选) - 指定在多进程之间复制数据时,共享内存分配的最大空间,默认值:6,单位为MB。仅当参数 python_multiprocessing=True时,此参数才会生效。

    接下来看使用示例。

    2.1 可调用的Python对象(单列)

    import numpy as np
    from mindspore import dataset as ds
    
    # 1) Multidimensional generator function as callable input.
    
    def generator_multidimensional():
        for i in range(64):
            yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
    
    # 支持next方法
    print(next(generator_multidimensional()))
    
    dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["multi_dimensional_data"])
    data_size = dataset.get_dataset_size()
    print("data_size:", data_size)
    iter = dataset.create_dict_iterator()
    print(next(iter))
    
    #output
    """
    (array([[0, 1],
           [2, 3]]),)
    [WARNING] ME(91776:103592,MainProcess):2022-08-21-20:28:36.939.24 [mindspore\dataset\engine\datasets_user_defined.py:656] Python multiprocessing is not supported on Windows platform.
    data_size: 64
    {'multi_dimensional_data': Tensor(shape=[2, 2], dtype=Int32, value=
    [[0, 1],
     [2, 3]])}
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    函数generator_multidimensional() 支持next()方法,并返回一个numpy.array对象。

    2.2 可调用的Python对象(多列)

    import numpy as np
    from mindspore import dataset as ds
    
    # 2) Multi-column generator function as callable input.
    def generator_multi_column():
        for i in range(4):
            yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])
    
    dataset = ds.GeneratorDataset(source=generator_multi_column, column_names=["col1", "col2"])
    data_size = dataset.get_dataset_size()
    print("data_size:", data_size)
    iter = dataset.create_dict_iterator()
    for data in iter:
        print("col1:", data['col1'])
        print("col2:", data['col2'])
    
    #output
    """
    data_size: 4
    col1: [0]
    col2: [[0 1]
     [2 3]]
    col1: [1]
    col2: [[1 2]
     [3 4]]
    col1: [2]
    col2: [[2 3]
     [4 5]]
    col1: [3]
    col2: [[3 4]
     [5 6]]
    """
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    函数generator_multi_column返回2列数据, GeneratorDataset为每列设置了column_name, 然后可以利用column name分别对每列数据进行访问迭代。

    2.3 可迭代Python对象

    # 3) Iterable dataset as iterable input.
    class MyIterable:
        def __init__(self):
            self._index = 0
            self._data = np.random.sample((5, 2))
            self._label = np.random.sample((5, 1))
    
        def __next__(self):
            if self._index >= len(self._data):
                raise StopIteration
            else:
                item = (self._data[self._index], self._label[self._index])
                self._index += 1
                return item
    
        def __iter__(self):
            self._index = 0
            return self
    
        def __len__(self):
            return len(self._data)
    
    dataset = ds.GeneratorDataset(source=MyIterable(), column_names=["data", "label"])
    data_size = dataset.get_dataset_size()
    print("data_size:", data_size)
    iter = dataset.create_dict_iterator()
    for data in iter:
        print("data shape:", data['data'].shape)
        print("label:", data['label'])
     
    #output
    """
    data_size: 5
    data shape: (2,)
    label: [0.5573979]
    data shape: (2,)
    label: [0.85054132]
    data shape: (2,)
    label: [0.27849295]
    data shape: (2,)
    label: [0.47892104]
    data shape: (2,)
    label: [0.81040191]
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    如上,是一个可迭代的对象的定义方式。包含方法:__init__, __next__, __iter__以及 __len__。其中,__next__方法返回一个元组,元组的大小为2,分别表示data, label,其与column_names的长度保持一致。此外,data, label的长度也是保持一致的,如果不一致,数据迭代时会出现错误。

    2.4 支持随机访问的Python对象

    # 4) Random accessible dataset as random accessible input.
    class MyAccessible:
        def __init__(self):
            self._data = np.random.sample((5, 2))
            self._label = np.random.sample((5, 1))
    
        def __getitem__(self, index):
            return self._data[index], self._label[index]
    
        def __len__(self):
            return len(self._data)
    
    dataset = ds.GeneratorDataset(source=MyAccessible(), column_names=["data", "label"])
    data_size = dataset.get_dataset_size()
    print("data_size:", data_size)
    iter = dataset.create_dict_iterator()
    for data in iter:
        print("data shape:", data['data'].shape)
        print("label:", data['label'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    如上,是一个支持随机访问的对象的定义方式。包含方法:__init__, __getitem__, __len__。同样地,__next__方法返回一个元组,元组的大小为2,分别表示data, label,其与column_names的长度保持一致。此外,data, label的长度也是保持一致的,如果不一致,数据迭代时会出现错误。

    需要指出的是,list, dict, tuple 均支持随机访问。

    3、MindRecord数据集加载

    MindRecordMindSpore开发的一种高效数据格式,mindspore.mindrecord模块提供了一些方法帮助用户将不同数据集转换为MindRecord格式, 也提供了一些方法读取、写入或者检索MindRecord格式文件。 用户可以使用mindspore.mindrecord.FileWriter生成MindRecord格式数据集,然后可以使用mindspore.dataset.MindDataset加载MindRecord格式数据集。

    MindSpore针对一些数据加载场景进行了性能优化,使用MindRecord数据格式可以减少磁盘IO、网络IO开销,从而获得更好的使用体验。

    MindRecord数据格式具备的特征有:

    1. 实现数据统一存储、统一访问,使得训练时数据读取更加简便。
    2. 实现数据聚合存储、高效读取,使得训练时数据方便管理和移动。
    3. 实现高效的数据编解码操作,使得用户对数据操作无感知。
    4. 实现灵活控制数据切分的分区大小,实现分布式数据处理。

    下面主要介绍如何将CV类数据集转换为MindRecord文件格式,并通过MindDataset接口,实现MindRecord数据集加载。

    3.1 将CV类数据集转换成为MindRecord格式数据集

    使用的接口如下:

    class mindspore.mindrecord.FileWriter(file_name, shard_num=1, overwrite=False)
    
    • 1

    其中,

    • file_name (str) - 转换生成的MindRecord文件路径。
    • shard_num (int,可选) - 生成MindRecord的文件个数。取值范围为[1, 1000]。默认值:1。
    • overwrite (bool,可选) - 当指定目录存在同名文件时是否覆盖写。默认值:False。

    如下示例,生成100张图像,并转换成MindRecord文件格式。其中样本包含file_name(字符串)、label(整型)、 data(二进制)三个字段。

    import os
    from PIL import Image
    from io import BytesIO
    
    import mindspore.mindrecord as record
    
    
    # 输出的MindSpore Record文件完整路径
    MINDRECORD_FILE = "dataset/test.mindrecord"
    
    if os.path.exists(MINDRECORD_FILE):
        os.remove(MINDRECORD_FILE)
    if os.path.exists(MINDRECORD_FILE + ".db"):
        os.remove(MINDRECORD_FILE + ".db")
    
    # 定义包含的字段
    cv_schema = {"file_name": {"type": "string"},
                 "label": {"type": "int32"},
                 "data": {"type": "bytes"}}
    
    # 声明MindSpore Record文件格式
    writer = record.FileWriter(file_name=MINDRECORD_FILE, shard_num=1)
    writer.add_schema(cv_schema, "it is a cv dataset")
    writer.add_index(["file_name", "label"])
    
    # 创建数据集
    data = []
    for i in range(100):
        i += 1
        sample = {}
        white_io = BytesIO()
        Image.new('RGB', (i*10, i*10), (255, 255, 255)).save(white_io, 'JPEG')
        image_bytes = white_io.getvalue()
        sample['file_name'] = str(i) + ".jpg"
        sample['label'] = i
        sample['data'] = white_io.getvalue()
    
        data.append(sample)
        if i % 10 == 0:
            writer.write_raw_data(data)
            if len(data) > 0 and i == 10:
                sample = data[0]
                print("data type", type(sample['data']))
                print("img name:", sample['file_name'])
                print("label:",sample['label'])
            data = []
    
    if data:
        writer.write_raw_data(data)
    
    writer.commit()
    
    #output
    """
    data type 
    img name: 1.jpg
    label: 1
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    保存后的文件结构如下:

    └── dataset
        ├── test.mindrecord
        ├── test.mindrecord.db
    
    • 1
    • 2
    • 3

    关于MindRecord文件格式的更多内容以及其他格式的数据集转换可以参考格式转换

    3.2 MindDataset接口读取MindRecord文件格式

    MindSpore提供接口mindspore.dataset.MindDataset用于读取和解析MindRecord数据文件构建数据集。生成的数据集的列名和列类型取决于MindRecord文件中的保存的列名与类型。

    mindspore.dataset.MindDataset接口定义如下:

    class mindspore.dataset.MindDataset(dataset_files, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None, shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None)
    
    • 1

    其中,必须设置的参数是需要读取的文件路径:

    • dataset_files (Union[str, list[str]]) - MindRecord文件路径,支持单文件路径字符串、多文件路径字符串列表。如果 dataset_files 的类型是字符串,则它代表一组具有相同前缀名的MindRecord文件,同一路径下具有相同前缀名的其他MindRecord文件将会被自动寻找并加载。

    其他是可选参数,根据实际使用情况选择。如下:

    • columns_list (list[str],可选) - 指定从MindRecord文件中读取的数据列。默认值:None,读取所有列。
    • num_parallel_workers (int, 可选) - 指定读取数据的工作线程数。默认值:None,使用mindspore.dataset.config中配置的线程数。
    • shuffle (Union[bool, Shuffle], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定,默认值:mindspore.dataset.Shuffle.GLOBAL。 如果 shuffle =False,则不混洗,如果 shuffle =True,等同于将 shuffle 设置为mindspore.dataset.Shuffle.GLOBAL。 通过传入枚举变量设置数据混洗的模式:
      • Shuffle.GLOBAL:混洗文件和文件中的数据。
      • Shuffle.FILES:仅混洗文件。
      • Shuffle.INFILE:保持读入文件的序列,仅混洗每个文件中的数据。
    • num_shards (int, 可选) - 指定分布式训练时将数据集进行划分的分片数,默认值:None。指定此参数后, num_samples 表示每个分片的最大样本数。
    • shard_id (int, 可选) - 指定分布式训练时使用的分片ID号,默认值:None。只有当指定了 num_shards 时才能指定此参数。
    • sampler (Sampler, 可选) - 指定从数据集中选取样本的采样器,默认值:None,下表中会展示不同配置的预期行为。当前此数据集仅支持以下采样器:SubsetRandomSampler、PkSampler、RandomSampler、SequentialSamplerDistributedSampler
    • padded_sample (dict, 可选): 指定额外添加到数据集的样本,可用于在分布式训练时补齐分片数据,注意字典的键名需要与 column_list 指定的列名相同。默认值:None,不添加样本。需要与 num_padded 参数同时使用。
    • num_padded (int, 可选) - 指定额外添加的数据集样本的数量。在分布式训练时可用于为数据集补齐样本,使得总样本数量可被num_shards整除。默认值:None,不添加样本。需要与 padded_sample 参数同时使用。
    • num_samples (int, 可选) - 指定从数据集中读取的样本数。默认值:None,读取所有样本。
    • cache (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 单节点数据缓存 。默认值:None,不使用缓存。

    读取文件的示例如下:

    import mindspore.dataset as ds
    import mindspore.dataset.vision as vision
    
    # 读取MindSpore Record文件格式
    MINDRECORD_FILE = "dataset/test.mindrecord"
    dataset = ds.MindDataset(dataset_files=MINDRECORD_FILE)
    
    decode_op = vision.Decode()
    dataset = dataset.map(operations=decode_op, input_columns=["data"], num_parallel_workers=2)
    
    # 样本计数
    print("Got {} samples".format(data_set.get_dataset_size()))
    iter = dataset.create_dict_iterator()
    for data in iter:
        img = data["data"]
        print("image type:", type(img))
        break
    
    #output
    """
    Got 100 samples
    image type: 
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    其中,mindspore.dataset.vision.Decode() 是将输入的压缩图像解码为RGB格式。具体说明参考mindspore.dataset.vision.Decodedataset.map操作将指定的函数作用于数据集的指定列数据,实现数据映射操作。其具体使用方法将在数据处理部分进行介绍。这里,是将数据集中的图像进行解码操作。

    此外,还可以通过mindspore.mindrecord.FileReader接口读取MindRecord文件。

    接口定义如下:

    class mindspore.mindrecord.FileReader(file_name, num_consumer=4, columns=None, operator=None)
    
    • 1

    其中,

    • file_name (str, list[str]) - MindRecord格式的数据集文件路径或文件路径组成的列表。
    • num_consumer (int,可选) - 加载数据的并发数。默认值:4。不应小于1或大于处理器的核数。
    • columns (list[str],可选) - MindRecord中待读取数据列的列表。默认值:None,读取所有的数据列。
    • operator (int,可选) - 保留参数。默认值:None。

    3.3 标准数据集转换成MindRecord格式数据集

    我们可以通过Cifar10ToMR类,将CIFAR-10原始数据转换为MindRecord,并使用MindDataset接口读取。

    此时,使用的是CIFAR-10的python格式的数据集。数据集的结构与之前不同:

    ./datasets/cifar-10-batches-py
    ├── batches.meta
    ├── data_batch_1
    ├── data_batch_2
    ├── data_batch_3
    ├── data_batch_4
    ├── data_batch_5
    ├── readme.html
    └── test_batch
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    数据集转换的方法:创建Cifar10ToMR对象,调用transform接口,将CIFAR-10数据集转换为MindRecord文件格式

    import os
    from mindspore.mindrecord import Cifar10ToMR
    
    ds_target_path = "./dataset/cifar10/" # 提前创建目录
    
    # CIFAR-10数据集路径
    CIFAR10_DIR = r"path/to/cifar-10-batches-py"
    
    # 输出的MindSpore Record文件路径
    MINDRECORD_FILE = os.path.join(ds_target_path,"cifar10.mindrecord")
    
    cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
    cifar10_transformer.transform(['label'])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    转换后保存的文件如下所示:

    ./datasets/cifar10
    ├── cifar10.mindrecord
    ├── cifar10.mindrecord.db
    ├── cifar10.mindrecord_test
    ├── cifar10.mindrecord_test.db
    
    • 1
    • 2
    • 3
    • 4
    • 5

    关于Cifar10ToMR的更多内容可以参考mindspore.mindrecord.Cifar10ToMR

    此外,MindSpore还提供了其他标准数据集转换为MindRecord的方法,更多内容参考Cifar100ToMR, ImageNetToMR, 以及TFRecordToMR

  • 相关阅读:
    SQL ORDER BY Keyword(按关键字排序)
    第六章 dubbo接口测试
    【nosql】redis之高可用(主从复制、哨兵、集群)搭建
    c++primer 2.1.1 算数类型
    现代 CSS 解决方案:数学函数 Round
    【安全函数】常用的安全函数的使用
    python实现读取,修改excel数据
    Java EE——线程(2)
    C++ 基础二
    微服务实践k8s&dapr开发部署实验(1)服务调用
  • 原文地址:https://blog.csdn.net/liujiabin076/article/details/126460157