• mindspore-softmax进行鸢尾花多分类模型


    版本:mindspore1.3.0

    代码:

    import os
    # os.environ['DEVICE_ID'] = '6'
    import csv
    import numpy as np

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import dataset
    from mindspore.train.callback import LossMonitor
    from mindspore.common.api import ms_function
    from mindspore.ops import operations as P

    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
    with open('iris.data') as csv_file:
        data = list(csv.reader(csv_file, delimiter=','))

    label_map = {
        'Iris-setosa': 0,
        'Iris-versicolor': 1,
        'Iris-virginica':2,
    }

    X = np.array([[float(x) for x in s[:-1]] for s in data[:150]], np.float32)
     

    Y = np.array([[label_map[s[-1]]] for s in data[:150]], np.float32)
     

    train_idx = np.random.choice(150, 120, replace=False)
    test_idx = np.array(list(set(range(150)) - set(train_idx)))
    X_train, Y_train = X[train_idx], Y[train_idx]
    X_test, Y_test = X[test_idx], Y[test_idx]
    XY_train = list(zip(X_train, Y_train))
    ds_train = dataset.GeneratorDataset(XY_train, ['x', 'y'])

    ds_train = ds_train.shuffle(buffer_size=80).batch(32, drop_remainder=True)
    XY_test = list(zip(X_test, Y_test))
    ds_test = dataset.GeneratorDataset(XY_test, ['x', 'y'])
    ds_test = ds_test.batch(30)

    net = nn.Dense(4, 3)
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    opt = nn.optim.Momentum(net.trainable_params(), learning_rate=0.05, momentum=0.9)

    model = ms.train.Model(net, loss, opt, metrics={'acc', 'loss'})
    model.train(15, ds_train, callbacks=[LossMonitor(per_print_times=ds_train.get_dataset_size())], dataset_sink_mode=False)
    metrics = model.eval(ds_test)
    print(metrics)

    1. 按照报错提示,是因为你的dataset对象给多个model使用了。

    2. 但是我们拿了你上面的脚本,先从脚本上看,没有发现ds_train / ds_eval给多个model使用的情况,另:在本地运行后,也没有报你上面的报错。

    故:你可以再试下你上条评论里的脚本,或者还有没有其他信息提供?我们再分析下。

  • 相关阅读:
    C语言之switch语句详解
    数据结构基本概念-Java常用算法
    Java学习笔记 --- 内部类
    Vue2.0 —— Vue.set(vm.$set) 源码探秘
    【算法设计与分析】— —实现最优载的贪心算法
    使用 webpack-cli 零配置打包,真香
    这五个步骤,网络有故障,自己都能解决
    【历史上的今天】8 月 14 日:新浪微博开始内测;阿塔纳索夫完成论文;登上太空的计算机病毒
    HighTec 工程配置详解
    大数据之路阿里巴巴实践
  • 原文地址:https://blog.csdn.net/weixin_45666880/article/details/126409739