使用MindSpore训练及保存模型的基本方法

2025-10-22 08:14:00

1、from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

# 设置模型保存参数

config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)

# 应用模型保存参数

ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)

2、通过MindSpore提供的model.train接口可以方便地进行网络的训练,LossMonitor可以监控训练过程中loss值的变化。

# 导入模型训练需要的库

from mindspore.nn import Accuracy

from mindspore.train.callback import LossMonitor

from mindspore import Model

3、def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):

    """定义训练的方法"""

    # 加载训练数据集

    ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)

    model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)

4、其中,dataset_sink_mode用于控制数据是否下沉,数据下沉是指数据通过通道直接传送到Device上,可以加快训练速度,dataset_sink_mode为True表示数据下沉,否则为非下沉。

通过模型运行测试数据集得到的结果,验证模型的泛化能力。

使用model.eval接口读入测试数据集。

使用保存后的模型参数进行推理。

5、def test_net(network, model, data_path):

    """定义验证的方法"""

    ds_eval = create_dataset(os.path.join(data_path, "test"))

    acc = model.eval(ds_eval, dataset_sink_mode=False)

    print("{}".format(acc))

6、train_epoch = 1

mnist_path = "./datasets/MNIST_Data"

dataset_size = 1

model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)

test_net(net, model, mnist_path)

声明:本网站引用、摘录或转载内容仅供网站访问者交流或参考,不代表本站立场,如存在版权或非法内容,请联系站长删除,联系邮箱:site.kefu@qq.com。
猜你喜欢