• C/C++开发,opencv-ml库学习,支持向量机(SVM)应用


    目录

    一、OpenCV支持向量机(SVM)模块

    1.1 openCV的机器学习库

    1.2 SVM(支持向量机)模块

    1.3 支持向量机(SVM)应用步骤

    二、支持向量机(SVM)应用示例

     2.1  训练及验证数据获取

    2.2 训练及验证数据加载

    2.3 SVM(支持向量机)训练及验证,输出svm模型

    2.4 SVM(支持向量机)实时识别应用

    三、完整代码编译

    3.1 OpenCV+MinGW的MakeFile编译

    3.2 OpenCV+vc2015+cmake编译

    3.3 执行效果

    3.4 附件,main.cpp全文


    一、OpenCV支持向量机(SVM)模块

    1.1 openCV的机器学习库

            OpenCV-ml库是OpenCV(开放源代码计算机视觉库)中的机器学习模块,常用于分类和回归问题,它是 OpenCV 众多modules下的一个模块。

            该模块提供了一系列常见的统计模型和分类算法,用于进行各种机器学习任务。以下是关于OpenCV-ml库的一些主要功能和特点:

    1. 丰富的算法支持:OpenCV-ml库包含了多种机器学习算法,如支持向量机(SVM)、决策树、Boosting方法、K近邻(KNN)、随机森林等。这些算法可以用于分类、回归、聚类等多种任务。
    2. 易于使用:OpenCV-ml库提供了简洁的API接口,使得开发者能够方便地调用各种机器学习算法。同时,它也支持多种数据格式,方便用户导入和处理数据。
    3. 高效性:OpenCV-ml库经过优化,能够高效地处理大规模数据集,并且具有较快的运算速度。这使得它能够满足实时处理和分析的需求。
    4. 与OpenCV其他模块的集成:OpenCV-ml库与OpenCV的其他模块(如imgproc、features2d等)紧密集成,可以方便地进行图像处理和特征提取,然后将提取的特征用于机器学习任务。
    1.2 SVM(支持向量机)模块

            OpenCV 的 SVM(支持向量机)模块是 OpenCV 机器学习库中的一个重要组成部分,它实现了支持向量机算法,用于解决分类和回归问题。支持向量机是一种监督学习模型,广泛应用于各种领域,特别是在图像分类和识别任务中。

            OpenCV 的 SVM 模块提供了灵活的参数设置和多种核函数选择,以适应不同的数据集和问题。以下是一些关于 OpenCV SVM 模块的主要特点:

    1. 多种核函数:支持线性核、多项式核、径向基函数(RBF)核和 Sigmoid 核等,可以根据问题的特性选择合适的核函数。

    2. 参数调整:可以通过调整 SVM 的参数,如 C 值(错误项的惩罚系数)和 gamma 值(对于 RBF、Poly 和 Sigmoid 核函数),来优化模型的性能。

    3. 多类分类支持:通过“一对一”或“一对多”的方式,可以处理多类分类问题。

    4. 概率估计:SVM 可以输出类别的概率估计,这对于某些应用(如置信度评估)非常有用。

    5. 易于使用:OpenCV 提供了简洁的 API,使得 SVM 的训练和测试过程相对简单。

    1.3 支持向量机(SVM)应用步骤

            在OpenCV中,使用支持向量机(SVM)进行预测涉及几个步骤。首先,获得训练数据,用于训练一个SVM模型,然后使用该模型对新的、未见过的数据进行预测。

        使用svm模型,包含必要的头文件:

    1. #include
    2. #include

       1) 准备训练和测试数据:

        你需要为SVM准备训练和测试数据。这些数据通常是特征向量,存储在cv::Mat对象中。每个特征向量对应一个标签(分类的类别)。
        2)创建和训练SVM模型:
        使用OpenCV的cv::ml::SVM类来创建SVM模型。然后,使用train方法来训练模型。
       3) 进行预测:
        使用训练好的模型对新数据进行预测。这通常涉及将新数据作为输入传递给模型的predict方法。

    二、支持向量机(SVM)应用示例

     2.1  训练及验证数据获取

            以下展示如何使用OpenCV的机器学习模块来实现一个基于SVM的手写数字识别器。首先前往网站:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges,下载MNIST database,用于实现一个SVM的手写数字识别模型训练及验证。

            下载完成后,进行解压操作:

            解压后是idx1-ubyteidx3-ubyte 是两种常见的标签编码格式,主要用于图像分割任务中。它们都是用来表示图像中每个像素所属类别的标签图像(也称为掩码或mask)。

    1. idx1-ubyte:

      • idx: 表示这是一个索引图像。
      • 1: 表示每个像素用一个字节(8位)来表示,且这些值从0开始,通常是连续的整数。
      • ubyte: 表示无符号字节类型,其值的范围是0到255。在idx1-ubyte格式中,通常会将0用作背景或未标记的类别,而其他值则用于表示不同的分割区域或类别。
    2. idx3-ubyte:

      • idx: 同样表示这是一个索引图像。
      • 3: 这里并不是指每个像素用3个字节来表示,而是指每个像素用一个字节来表示,但值的范围是从0到255,通常用来表示256个不同的类别(包括0作为背景或未标记的类别)。注意,虽然名为idx3,但实际上它并不是用3个字节来存储每个像素的值。
      • ubyte: 同样表示无符号字节类型。

            在图像分割任务中,这些标签图像通常与原始RGB图像一起使用。RGB图像用于显示给人类观察者或作为模型的输入,而标签图像则用于训练模型或评估模型的性能。

    2.2 训练及验证数据加载

            idx3-ubyte 文件通常与 MNIST 数据集相关联,这是一个大型的手写数字数据库,经常用于机器学习和深度学习中的图像识别任务。MNIST 数据集包含两个文件:train-images-idx3-ubytetrain-labels-idx1-ubyte(用于训练),以及 t10k-images-idx3-ubytet10k-labels-idx1-ubyte(用于测试)。这些文件使用特定的二进制格式存储图像和标签。

            通过两个函数来读取手写图像数据集和手写图像数据对应的标签(每个标签都是一个 0 到 9 之间的整数,表示对应图像中的手写数字)。

    1. //大小端转换
    2. int intReverse(int num)
    3. {
    4. return (num>>24|((num&0xFF0000)>>8)|((num&0xFF00)<<8)|((num&0xFF)<<24));
    5. }
    6. //读取手写图像数据集
    7. cv::Mat read_mnist_image(const std::string fileName) {
    8. int magic_number = 0;
    9. int number_of_images = 0;
    10. int img_rows = 0;
    11. int img_cols = 0;
    12. cv::Mat DataMat;
    13. std::ifstream file(fileName, std::ios::binary);
    14. if (file.is_open())
    15. {
    16. std::cout << "open images file: "<< fileName << std::endl;
    17. file.read((char*)&magic_number, sizeof(magic_number));//format
    18. file.read((char*)&number_of_images, sizeof(number_of_images));//images number
    19. file.read((char*)&img_rows, sizeof(img_rows));//img rows
    20. file.read((char*)&img_cols, sizeof(img_cols));//img cols
    21. magic_number = intReverse(magic_number);
    22. number_of_images = intReverse(number_of_images);
    23. img_rows = intReverse(img_rows);
    24. img_cols = intReverse(img_cols);
    25. std::cout << "format:" << magic_number
    26. << " img num:" << number_of_images
    27. << " img row:" << img_rows
    28. << " img col:" << img_cols << std::endl;
    29. std::cout << "read img data" << std::endl;
    30. DataMat = cv::Mat::zeros(number_of_images, img_rows * img_cols, CV_32FC1);
    31. unsigned char temp = 0;
    32. for (int i = 0; i < number_of_images; i++) {
    33. for (int j = 0; j < img_rows * img_cols; j++) {
    34. file.read((char*)&temp, sizeof(temp));
    35. //svm data is CV_32FC1
    36. float pixel_value = float(temp);
    37. DataMat.at<float>(i, j) = pixel_value;
    38. }
    39. }
    40. std::cout << "read img data finish!" << std::endl;
    41. }
    42. file.close();
    43. return DataMat;
    44. }
    45. //读取手写标签
    46. cv::Mat read_mnist_label(const std::string fileName) {
    47. int magic_number;
    48. int number_of_items;
    49. cv::Mat LabelMat;
    50. std::ifstream file(fileName, std::ios::binary);
    51. if (file.is_open())
    52. {
    53. std::cout << "open label file: "<< fileName << std::endl;
    54. file.read((char*)&magic_number, sizeof(magic_number));
    55. file.read((char*)&number_of_items, sizeof(number_of_items));
    56. magic_number = intReverse(magic_number);
    57. number_of_items = intReverse(number_of_items);
    58. std::cout << "format:" << magic_number << " ;label_num:" << number_of_items << std::endl;
    59. std::cout << "read Label data" << std::endl;
    60. //data type:CV_32SC1,channel:1
    61. LabelMat = cv::Mat::zeros(number_of_items, 1, CV_32SC1);
    62. for (int i = 0; i < number_of_items; i++) {
    63. unsigned char temp = 0;
    64. file.read((char*)&temp, sizeof(temp));
    65. LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
    66. }
    67. std::cout << "read label data finish!" << std::endl;
    68. }
    69. file.close();
    70. return LabelMat;
    71. }
    2.3 SVM(支持向量机)训练及验证,输出svm模型

            1)加载训练图像数据和标签数据,采用cv::Mat存储,图像数据虚归一化;

            2)创建svm模型,设置svm模型的各关联参数,不同参数设置,对应模型精度有较大影响;

            3)加载测试图像数据和标签数据,采用cv::Mat存储,图像数据虚归一化;

            4)采用测试图像数据验证已经训练好的svm模型,获得测试推演结果;

            5)通过测试结果和已有的标签数据进行校对,验证该模型精度。

            6)将训练好的模型保持输出。便于后续用于实时识别应用。

    1. //change path for real paths
    2. std::string trainImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-images.idx3-ubyte";
    3. std::string trainLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-labels.idx1-ubyte";
    4. std::string testImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images.idx3-ubyte";
    5. std::string testLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-labels.idx1-ubyte";
    6. void train_SVM()
    7. {
    8. //read train images, data type CV_32FC1
    9. cv::Mat trainingData = read_mnist_image(trainImgFile);
    10. //images data normalization
    11. trainingData = trainingData/255.0;
    12. std::cout << "trainingData.size() = " << trainingData.size() << std::endl;
    13. //read train label, data type CV_32SC1
    14. cv::Mat labelsMat = read_mnist_label(trainLabeFile);
    15. std::cout << "labelsMat.size() = " << labelsMat.size() << std::endl;
    16. std::cout << "trainingData & labelsMat finish!" << std::endl;
    17. //create SVM model
    18. cv::Ptr svm = cv::ml::SVM::create();
    19. //set svm args,type and KernelTypes
    20. svm->setType(cv::ml::SVM::C_SVC);
    21. svm->setKernel(cv::ml::SVM::POLY);
    22. //KernelTypes POLY is need set gamma and degree
    23. svm->setGamma(3.0);
    24. svm->setDegree(2.0);
    25. //Set iteration termination conditions, maxCount is importance
    26. svm->setTermCriteria(cv::TermCriteria(cv::TermCriteria::EPS | cv::TermCriteria::COUNT, 1000, 1e-8));
    27. std::cout << "create SVM object finish!" << std::endl;
    28. std::cout << "trainingData.rows = " << trainingData.rows << std::endl;
    29. std::cout << "trainingData.cols = " << trainingData.cols << std::endl;
    30. std::cout << "trainingData.type() = " << trainingData.type() << std::endl;
    31. // svm model train
    32. svm->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);
    33. std::cout << "SVM training finish!" << std::endl;
    34. // svm model test
    35. cv::Mat testData = read_mnist_image(testImgFile);
    36. //images data normalization
    37. testData = testData/255.0;
    38. std::cout << "testData.rows = " << testData.rows << std::endl;
    39. std::cout << "testData.cols = " << testData.cols << std::endl;
    40. std::cout << "testData.type() = " << testData.type() << std::endl;
    41. //read test label, data type CV_32SC1
    42. cv::Mat testlabel = read_mnist_label(testLabeFile);
    43. cv::Mat testResp;
    44. float response = svm->predict(testData,testResp);
    45. // std::cout << "response = " << response << std::endl;
    46. testResp.convertTo(testResp,CV_32SC1);
    47. int map_num = 0;
    48. for (int i = 0; i
    49. {
    50. if (testResp.at<int>(i, 0) == testlabel.at<int>(i, 0))
    51. {
    52. map_num++;
    53. }
    54. // else{
    55. // std::cout << "testResp.at(i, 0) " << testResp.at(i, 0) << std::endl;
    56. // std::cout << "testlabel.at(i, 0) " << testlabel.at(i, 0) << std::endl;
    57. // }
    58. }
    59. float proportion = float(map_num) / float(testResp.rows);
    60. std::cout << "map rate: " << proportion * 100 << "%" << std::endl;
    61. std::cout << "SVM testing finish!" << std::endl;
    62. //save svm model
    63. svm->save("mnist_svm.xml");
    64. }
    2.4 SVM(支持向量机)实时识别应用

            将t10k-images.idx3-ubyte处理成图片数据,用于svm模型调用示例,本文主要是通过一段python代码,将t10k-images.idx3-ubyte另存为一张张手写图片。

    1. import numpy as np
    2. import os
    3. from PIL import Image
    4. from struct import unpack
    5. def read_idx3_ubyte(filename):
    6. with open(filename, 'rb') as f:
    7. magic, num_images, rows, cols = unpack('>IIII', f.read(16))
    8. buf = f.read()
    9. data = np.frombuffer(buf, dtype=np.uint8).reshape((num_images, rows, cols))
    10. return data
    11. def save_images_as_png(idx3_file, output_dir, prefix='image'):
    12. images = read_idx3_ubyte(idx3_file)
    13. for i, image in enumerate(images):
    14. image_pil = Image.fromarray(image, 'L') # 'L' 表示灰度模式
    15. filename = f"{output_dir}/{prefix}_{i}.png"
    16. image_pil.save(filename)
    17. # 使用示例
    18. # idx3_file = 'train-images.idx3-ubyte'
    19. # output_dir = 'train-images'
    20. # if not os.path.exists(output_dir):#检查目录是否存在
    21. # os.makedirs(output_dir)#如果不存在则创建目录
    22. # save_images_as_png(idx3_file, output_dir)
    23. idx3_file = 't10k-images.idx3-ubyte'
    24. output_dir = 't10k-images'
    25. if not os.path.exists(output_dir):#检查目录是否存在
    26. os.makedirs(output_dir)#如果不存在则创建目录
    27. save_images_as_png(idx3_file, output_dir)

            在获得图片数据后,将加载这些图片,和上述已保存的svm模型(mnist_svm.xml),实现模型调用验证。

    1. void prediction(const std::string fileName,cv::Ptr svm)
    2. {
    3. //read img 28*28 size
    4. cv::Mat image = cv::imread(fileName, cv::IMREAD_GRAYSCALE);
    5. //uchar->float32
    6. image.convertTo(image, CV_32F);
    7. //image data normalization
    8. image = image / 255.0;
    9. //28*28 -> 1*784
    10. image = image.reshape(1, 1);
    11. //预测图片
    12. float ret = svm->predict(image);
    13. std::cout << "predict val = "<< ret << std::endl;
    14. }
    15. std::string imgDir = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images\\";
    16. std::string ImgFiles[5] = {"image_0.png","image_10.png","image_20.png","image_30.png","image_40.png",};
    17. void predictimgs()
    18. {
    19. //load svm model
    20. cv::Ptr svm = cv::ml::StatModel::load("mnist_svm.xml");
    21. for (size_t i = 0; i < 5; i++)
    22. {
    23. prediction(imgDir+ImgFiles[i],svm);
    24. }
    25. }

    三、完整代码编译

    3.1 OpenCV+MinGW的MakeFile编译

            本文是采用win系统下,opencv采用MinGW编译的静态库(C/C++开发,win下OpenCV+MinGW编译环境搭建_opencv mingw-CSDN博客),建立makefile:

    1. #/bin/sh
    2. #win32
    3. CX= g++ -DWIN32
    4. #linux
    5. #CX= g++ -Dlinux
    6. BIN := ./
    7. TARGET := opencv_ml01.exe
    8. FLAGS := -std=c++11 -static
    9. SRCDIR := ./
    10. #INCLUDES
    11. INCLUDEDIR := -I"../../opencv_MinGW/include" -I"./"
    12. #-I"$(SRCDIR)"
    13. staticDir := ../../opencv_MinGW/x64/mingw/staticlib/
    14. #LIBDIR := $(staticDir)/libopencv_world460.a\
    15. # $(staticDir)/libade.a \
    16. # $(staticDir)/libIlmImf.a \
    17. # $(staticDir)/libquirc.a \
    18. # $(staticDir)/libzlib.a \
    19. # $(wildcard $(staticDir)/liblib*.a) \
    20. # -lgdi32 -lComDlg32 -lOleAut32 -lOle32 -luuid
    21. #opencv_world放弃前,然后是opencv依赖的第三方库,后面的库是MinGW编译工具的库
    22. LIBDIR := -L $(staticDir) -lopencv_world460 -lade -lIlmImf -lquirc -lzlib \
    23. -llibjpeg-turbo -llibopenjp2 -llibpng -llibprotobuf -llibtiff -llibwebp \
    24. -lgdi32 -lComDlg32 -lOleAut32 -lOle32 -luuid
    25. source := $(wildcard $(SRCDIR)/*.cpp)
    26. $(TARGET) :
    27. $(CX) $(FLAGS) $(INCLUDEDIR) $(source) -o $(BIN)/$(TARGET) $(LIBDIR)
    28. clean:
    29. rm $(BIN)/$(TARGET)

            编译如下:

    3.2 OpenCV+vc2015+cmake编译

            第二种编译,本文采用了vs2015 x64编译了opencv库C/C++开发,opencv在win下安装及应用_windows安装opencv c++-CSDN博客)。

            建立cmake文件:

    1. # CMake 最低版本号要求
    2. cmake_minimum_required (VERSION 2.8)
    3. # 项目信息
    4. project (opencv_test)
    5. #
    6. message(STATUS "windows compiling...")
    7. add_definitions(-D_PLATFORM_IS_WINDOWS_)
    8. set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
    9. set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd")
    10. set(WIN_OS true)
    11. #
    12. set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin)
    13. # 指定源文件的目录,并将名称保存到变量
    14. SET(source_h
    15. #
    16. )
    17. SET(source_cpp
    18. #
    19. ${PROJECT_SOURCE_DIR}/main.cpp
    20. )
    21. #头文件目录
    22. include_directories(${PROJECT_SOURCE_DIR}/../../opencv_VC/include)
    23. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4819")
    24. add_definitions(
    25. "-D_CRT_SECURE_NO_WARNINGS"
    26. "-D_WINSOCK_DEPRECATED_NO_WARNINGS"
    27. "-DNO_WARN_MBCS_MFC_DEPRECATION"
    28. "-DWIN32_LEAN_AND_MEAN"
    29. )
    30. link_directories(
    31. ${PROJECT_SOURCE_DIR}/../../opencv_VC/x64/vc14/bin
    32. ${PROJECT_SOURCE_DIR}/../../opencv_VC/x64/vc14/lib
    33. )
    34. if (CMAKE_BUILD_TYPE STREQUAL "Debug")
    35. set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR})
    36. # 指定生成目标
    37. add_executable(opencv_testd ${source_h} ${source_cpp})
    38. else(CMAKE_BUILD_TYPE)
    39. set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR})
    40. # 指定生成目标
    41. add_executable(opencv_test ${source_h} ${source_cpp})
    42. target_link_libraries(opencv_test opencv_world460.lib opencv_img_hash460.lib)
    43. endif (CMAKE_BUILD_TYPE)
    44. # mkdir build_win
    45. # cd build_win
    46. # cmake -G "Visual Studio 14 2015 Win64" -DCMAKE_BUILD_TYPE=Release ..
    47. # msbuild opencv_test.sln /p:Configuration="Release" /p:Platform="x64"

    启动vs2015 x64的命令工具(使前面配置的环境变量生效),进入main.cpp文件目录,编译如下:

    1. mkdir build_win
    2. cd build_win
    3. cmake -G "Visual Studio 14 2015 Win64" -DCMAKE_BUILD_TYPE=Release ..
    4. msbuild opencv_test.sln /p:Configuration="Release" /p:Platform="x64"

            编译输出大致如下:

    3.3 执行效果

            【1】OpenCV+MinGW+makefile编译程序执行输出,准确率达到98%以上(PS,大家可尝试去调设SVM模型的参数设置,看怎样设置可以获得更高的准确率)

            通过模型调用识别图片全OK(呵呵,毕竟是测试集内的图片数据)

    【2】opencv+vc2015+cmake编译程序执行输出,同样能到达效果。

    3.4 附件,main.cpp全文
    1. #include
    2. #include
    3. #include
    4. #include
    5. #include
    6. #include
    7. #include
    8. int intReverse(int num)
    9. {
    10. return (num>>24|((num&0xFF0000)>>8)|((num&0xFF00)<<8)|((num&0xFF)<<24));
    11. }
    12. std::string intToString(int num)
    13. {
    14. char buf[32]={0};
    15. itoa(num,buf,10);
    16. return std::string(buf);
    17. }
    18. cv::Mat read_mnist_image(const std::string fileName) {
    19. int magic_number = 0;
    20. int number_of_images = 0;
    21. int img_rows = 0;
    22. int img_cols = 0;
    23. cv::Mat DataMat;
    24. std::ifstream file(fileName, std::ios::binary);
    25. if (file.is_open())
    26. {
    27. std::cout << "open images file: "<< fileName << std::endl;
    28. file.read((char*)&magic_number, sizeof(magic_number));//format
    29. file.read((char*)&number_of_images, sizeof(number_of_images));//images number
    30. file.read((char*)&img_rows, sizeof(img_rows));//img rows
    31. file.read((char*)&img_cols, sizeof(img_cols));//img cols
    32. magic_number = intReverse(magic_number);
    33. number_of_images = intReverse(number_of_images);
    34. img_rows = intReverse(img_rows);
    35. img_cols = intReverse(img_cols);
    36. std::cout << "format:" << magic_number
    37. << " img num:" << number_of_images
    38. << " img row:" << img_rows
    39. << " img col:" << img_cols << std::endl;
    40. std::cout << "read img data" << std::endl;
    41. DataMat = cv::Mat::zeros(number_of_images, img_rows * img_cols, CV_32FC1);
    42. unsigned char temp = 0;
    43. for (int i = 0; i < number_of_images; i++) {
    44. for (int j = 0; j < img_rows * img_cols; j++) {
    45. file.read((char*)&temp, sizeof(temp));
    46. //svm data is CV_32FC1
    47. float pixel_value = float(temp);
    48. DataMat.at<float>(i, j) = pixel_value;
    49. }
    50. }
    51. std::cout << "read img data finish!" << std::endl;
    52. }
    53. file.close();
    54. return DataMat;
    55. }
    56. cv::Mat read_mnist_label(const std::string fileName) {
    57. int magic_number;
    58. int number_of_items;
    59. cv::Mat LabelMat;
    60. std::ifstream file(fileName, std::ios::binary);
    61. if (file.is_open())
    62. {
    63. std::cout << "open label file: "<< fileName << std::endl;
    64. file.read((char*)&magic_number, sizeof(magic_number));
    65. file.read((char*)&number_of_items, sizeof(number_of_items));
    66. magic_number = intReverse(magic_number);
    67. number_of_items = intReverse(number_of_items);
    68. std::cout << "format:" << magic_number << " ;label_num:" << number_of_items << std::endl;
    69. std::cout << "read Label data" << std::endl;
    70. //data type:CV_32SC1,channel:1
    71. LabelMat = cv::Mat::zeros(number_of_items, 1, CV_32SC1);
    72. for (int i = 0; i < number_of_items; i++) {
    73. unsigned char temp = 0;
    74. file.read((char*)&temp, sizeof(temp));
    75. LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
    76. }
    77. std::cout << "read label data finish!" << std::endl;
    78. }
    79. file.close();
    80. return LabelMat;
    81. }
    82. //change path for real paths
    83. std::string trainImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-images.idx3-ubyte";
    84. std::string trainLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-labels.idx1-ubyte";
    85. std::string testImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images.idx3-ubyte";
    86. std::string testLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-labels.idx1-ubyte";
    87. void train_SVM()
    88. {
    89. //read train images, data type CV_32FC1
    90. cv::Mat trainingData = read_mnist_image(trainImgFile);
    91. //images data normalization
    92. trainingData = trainingData/255.0;
    93. std::cout << "trainingData.size() = " << trainingData.size() << std::endl;
    94. //read train label, data type CV_32SC1
    95. cv::Mat labelsMat = read_mnist_label(trainLabeFile);
    96. std::cout << "labelsMat.size() = " << labelsMat.size() << std::endl;
    97. std::cout << "trainingData & labelsMat finish!" << std::endl;
    98. //create SVM model
    99. cv::Ptr svm = cv::ml::SVM::create();
    100. //set svm args,type and KernelTypes
    101. svm->setType(cv::ml::SVM::C_SVC);
    102. svm->setKernel(cv::ml::SVM::POLY);
    103. //KernelTypes POLY is need set gamma and degree
    104. svm->setGamma(3.0);
    105. svm->setDegree(2.0);
    106. //Set iteration termination conditions, maxCount is importance
    107. svm->setTermCriteria(cv::TermCriteria(cv::TermCriteria::EPS | cv::TermCriteria::COUNT, 1000, 1e-8));
    108. std::cout << "create SVM object finish!" << std::endl;
    109. std::cout << "trainingData.rows = " << trainingData.rows << std::endl;
    110. std::cout << "trainingData.cols = " << trainingData.cols << std::endl;
    111. std::cout << "trainingData.type() = " << trainingData.type() << std::endl;
    112. // svm model train
    113. svm->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);
    114. std::cout << "SVM training finish!" << std::endl;
    115. // svm model test
    116. cv::Mat testData = read_mnist_image(testImgFile);
    117. //images data normalization
    118. testData = testData/255.0;
    119. std::cout << "testData.rows = " << testData.rows << std::endl;
    120. std::cout << "testData.cols = " << testData.cols << std::endl;
    121. std::cout << "testData.type() = " << testData.type() << std::endl;
    122. //read test label, data type CV_32SC1
    123. cv::Mat testlabel = read_mnist_label(testLabeFile);
    124. cv::Mat testResp;
    125. float response = svm->predict(testData,testResp);
    126. // std::cout << "response = " << response << std::endl;
    127. testResp.convertTo(testResp,CV_32SC1);
    128. int map_num = 0;
    129. for (int i = 0; i
    130. {
    131. if (testResp.at<int>(i, 0) == testlabel.at<int>(i, 0))
    132. {
    133. map_num++;
    134. }
    135. // else{
    136. // std::cout << "testResp.at(i, 0) " << testResp.at(i, 0) << std::endl;
    137. // std::cout << "testlabel.at(i, 0) " << testlabel.at(i, 0) << std::endl;
    138. // }
    139. }
    140. float proportion = float(map_num) / float(testResp.rows);
    141. std::cout << "map rate: " << proportion * 100 << "%" << std::endl;
    142. std::cout << "SVM testing finish!" << std::endl;
    143. //save svm model
    144. svm->save("mnist_svm.xml");
    145. }
    146. void prediction(const std::string fileName,cv::Ptr svm)
    147. {
    148. //read img 28*28 size
    149. cv::Mat image = cv::imread(fileName, cv::IMREAD_GRAYSCALE);
    150. //uchar->float32
    151. image.convertTo(image, CV_32F);
    152. //image data normalization
    153. image = image / 255.0;
    154. //28*28 -> 1*784
    155. image = image.reshape(1, 1);
    156. //预测图片
    157. float ret = svm->predict(image);
    158. std::cout << "predict val = "<< ret << std::endl;
    159. }
    160. std::string imgDir = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images\\";
    161. std::string ImgFiles[5] = {"image_0.png","image_10.png","image_20.png","image_30.png","image_40.png",};
    162. void predictimgs()
    163. {
    164. //load svm model
    165. cv::Ptr svm = cv::ml::StatModel::load("mnist_svm.xml");
    166. for (size_t i = 0; i < 5; i++)
    167. {
    168. prediction(imgDir+ImgFiles[i],svm);
    169. }
    170. }
    171. int main()
    172. {
    173. train_SVM();
    174. predictimgs();
    175. return 0;
    176. }

  • 相关阅读:
    实现延迟队列的几种途径
    代码随想录笔记_动态规划_718最长重复子数组
    linux下golang环境安装教程(学习笔记)
    02-分布式协调服务ZooKeeper
    交易日均千万订单的存储架构设计与实践 | 京东物流技术团队
    机器学习算法详解3:逻辑回归
    UFC765AE102 ABB数据密集型边缘人工智能
    vr电力作业安全培训覆盖三大板块,为学员提供高仿真的技能培训
    华为低代码TinyEngine ——flow 图元编排设计器
    npm install 卡在reify:rxjs: timing reifyNode的解决办法
  • 原文地址:https://blog.csdn.net/py8105/article/details/138042907