本节讲机器学习 C++ 的opencv实现SVM图像二分类的训练,下节讲测试:
数据集合data内容如下:
下载地址为:https://download.csdn.net/download/hgaohr1021/89506900
#include <stdio.h>
#include <time.h>
#include <opencv2/opencv.hpp> #include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "opencv2/imgcodecs.hpp"
#include <opencv2/core/utils/logger.hpp>
#include <opencv2/ml/ml.hpp>
#include <io.h>using namespace std;
using namespace cv;
using namespace cv::ml;void getFiles(string path, vector<string>& files);
void get_1(Mat& trainingImages, vector<int>& trainingLabels);
void get_0(Mat& trainingImages, vector<int>& trainingLabels);int main()
{//获取训练数据Mat classes;Mat trainingData;Mat trainingImages;vector<int> trainingLabels;get_1(trainingImages, trainingLabels);//waitKey(2000);get_0(trainingImages, trainingLabels);Mat(trainingImages).copyTo(trainingData);trainingData.convertTo(trainingData, CV_32FC1);Mat(trainingLabels).copyTo(classes);//配置SVM训练器参数Ptr<SVM> svm = SVM::create();svm->setType(SVM::C_SVC);svm->setKernel(SVM::LINEAR);svm->setDegree(0);svm->setGamma(1);svm->setCoef0(0);svm->setC(1);svm->setNu(0);svm->setP(0);svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 1000, 0.01));//训练svm->train(trainingData, ROW_SAMPLE, classes);//保存模型svm->save("svm.xml");cout << "训练好了!!!" << endl;getchar();return 0;
}
void getFiles(string path, vector<string>& files)
{long long hFile = 0;struct _finddata_t fileinfo;string p;if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1){do{if ((fileinfo.attrib & _A_SUBDIR)){if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)getFiles(p.assign(path).append("\\").append(fileinfo.name), files);}else{files.push_back(p.assign(path).append("\\").append(fileinfo.name));}} while (_findnext(hFile, &fileinfo) == 0);_findclose(hFile);}
}
void get_1(Mat& trainingImages, vector<int>& trainingLabels)
{string filePath = "data\\train_image\\1";vector<string> files;getFiles(filePath, files);int number = files.size();for (int i = 0; i < number; i++){Mat SrcImage = imread(files[i].c_str());resize(SrcImage, SrcImage, cv::Size(60, 256), (0, 0), (0, 0), cv::INTER_LINEAR); //将图片调整为相同的大小SrcImage = SrcImage.reshape(1, 1);trainingImages.push_back(SrcImage);trainingLabels.push_back(1);}
}
void get_0(Mat& trainingImages, vector<int>& trainingLabels)
{string filePath = "data\\train_image\\0";vector<string> files;getFiles(filePath, files);int number = files.size();for (int i = 0; i < number; i++){Mat SrcImage = imread(files[i].c_str());resize(SrcImage, SrcImage, cv::Size(60, 256), (0, 0), (0, 0), cv::INTER_LINEAR); //将图片调整为相同的大小SrcImage = SrcImage.reshape(1, 1);trainingImages.push_back(SrcImage);trainingLabels.push_back(0);}
}
运行结果为:
运行玩,在根目录里面出现,svm.xml文件,为下一节,测试图片用。
数据集下载地址为:https://download.csdn.net/download/hgaohr1021/89506900