• 【MindSpore易点通】如何将PyTorch源码转成MindSpore低阶APIP,并在Ascend芯片上实现单机单卡训练


    文章来源:华为云论坛_云计算论坛_开发者论坛_技术论坛-华为云

    1 概述
    本文将介绍如何将PyTorch源码转换成MindSpore低阶API代码,并在Ascend芯片上实现单机单卡训练。
    下图展示了MindSpore高阶API、低阶API和PyTorch的训练流程的区别。


    与MindSpore高阶API相同,低阶API训练也需要进行:配置运行信息、数据读取和预处理、网络定义、定义损失函数和优化器。具体步骤同高阶API。
    2 构造模型(低阶API)
    构造模型时,首先将网络原型与损失函数封装,再将组合的模型与优化器封装,最终组合成一个可用于训练的网络。 由于训练并验证中,需计算在训练集上的精度 ,因此返回值中需包含网络的输出值。

    1. import mindsporefrom mindspore import Modelimport mindspore.nn as nnfrom mindspore.ops import functional as Ffrom mindspore.ops import operations as P
    2. class BuildTrainNetwork(nn.Cell):
    3. '''Build train network.'''
    4. def __init__(self, my_network, my_criterion, train_batch_size, class_num):
    5. super(BuildTrainNetwork, self).__init__()
    6. self.network = my_network
    7. self.criterion = my_criterion
    8. self.print = P.Print()
    9. # Initialize self.output
    10. self.output = mindspore.Parameter(Tensor(np.ones((train_batch_size,
    11. class_num)), mindspore.float32), requires_grad=False)
    12. def construct(self, input_data, label):
    13. output = self.network(input_data)
    14. # Get the network output and assign it to self.output
    15. self.output = output
    16. loss0 = self.criterion(output, label)
    17. return loss0
    18. class TrainOneStepCellV2(TrainOneStepCell):
    19. '''Build train network.'''
    20. def __init__(self, network, optimizer, sens=1.0):
    21. super(TrainOneStepCellV2, self).__init__(network, optimizer, sens=1.0)
    22. def construct(self, *inputs):
    23. weights = self.weights
    24. loss = self.network(*inputs)
    25. # Obtain self.network from BuildTrainNetwork
    26. output = self.network.output
    27. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
    28. # Get the gradient of the network parameters
    29. grads = self.grad(self.network, weights)(*inputs, sens)
    30. grads = self.grad_reducer(grads)
    31. # Optimize model parameters
    32. loss = F.depend(loss, self.optimizer(grads))
    33. return loss, output
    34. # Construct model
    35. model_constructed = BuildTrainNetwork(net, loss_function, TRAIN_BATCH_SIZE, CLASS_NUM)
    36. model_constructed = TrainOneStepCellV2(model_constructed, opt)


    3 训练并验证(低阶API)
    和PyTorch中类似,采用低阶API进行网络训练并验证。详细步骤如下:

    1. class CorrectLabelNum(nn.Cell):
    2. def __init__(self):
    3. super(CorrectLabelNum, self).__init__()
    4. self.print = P.Print()
    5. self.argmax = mindspore.ops.Argmax(axis=1)
    6. self.sum = mindspore.ops.ReduceSum()
    7. def construct(self, output, target):
    8. output = self.argmax(output)
    9. correct = self.sum((output == target).astype(mindspore.dtype.float32))
    10. return correct
    11. def train_net(model, network, criterion,
    12. epoch_max, train_path, val_path,
    13. train_batch_size, val_batch_size,
    14. repeat_size):
    15. """define the training method"""
    16. # Create dataset
    17. ds_train, steps_per_epoch_train = create_dataset(train_path,
    18. do_train=True, batch_size=train_batch_size, repeat_num=repeat_size)
    19. ds_val, steps_per_epoch_val = create_dataset(val_path, do_train=False,
    20. batch_size=val_batch_size, repeat_num=repeat_size)
    21. # CheckPoint CallBack definition
    22. config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch_train,
    23. keep_checkpoint_max=epoch_max)
    24. ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10",
    25. directory="./", config=config_ck)
    26. # Create dict to save internal callback object's parameters
    27. cb_params = _InternalCallbackParam()
    28. cb_params.train_network = model
    29. cb_params.epoch_num = epoch_max
    30. cb_params.batch_num = steps_per_epoch_train
    31. cb_params.cur_epoch_num = 0
    32. cb_params.cur_step_num = 0
    33. run_context = RunContext(cb_params)
    34. ckpoint_cb.begin(run_context)
    35. print("============== Starting Training ==============")
    36. correct_num = CorrectLabelNum()
    37. correct_num.set_train(False)
    38. for epoch in range(epoch_max):
    39. print("
    40. Epoch:", epoch+1, "/", epoch_max)
    41. train_loss = 0
    42. train_correct = 0
    43. train_total = 0
    44. for _, (data, gt_classes) in enumerate(ds_train):
    45. model.set_train()
    46. loss, output = model(data, gt_classes)
    47. train_loss += loss
    48. correct = correct_num(output, gt_classes)
    49. correct = correct.asnumpy()
    50. train_correct += correct.sum()
    51. # Update current step number
    52. cb_params.cur_step_num += 1
    53. # Check whether to save checkpoint or not
    54. ckpoint_cb.step_end(run_context)
    55. cb_params.cur_epoch_num += 1
    56. my_train_loss = train_loss/steps_per_epoch_train
    57. my_train_accuracy = 100*train_correct/(train_batch_size*
    58. steps_per_epoch_train)
    59. print('Train Loss:', my_train_loss)
    60. print('Train Accuracy:', my_train_accuracy, '%')
    61. print('evaluating {}/{} ...'.format(epoch + 1, epoch_max))
    62. val_loss = 0
    63. val_correct = 0
    64. for _, (data, gt_classes) in enumerate(ds_val):
    65. network.set_train(False)
    66. output = network(data)
    67. loss = criterion(output, gt_classes)
    68. val_loss += loss
    69. correct = correct_num(output, gt_classes)
    70. correct = correct.asnumpy()
    71. val_correct += correct.sum()
    72. my_val_loss = val_loss/steps_per_epoch_val
    73. my_val_accuracy = 100*val_correct/(val_batch_size*steps_per_epoch_val)
    74. print('Validation Loss:', my_val_loss)
    75. print('Validation Accuracy:', my_val_accuracy, '%')
    76. print("--------- trains out ---------")


    4 运行脚本
    启动命令:
    python MindSpore_1P_low_API.py --data_path=xxx --epoch_num=xxx
    在开发环境的Terminal中运行脚本,可以看到网络输出结果:


    注:由于高阶API采用数据下沉模式进行训练,而低阶API不支持数据下沉训练,因此高阶API比低阶API训练速度快。
    性能对比:低阶API: 2000 imgs/sec ;高阶API: 2200 imgs/sec
    详细代码请前往MindSpore论坛进行下载:华为云论坛_云计算论坛_开发者论坛_技术论坛-华为云

  • 相关阅读:
    A-Level陆续放榜,这些重要事宜需要关注
    基于STM32单片机的智能家居环境监测与控制系统设计
    Lazarus网络编程
    如何写出优雅的代码?
    什么???CSS也能原子化!
    UE5——源码阅读——3——引擎退出
    微服务开发与实战Day11 - 微服务面试篇
    阿里巴巴为什么能抗住90秒100亿?看完这篇你就明白了!
    ERROR: KeeperErrorCode = ConnectionLoss for /hbase/master
    第二章:Pythonocc官方demo 案例45(几何轴向曲线偏置)
  • 原文地址:https://blog.csdn.net/skytttttt9394/article/details/126599748