使用MindSpore进行数据处理的常用方法

2025-09-26 07:06:37

1、注:本次操作使用的MNIST数据集

      这里我们把数据集处理主要分为四个步骤:

1. 定义函数create_dataset来创建数据集。

2. 定义需要进行的数据增强和处理操作,为之后进行map映射做准备。

3. 使用map映射函数,将数据操作应用到数据集。

4. 进行数据shuffle、batch操作。

2、import mindspore.dataset as ds

import mindspore.dataset.transforms.c_transforms as C

import mindspore.dataset.vision.c_transforms as CV

from mindspore.dataset.vision import Inter

from mindspore import dtype as mstype

def create_dataset(data_path, batch_size=32, repeat_size=1,

                   num_parallel_workers=1):

    # 定义数据集

    mnist_ds = ds.MnistDataset(data_path)

    resize_height, resize_width = 32, 32

    rescale = 1.0 / 255.0

    shift = 0.0

    rescale_nml = 1 / 0.3081

    shift_nml = -1 * 0.1307 / 0.3081

3、    # 定义所需要操作的map映射

    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)

    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)

    rescale_op = CV.Rescale(rescale, shift)

    hwc2chw_op = CV.HWC2CHW()

    type_cast_op = C.TypeCast(mstype.int32)

4、    # 使用map映射函数,将数据操作应用到数据集

    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)

5、    # 进行shuffle、batch操作

    buffer_size = 10000

    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)

    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)

    return mnist_ds

声明:本网站引用、摘录或转载内容仅供网站访问者交流或参考,不代表本站立场,如存在版权或非法内容,请联系站长删除,联系邮箱:site.kefu@qq.com。
猜你喜欢