• 基于CNTK/C#实现图像分类【附部分源码及模型】



    前言

    本文基于CNTK实现分类,并以之前的不同,本次使用C#实现,不适用python,python版的CNTK比较简单,而且python版的cntk个人感觉没什么必要,毕竟是微软的框架因此本人强迫症犯了,所以使用C#实现CNTK
    环境版本:
    Visualstudio 2022
    C# .net4.6
    cntk 2.7
    cuda 10.1


    一、数据集准备

    本次数据集使用中国象棋数据集,如图:
    在这里插入图片描述
    在这里插入图片描述
    DataImage_train:图像训练文件夹
    DataImage_val:图像验证
    test:图像测试


    二、图像分类程序构建

    1.变量定义

    本次使用的变量都包含了训练,验证,测试的变量

    //CNTK模型类
    static CntkModelWrapper _modelWrapper = null;
    //CNTK的输入层和输出层的名称
    private const string FEATURE_STREAM_NAME = "features";
    private const string LABEL_STREAM_NAME = "labels";
    //预训练模型文件
    static string BaseWorkPath = @"./Base_Model";
    //是否启用GPU
    static bool useGPU = true;
    static DeviceDescriptor device = useGPU ? DeviceDescriptor.GPUDevice(0) : DeviceDescriptor.CPUDevice;
    //训练使用的模型及拥有的模型列表
    static int model_type = 0;
    static string[] base_model_file = new string[] { "AlexNet_ImageNet_CNTK.model", "InceptionV3_ImageNet_CNTK.model", "ResNet18_ImageNet_CNTK.model", "ResNet34_ImageNet_CNTK.model", "ResNet50_ImageNet_CNTK.model", "ResNet101_ImageNet_CNTK.model", "ResNet152_ImageNet_CNTK.model" };
    //训练图像的大小
    static int IMAGE_WIDTH = (model_type == 0) ? 227 : ((model_type == 1) ? 299 : 224);
    static int IMAGE_HEIGHT = (model_type == 0) ? 227 : ((model_type == 1) ? 299 : 224);
    static int IMAGE_DEPTH = 3;
    //学习率
    static float learning_rate = 0.0001f;
    //每次迭代的批次
    static uint batch_size = 4;
    //训练次数
    static int TrainNum = 300;
    
    //类别,这里通过读取文件夹
    static string[] classes_names = new string[]{ };
    //是否重新生成训练文件及重新构建模型
    static bool reCreateData = false;
    static bool reCreateModel = true;
    //训练数据集的文件夹
    static string ImageDir_Train = @"./DataSet_Classification_Chess\DataImage_train";
    //验证数据集的文件夹
    static string ImageDir_Val = @"./DataSet_Classification_Chess\DataImage_val";
    //图像测试文件夹
    static string ImageDir_Test = @"./DataSet_Classification_Chess\test";
    //图像扩展名
    static string ext = "bmp";
    //最后模型保存路径
    static string model_path = "./result_Model";
    //模型训练生成名称
    static string model_file = "Ctu_CNTK.model";
    //保存列表,这里保存的是类别名称及顺序
    static string config_file = "Ctu_Config.txt";
    //保存训练文件
    static string train_data_file = "train-dataset.txt";
    //保存验证文件
    static string test_data_file = "test-dataset.txt";
    
    //本程序运行模式
    static string RunModel = "train";
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50

    2.模型文件生成

    由于预训练模型文件是针对ImageNet数据集分1000类,因为是自定义数据集,所以需要对后面的模型进行小修改

     public static Function BuildTransferLearningModel(Function baseModel, string featureNodeName, string outputNodeName, string hiddenNodeName, int[] imageDims, int numClasses, DeviceDescriptor device)
     {
         var input = Variable.InputVariable(imageDims, DataType.Float);
         var normalizedFeatureNode = CNTKLib.Minus(input, Constant.Scalar(DataType.Float, 114.0F));
    
         var oldFeatureNode = baseModel.Arguments.Single(a => a.Name == featureNodeName);
         var lastNode = baseModel.FindByName(hiddenNodeName);
    
         var clonedLayer = CNTKLib.AsComposite(lastNode).Clone(
             ParameterCloningMethod.Freeze,
             new Dictionary<Variable, Variable>()
             {
                 { oldFeatureNode, normalizedFeatureNode }
             });
    
         var clonedModel = Dense(clonedLayer, numClasses, device, Activation.None, outputNodeName);
         return clonedModel;
     }
    if (reCreateModel || File.Exists(Path.Combine(model_path, model_file)) == false)
    {
         CreateAndSaveModel(config, Path.Combine(BaseWorkPath, base_model_file[model_type]), Path.Combine(model_path, model_file), device);
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    3.训练数据集生成

    if (reCreateData || Directory.Exists(model_path) == false)
    {
          DirectoryInfo dir = new DirectoryInfo(model_path);
          if (dir.Exists)
          {
              DirectoryInfo[] childs = dir.GetDirectories();
              foreach (DirectoryInfo child in childs)
              {
                  child.Delete(true);
              }
              dir.Delete(true);
          }
          Directory.CreateDirectory(model_path);
          CreateAndSaveDatasets(config, ImageDir_Train, Path.Combine(model_path, train_data_file), ImageDir_Val, Path.Combine(model_path, test_data_file), ext);
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    在这里插入图片描述


    4.训练完整代码

    if (RunModel == "train")
    {
        classes_names = new DirectoryInfo(ImageDir_Train).GetDirectories().Select(d => d.Name).ToList().ToArray();
        //classes_names = Directory.GetDirectories(ImageDir_Train).Select(d => d.Substring(d.LastIndexOf('\\') + 1)).ToList().ToArray();
    
        var config = new ClassificationConfig(classes_names);
        if (reCreateData || Directory.Exists(model_path) == false)
        {
            DirectoryInfo dir = new DirectoryInfo(model_path);
            if (dir.Exists)
            {
                DirectoryInfo[] childs = dir.GetDirectories();
                foreach (DirectoryInfo child in childs)
                {
                    child.Delete(true);
                }
                dir.Delete(true);
            }
            Directory.CreateDirectory(model_path);
            CreateAndSaveDatasets(config, ImageDir_Train, Path.Combine(model_path, train_data_file), ImageDir_Val, Path.Combine(model_path, test_data_file), ext);
            
        }
        if (reCreateModel || File.Exists(Path.Combine(model_path, model_file)) == false)
        {
            CreateAndSaveModel(config, Path.Combine(BaseWorkPath, base_model_file[model_type]), Path.Combine(model_path, model_file), device);
        }
        config.Save(Path.Combine(model_path, config_file));
    
        // 训练
        _modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
        var dataSource = CreateDataSource(Path.Combine(model_path, train_data_file));
        var trainer = CreateTrainer();
    
        var minibatchesSeen = 0;
        int data_length = readFileLines(Path.Combine(model_path, train_data_file));
        while (true)
        {
            var minibatchData = dataSource.MinibatchSource.GetNextMinibatch(batch_size, device);
            var arguments = new Dictionary<Variable, MinibatchData>
            {
                { _modelWrapper.Input, minibatchData[dataSource.FeatureStreamInfo] },
                { _modelWrapper.TrainingOutput, minibatchData[dataSource.LabelStreamInfo] }
            };
            trainer.TrainMinibatch(arguments, device);
            double loss = trainer.PreviousMinibatchLossAverage();
            double eval = trainer.PreviousMinibatchEvaluationAverage();
    
            int epoch = Convert.ToInt32((minibatchesSeen * batch_size / data_length)) + 1;
            Console.WriteLine($"[{epoch}:{TrainNum}/{minibatchesSeen % (data_length / batch_size)+1}] CrossEntropyLoss = {loss}, EvaluationCriterion = {eval}");
    
            minibatchesSeen++;
    
            if ((Convert.ToInt32((minibatchesSeen * batch_size / data_length)) + 1) > TrainNum)
            {
                _modelWrapper.Model.Save(Path.Combine(model_path, model_file));
                break;
            }
        }
        RunModel = "val";
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60

    5.验证完整代码

    if (RunModel == "val")
    {
        //验证
        _modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
        var dataSource_test = CreateDataSource(Path.Combine(model_path, test_data_file));
    
        const int minibatchSize = 1;
        var currentMinibatch = 0;
        int Correct = 0;
        int Total = 0;
        while (true)
        {
            var minibatchData = dataSource_test.MinibatchSource.GetNextMinibatch(minibatchSize, device);
            var inputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.Input, minibatchData[dataSource_test.FeatureStreamInfo].data } };
            var outputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.EvaluationOutput, null } };
    
            _modelWrapper.Model.Evaluate(inputDataMap, outputDataMap, device);
            var outputVal = outputDataMap[_modelWrapper.EvaluationOutput];
            var actual = outputVal.GetDenseData<float>(_modelWrapper.EvaluationOutput);
            var labelBatch = minibatchData[dataSource_test.LabelStreamInfo].data;
            var expected = labelBatch.GetDenseData<float>(_modelWrapper.Model.Output);
    
            Func<IEnumerable<IList<float>>, IEnumerable<int>> maxSelector =
                (collection) => collection.Select(x => x.IndexOf(x.Max()));
    
            var actualLabels = maxSelector(actual);
            var expectedLabels = maxSelector(expected);
    
            Correct += actualLabels.Zip(expectedLabels, (a, b) => a.Equals(b) ? 1 : 0).Sum();
            Total += actualLabels.Count();
            double acc = (Convert.ToDouble(Correct) / Total);
            Console.WriteLine($"Correct = {Correct}, Total = {Total}, acc={acc}");
            currentMinibatch++;
            if (minibatchData.Values.Any(x => x.sweepEnd))
            {
                break;
            }
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    6.预测完整代码

    if(RunModel=="test")
    {
         //测试图片
         _modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
         var config = ClassificationConfig.Load(Path.Combine(model_path, config_file));
    
         string[] all_image = Directory.GetFiles(ImageDir_Test, $"*.{ext}");
         foreach(string file in all_image)
         {
             var inputValue = new Value(new NDArrayView(new int[] { IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_DEPTH }, ImageHelper.Load(IMAGE_WIDTH, IMAGE_HEIGHT, file), device));
             var inputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.Input, inputValue } };
             var outputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.EvaluationOutput, null } };
    
             _modelWrapper.Model.Evaluate(inputDataMap, outputDataMap, device);
             var outputData = outputDataMap[_modelWrapper.EvaluationOutput].GetDenseData<float>(_modelWrapper.EvaluationOutput).First();
    
             var output = outputData.Select(x => (double)x).ToArray();
             var classIndex = Array.IndexOf(output, output.Max());
             var className = config.GetClassNameByIndex(classIndex);
             Console.WriteLine(file + " : " + className);
         }
     }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    训练效果

    在这里插入图片描述
    在这里插入图片描述

    预测效果

    在这里插入图片描述

    总结

    源码私聊

  • 相关阅读:
    嵌入式系统软件开发环境_3.主要功能和典型产品
    初识C++(2)
    自动控制原理 传递函数
    C#:轮询调度算法​(附完整源码)
    放出云伙伴生态“大招”,微软为业界打了个样
    SpringCloud原理-OpenFeign篇(一、Hello OpenFeign项目示例)
    K8S-PV与PVC
    linux部署禅道
    springboot基于BS结构的企业人事管理系统的设计与实现毕业设计源码121727
    vue+vscode 快速搭建运行调试环境与发布
  • 原文地址:https://blog.csdn.net/ctu_sue/article/details/127599959