• 【opencv450-samples】train_svmsgd.cpp


     

    与SVM不同,SVMSGD不需要设置核函数。

    【参数】默认值见下述代码

    模型类型:SGD、ASGD(推荐)。随机梯度下降、平均随机梯度下降。
    边界类型:HARD_MARGIN、SOFT_MARGIN(推荐),前者用于线性可分,后者用于非线性可分
    边界规范化 lambda:推荐设为0.0001(对于SGD),0.00001(对于ASGD)。越小,异类被抛弃的越少。
    步长 gamma_0
    步长降低力度 c:推荐设置为1(对于SGD),0.75(对于ASGD)
    终止条件:TermCriteria::COUNT、TermCriteria::EPS、TermCriteria::COUNT + TermCriteria::EPS

    参数设置函数:

    setSvmsgdType()
    setMarginType()
    setMarginRegularization()
    setInitialStepSize()
    setStepDecreasingPower()

    【使用方式】

    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();//创建对象
    svmsgd->train(trainData);//训练
    svmsgd->save("MySvmsgd.xml");//保存模型
    svmsgd->load("MySvmsgd.xml");//加载模型
    svmsgd->predict(samples, responses);//预测,结果保存到responses标签中

    1. #include "opencv2/core.hpp"
    2. #include "opencv2/video/tracking.hpp"
    3. #include "opencv2/imgproc.hpp"
    4. #include "opencv2/highgui.hpp"
    5. #include "opencv2/ml.hpp"
    6. using namespace cv;
    7. using namespace cv::ml;
    8. //https://www.cnblogs.com/xixixing/p/12430202.html
    9. struct Data
    10. {
    11. Mat img;
    12. Mat samples; //一组训练样本。 包含图像上的点Set of train samples. Contains points on image
    13. Mat responses; //训练样本的标签 Set of responses for train samples
    14. Data() //显示图像
    15. {
    16. const int WIDTH = 841;
    17. const int HEIGHT = 594;
    18. img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
    19. imshow("Train svmsgd", img);
    20. }
    21. };
    22. //Train with SVMSGD algorithm
    23. //(samples, responses) is a train set
    24. //weights is a required vector for decision function of SVMSGD algorithm
    25. //用SVMSGD算法训练
    26. //(samples,responses) 是一个训练集
    27. //weights 是 SVMSGD 算法决策函数所需的向量
    28. bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);
    29. //function finds two points for drawing line (wx = 0)
    30. //函数找到绘制线的两个点(wx = 0)
    31. bool findPointsForLine(const Mat &weights, float shift, Point points[], int width, int height);
    32. // function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
    33. // 函数找到线 (wx = 0) 和线段 ( (y = HEIGHT, 0 <= x <= WIDTH) 或 (x = WIDTH, 0 <= y <= HEIGHT) ) 的交叉点
    34. bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
    35. //segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
    36. //线段的初始化 ( (y = HEIGHT, 0 <= x <= WIDTH) 和 (x = WIDTH, 0 <= y <= HEIGHT) )
    37. void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);
    38. //redraw points' set and line (wx = 0)
    39. //重绘点的集合和线(wx = 0)
    40. void redraw(Data data, const Point points[2]);
    41. //add point in train set, train SVMSGD algorithm and draw results on image
    42. //在训练集中添加点,训练SVMSGD算法并在图像上绘制结果
    43. void addPointRetrainAndRedraw(Data &data, int x, int y, int response);
    44. //训练 得到参数
    45. bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
    46. {
    47. cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
    48. //*设置参数,以下全是默认参数
    49. //svmsgd->setSvmsgdType(SVMSGD::ASGD); //模型类型
    50. //svmsgd->setMarginType(SVMSGD::SOFT_MARGIN); //边界类型
    51. //svmsgd->setMarginRegularization(0.00001); //边界规范化
    52. //svmsgd->setInitialStepSize(0.05);//步长
    53. //svmsgd->setStepDecreasingPower(0.75); //步长减弱力度
    54. //svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT,1000,1e-3));//终止条件,1000次迭代,0.001每次迭代的精度
    55. cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);//构造训练数据
    56. svmsgd->train( trainData );
    57. if (svmsgd->isTrained())
    58. {
    59. weights = svmsgd->getWeights();
    60. shift = svmsgd->getShift();
    61. //*保存模型
    62. svmsgd->save("svmsgd.xml"); //保存训练好的模型
    63. return true;
    64. }
    65. return false;
    66. }
    67. //找出边界四条直线
    68. void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
    69. {
    70. std::pair<Point,Point> currentSegment;//当前线段
    71. currentSegment.first = Point(width, 0);//右上角点
    72. currentSegment.second = Point(width, height);//右下角点
    73. segments.push_back(currentSegment);
    74. currentSegment.first = Point(0, height);//左下角点
    75. currentSegment.second = Point(width, height);//右下角点
    76. segments.push_back(currentSegment);
    77. currentSegment.first = Point(0, 0);//左上角点
    78. currentSegment.second = Point(width, 0);//右上角点
    79. segments.push_back(currentSegment);
    80. currentSegment.first = Point(0, 0);
    81. currentSegment.second = Point(0, height);
    82. segments.push_back(currentSegment);
    83. }
    84. //找到与边界框交点
    85. bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
    86. {
    87. int x = 0;
    88. int y = 0;
    89. int xMin = std::min(segment.first.x, segment.second.x);
    90. int xMax = std::max(segment.first.x, segment.second.x);
    91. int yMin = std::min(segment.first.y, segment.second.y);
    92. int yMax = std::max(segment.first.y, segment.second.y);
    93. CV_Assert(weights.type() == CV_32FC1);
    94. CV_Assert(xMin == xMax || yMin == yMax);//断言:线段为垂直或者水平
    95. //一条垂直线 边框的左侧和右侧线
    96. if (xMin == xMax && weights.at<float>(1) != 0) //AX+BY+C=0 B!=0
    97. {
    98. x = xMin;
    99. y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));
    100. if (y >= yMin && y <= yMax)
    101. { //直线与边框左右侧线条的交点
    102. crossPoint.x = x;
    103. crossPoint.y = y;
    104. return true;
    105. }
    106. }
    107. //一条水平线 边框的上侧和下侧线
    108. else if (yMin == yMax && weights.at<float>(0) != 0)//A!=0
    109. {
    110. y = yMin;
    111. x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));
    112. if (x >= xMin && x <= xMax)
    113. { //直线与边框上下端线条的交点
    114. crossPoint.x = x;
    115. crossPoint.y = y;
    116. return true;
    117. }
    118. }
    119. return false;
    120. }
    121. //根据直线找到与边界框的交点 2个
    122. bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
    123. {
    124. if (weights.empty())//直线权重参数非空
    125. {
    126. return false;
    127. }
    128. int foundPointsCount = 0;//找到的点数
    129. std::vector<std::pair<Point,Point> > segments;//点对集合 线段集合
    130. fillSegments(segments, width, height);//找到边界框
    131. for (uint i = 0; i < segments.size(); i++)
    132. { //找到直线与边界框的交点
    133. if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
    134. foundPointsCount++;//直线与边界框交点数
    135. if (foundPointsCount >= 2)
    136. break;
    137. }
    138. return true;
    139. }
    140. //绘制直线
    141. void redraw(Data data, const Point points[2])
    142. {
    143. data.img.setTo(0);//黑色背景
    144. Point center;//样本中心点
    145. int radius = 3;//半径3
    146. Scalar color;
    147. CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1));//断言:数据样本类型
    148. for (int i = 0; i < data.samples.rows; i++)//遍历样本
    149. {
    150. center.x = static_cast<int>(data.samples.at<float>(i,0));
    151. center.y = static_cast<int>(data.samples.at<float>(i,1));
    152. color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
    153. circle(data.img, center, radius, color, 5);//绘制样本点
    154. }
    155. line(data.img, points[0], points[1],cv::Scalar(1,255,1));//绘制直线
    156. imshow("Train svmsgd", data.img);//显示图像
    157. }
    158. //添加点 标签response:1 / -1
    159. void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
    160. {
    161. Mat currentSample(1, 2, CV_32FC1);//临时点坐标 x,y float
    162. currentSample.at<float>(0,0) = (float)x;
    163. currentSample.at<float>(0,1) = (float)y;
    164. data.samples.push_back(currentSample);//添加到数据样本中
    165. data.responses.push_back(static_cast<float>(response));//添加到数据标签中
    166. Mat weights(1, 2, CV_32FC1);//权重系数A,B 超平面: AX+BY+C=0
    167. float shift = 0;//C
    168. //训练,得到超平面即直线参数
    169. if (doTrain(data.samples, data.responses, weights, shift))
    170. {
    171. Point points[2];
    172. findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);//找到直线与边界框的交点
    173. redraw(data, points);//绘制直线和样本点
    174. }
    175. }
    176. //鼠标回调
    177. static void onMouse( int event, int x, int y, int, void* pData)
    178. {
    179. Data &data = *(Data*)pData;//数据指针
    180. switch( event )
    181. {
    182. case EVENT_LBUTTONUP:
    183. addPointRetrainAndRedraw(data, x, y, 1);//左键 添加点标签1
    184. break;
    185. case EVENT_RBUTTONDOWN:
    186. addPointRetrainAndRedraw(data, x, y, -1);//右键 添加点标签-1
    187. break;
    188. }
    189. }
    190. int main()
    191. {
    192. Data data;
    193. setMouseCallback( "Train svmsgd", onMouse, &data );
    194. waitKey();
    195. return 0;
    196. }

     svmsgd.xml

    参考:

    基于SGD、ASGD算法的SVM分类器(OpenCV案例源码train_svmsgd.cpp解读) - 夕西行 - 博客园

  • 相关阅读:
    C++设计模式-创建型设计模式:抽象工厂
    ARM汇编(gun-complier)
    QGIS展示三维DEM数据
    无需公网IP,在家SSH远程连接公司内网服务器「cpolar内网穿透」
    Spark RDD 转换算子
    如何下载并安装jdk13版本
    JavaSE - 封装、static成员和内部类
    LeetCode98题:验证二叉搜索树(python3)
    WebDAV之葫芦儿·派盘+书藏家
    使用VSCode+PlatformIO搭建ESP32开发环境
  • 原文地址:https://blog.csdn.net/cxyhjl/article/details/125557923