2013-11-09 08:42:39 +08:00
|
|
|
#include "opencv2/opencv_modules.hpp"
|
2011-05-01 00:44:34 +08:00
|
|
|
#include "opencv2/core/core.hpp"
|
|
|
|
#include "opencv2/ml/ml.hpp"
|
|
|
|
#include "opencv2/highgui/highgui.hpp"
|
2013-11-09 08:42:39 +08:00
|
|
|
#ifdef HAVE_OPENCV_OCL
|
2013-11-12 13:13:25 +08:00
|
|
|
#define _OCL_KNN_ 1 // select whether using ocl::KNN method or not, default is using
|
|
|
|
#define _OCL_SVM_ 1 // select whether using ocl::svm method or not, default is using
|
2013-11-09 08:42:39 +08:00
|
|
|
#include "opencv2/ocl/ocl.hpp"
|
|
|
|
#endif
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
#include <stdio.h>
|
|
|
|
|
|
|
|
using namespace std;
|
|
|
|
using namespace cv;
|
|
|
|
|
2013-10-30 20:34:27 +08:00
|
|
|
const Scalar WHITE_COLOR = Scalar(255,255,255);
|
2011-05-01 00:44:34 +08:00
|
|
|
const string winName = "points";
|
|
|
|
const int testStep = 5;
|
|
|
|
|
2011-05-01 02:04:33 +08:00
|
|
|
Mat img, imgDst;
|
2011-05-01 00:44:34 +08:00
|
|
|
RNG rng;
|
|
|
|
|
|
|
|
vector<Point> trainedPoints;
|
|
|
|
vector<int> trainedPointsMarkers;
|
|
|
|
vector<Scalar> classColors;
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#define _NBC_ 0 // normal Bayessian classifier
|
|
|
|
#define _KNN_ 0 // k nearest neighbors classifier
|
|
|
|
#define _SVM_ 0 // support vectors machine
|
|
|
|
#define _DT_ 1 // decision tree
|
|
|
|
#define _BT_ 0 // ADA Boost
|
|
|
|
#define _GBT_ 0 // gradient boosted trees
|
|
|
|
#define _RF_ 0 // random forest
|
|
|
|
#define _ERT_ 0 // extremely randomized trees
|
|
|
|
#define _ANN_ 0 // artificial neural networks
|
|
|
|
#define _EM_ 0 // expectation-maximization
|
2011-05-01 00:44:34 +08:00
|
|
|
|
2012-06-15 21:04:17 +08:00
|
|
|
static void on_mouse( int event, int x, int y, int /*flags*/, void* )
|
2011-05-01 00:44:34 +08:00
|
|
|
{
|
|
|
|
if( img.empty() )
|
|
|
|
return;
|
|
|
|
|
|
|
|
int updateFlag = 0;
|
|
|
|
|
|
|
|
if( event == CV_EVENT_LBUTTONUP )
|
|
|
|
{
|
|
|
|
if( classColors.empty() )
|
|
|
|
return;
|
|
|
|
|
|
|
|
trainedPoints.push_back( Point(x,y) );
|
2012-03-17 05:21:04 +08:00
|
|
|
trainedPointsMarkers.push_back( (int)(classColors.size()-1) );
|
2011-05-01 00:44:34 +08:00
|
|
|
updateFlag = true;
|
|
|
|
}
|
|
|
|
else if( event == CV_EVENT_RBUTTONUP )
|
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _BT_
|
2011-05-01 02:04:33 +08:00
|
|
|
if( classColors.size() < 2 )
|
|
|
|
{
|
|
|
|
#endif
|
|
|
|
classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
|
|
|
|
updateFlag = true;
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _BT_
|
2011-05-01 02:04:33 +08:00
|
|
|
}
|
|
|
|
else
|
|
|
|
cout << "New class can not be added, because CvBoost can only be used for 2-class classification" << endl;
|
|
|
|
#endif
|
|
|
|
|
2011-05-01 00:44:34 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//draw
|
|
|
|
if( updateFlag )
|
|
|
|
{
|
|
|
|
img = Scalar::all(0);
|
|
|
|
|
|
|
|
// put the text
|
|
|
|
stringstream text;
|
|
|
|
text << "current class " << classColors.size()-1;
|
2013-10-30 20:34:27 +08:00
|
|
|
putText( img, text.str(), Point(10,25), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
text.str("");
|
|
|
|
text << "total classes " << classColors.size();
|
2013-10-30 20:34:27 +08:00
|
|
|
putText( img, text.str(), Point(10,50), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
text.str("");
|
|
|
|
text << "total points " << trainedPoints.size();
|
2013-10-30 20:34:27 +08:00
|
|
|
putText(img, text.str(), Point(10,75), FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
// draw points
|
|
|
|
for( size_t i = 0; i < trainedPoints.size(); i++ )
|
|
|
|
circle( img, trainedPoints[i], 5, classColors[trainedPointsMarkers[i]], -1 );
|
|
|
|
|
|
|
|
imshow( winName, img );
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2012-06-15 21:04:17 +08:00
|
|
|
static void prepare_train_data( Mat& samples, Mat& classes )
|
2011-05-01 00:44:34 +08:00
|
|
|
{
|
|
|
|
Mat( trainedPoints ).copyTo( samples );
|
|
|
|
Mat( trainedPointsMarkers ).copyTo( classes );
|
|
|
|
|
|
|
|
// reshape trainData and change its type
|
|
|
|
samples = samples.reshape( 1, samples.rows );
|
|
|
|
samples.convertTo( samples, CV_32FC1 );
|
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _NBC_
|
2012-06-15 21:04:17 +08:00
|
|
|
static void find_decision_boundary_NBC()
|
2011-05-01 02:04:33 +08:00
|
|
|
{
|
|
|
|
img.copyTo( imgDst );
|
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// learn classifier
|
|
|
|
CvNormalBayesClassifier normalBayesClassifier( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
Mat testSample( 1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
|
|
|
int response = (int)normalBayesClassifier.predict( testSample );
|
|
|
|
circle( imgDst, Point(x,y), 1, classColors[response] );
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _KNN_
|
2012-06-15 21:04:17 +08:00
|
|
|
static void find_decision_boundary_KNN( int K )
|
2011-05-01 00:44:34 +08:00
|
|
|
{
|
2011-05-01 02:04:33 +08:00
|
|
|
img.copyTo( imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// learn classifier
|
2013-11-11 11:46:07 +08:00
|
|
|
#if defined HAVE_OPENCV_OCL && _OCL_KNN_
|
2013-11-09 08:42:39 +08:00
|
|
|
cv::ocl::KNearestNeighbour knnClassifier;
|
|
|
|
Mat temp, result;
|
|
|
|
knnClassifier.train(trainSamples, trainClasses, temp, false, K);
|
|
|
|
cv::ocl::oclMat testSample_ocl, reslut_ocl;
|
|
|
|
#else
|
2011-05-01 00:44:34 +08:00
|
|
|
CvKNearest knnClassifier( trainSamples, trainClasses, Mat(), false, K );
|
2013-11-09 08:42:39 +08:00
|
|
|
#endif
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat testSample( 1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
2013-11-11 11:46:07 +08:00
|
|
|
#if defined HAVE_OPENCV_OCL && _OCL_KNN_
|
2013-11-09 08:42:39 +08:00
|
|
|
testSample_ocl.upload(testSample);
|
|
|
|
|
|
|
|
knnClassifier.find_nearest(testSample_ocl, K, reslut_ocl);
|
|
|
|
|
|
|
|
reslut_ocl.download(result);
|
|
|
|
int response = saturate_cast<int>(result.at<float>(0));
|
|
|
|
circle(imgDst, Point(x, y), 1, classColors[response]);
|
|
|
|
#else
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
int response = (int)knnClassifier.find_nearest( testSample, K );
|
2011-05-01 02:04:33 +08:00
|
|
|
circle( imgDst, Point(x,y), 1, classColors[response] );
|
2013-11-09 08:42:39 +08:00
|
|
|
#endif
|
2011-05-01 00:44:34 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _SVM_
|
2012-06-15 21:04:17 +08:00
|
|
|
static void find_decision_boundary_SVM( CvSVMParams params )
|
2011-05-01 00:44:34 +08:00
|
|
|
{
|
2011-05-01 02:04:33 +08:00
|
|
|
img.copyTo( imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// learn classifier
|
2013-11-11 11:46:07 +08:00
|
|
|
#if defined HAVE_OPENCV_OCL && _OCL_SVM_
|
2013-11-09 08:42:39 +08:00
|
|
|
cv::ocl::CvSVM_OCL svmClassifier(trainSamples, trainClasses, Mat(), Mat(), params);
|
|
|
|
#else
|
2011-05-01 00:44:34 +08:00
|
|
|
CvSVM svmClassifier( trainSamples, trainClasses, Mat(), Mat(), params );
|
2013-11-09 08:42:39 +08:00
|
|
|
#endif
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat testSample( 1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
|
|
|
int response = (int)svmClassifier.predict( testSample );
|
2011-05-01 02:04:33 +08:00
|
|
|
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
2011-05-01 00:44:34 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for( int i = 0; i < svmClassifier.get_support_vector_count(); i++ )
|
|
|
|
{
|
|
|
|
const float* supportVector = svmClassifier.get_support_vector(i);
|
2013-11-09 08:42:39 +08:00
|
|
|
circle( imgDst, Point(saturate_cast<int>(supportVector[0]),saturate_cast<int>(supportVector[1])), 5, CV_RGB(255,255,255), -1 );
|
2011-05-01 00:44:34 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _DT_
|
2012-06-15 21:04:17 +08:00
|
|
|
static void find_decision_boundary_DT()
|
2011-05-01 00:44:34 +08:00
|
|
|
{
|
2011-05-01 02:04:33 +08:00
|
|
|
img.copyTo( imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// learn classifier
|
|
|
|
CvDTree dtree;
|
|
|
|
|
|
|
|
Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
|
|
|
|
var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
|
|
|
|
|
|
|
|
CvDTreeParams params;
|
|
|
|
params.max_depth = 8;
|
|
|
|
params.min_sample_count = 2;
|
|
|
|
params.use_surrogates = false;
|
|
|
|
params.cv_folds = 0; // the number of cross-validation folds
|
|
|
|
params.use_1se_rule = false;
|
|
|
|
params.truncate_pruned_tree = false;
|
|
|
|
|
|
|
|
dtree.train( trainSamples, CV_ROW_SAMPLE, trainClasses,
|
|
|
|
Mat(), Mat(), var_types, Mat(), params );
|
|
|
|
|
|
|
|
Mat testSample(1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
|
|
|
int response = (int)dtree.predict( testSample )->value;
|
2011-05-01 02:04:33 +08:00
|
|
|
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _BT_
|
2011-05-01 02:04:33 +08:00
|
|
|
void find_decision_boundary_BT()
|
|
|
|
{
|
|
|
|
img.copyTo( imgDst );
|
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// learn classifier
|
|
|
|
CvBoost boost;
|
|
|
|
|
|
|
|
Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
|
|
|
|
var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
|
|
|
|
|
|
|
|
CvBoostParams params( CvBoost::DISCRETE, // boost_type
|
|
|
|
100, // weak_count
|
|
|
|
0.95, // weight_trim_rate
|
|
|
|
2, // max_depth
|
|
|
|
false, //use_surrogates
|
|
|
|
0 // priors
|
|
|
|
);
|
|
|
|
|
|
|
|
boost.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
|
|
|
|
|
|
|
|
Mat testSample(1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
|
|
|
int response = (int)boost.predict( testSample );
|
|
|
|
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _GBT_
|
2011-05-01 02:04:33 +08:00
|
|
|
void find_decision_boundary_GBT()
|
|
|
|
{
|
|
|
|
img.copyTo( imgDst );
|
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// learn classifier
|
|
|
|
CvGBTrees gbtrees;
|
|
|
|
|
|
|
|
Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
|
|
|
|
var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
|
|
|
|
|
2011-05-04 22:49:02 +08:00
|
|
|
CvGBTreesParams params( CvGBTrees::DEVIANCE_LOSS, // loss_function_type
|
2011-05-01 02:04:33 +08:00
|
|
|
100, // weak_count
|
2011-05-04 22:49:02 +08:00
|
|
|
0.1f, // shrinkage
|
|
|
|
1.0f, // subsample_portion
|
2011-05-01 02:04:33 +08:00
|
|
|
2, // max_depth
|
2011-05-01 17:01:57 +08:00
|
|
|
false // use_surrogates )
|
2011-05-04 22:49:02 +08:00
|
|
|
);
|
2011-05-01 02:04:33 +08:00
|
|
|
|
|
|
|
gbtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
|
|
|
|
|
|
|
|
Mat testSample(1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
|
|
|
int response = (int)gbtrees.predict( testSample );
|
|
|
|
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
2011-05-01 00:44:34 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2011-05-01 02:04:33 +08:00
|
|
|
|
2011-05-01 00:44:34 +08:00
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _RF_
|
2011-05-01 00:44:34 +08:00
|
|
|
void find_decision_boundary_RF()
|
|
|
|
{
|
2011-05-01 02:04:33 +08:00
|
|
|
img.copyTo( imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// learn classifier
|
|
|
|
CvRTrees rtrees;
|
|
|
|
CvRTParams params( 4, // max_depth,
|
|
|
|
2, // min_sample_count,
|
|
|
|
0.f, // regression_accuracy,
|
|
|
|
false, // use_surrogates,
|
|
|
|
16, // max_categories,
|
|
|
|
0, // priors,
|
|
|
|
false, // calc_var_importance,
|
|
|
|
1, // nactive_vars,
|
|
|
|
5, // max_num_of_trees_in_the_forest,
|
|
|
|
0, // forest_accuracy,
|
|
|
|
CV_TERMCRIT_ITER // termcrit_type
|
|
|
|
);
|
|
|
|
|
2011-05-04 22:49:02 +08:00
|
|
|
rtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), Mat(), Mat(), params );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat testSample(1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
|
|
|
int response = (int)rtrees.predict( testSample );
|
2011-05-01 02:04:33 +08:00
|
|
|
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
2011-05-01 00:44:34 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _ERT_
|
2011-05-01 02:04:33 +08:00
|
|
|
void find_decision_boundary_ERT()
|
|
|
|
{
|
|
|
|
img.copyTo( imgDst );
|
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// learn classifier
|
|
|
|
CvERTrees ertrees;
|
|
|
|
|
|
|
|
Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
|
|
|
|
var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
|
|
|
|
|
|
|
|
CvRTParams params( 4, // max_depth,
|
|
|
|
2, // min_sample_count,
|
|
|
|
0.f, // regression_accuracy,
|
|
|
|
false, // use_surrogates,
|
|
|
|
16, // max_categories,
|
|
|
|
0, // priors,
|
|
|
|
false, // calc_var_importance,
|
|
|
|
1, // nactive_vars,
|
|
|
|
5, // max_num_of_trees_in_the_forest,
|
|
|
|
0, // forest_accuracy,
|
|
|
|
CV_TERMCRIT_ITER // termcrit_type
|
|
|
|
);
|
|
|
|
|
|
|
|
ertrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
|
|
|
|
|
|
|
|
Mat testSample(1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
|
|
|
int response = (int)ertrees.predict( testSample );
|
|
|
|
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _ANN_
|
2011-05-01 00:44:34 +08:00
|
|
|
void find_decision_boundary_ANN( const Mat& layer_sizes )
|
|
|
|
{
|
2011-05-01 02:04:33 +08:00
|
|
|
img.copyTo( imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
|
|
|
// prerare trainClasses
|
|
|
|
trainClasses.create( trainedPoints.size(), classColors.size(), CV_32FC1 );
|
|
|
|
for( int i = 0; i < trainClasses.rows; i++ )
|
|
|
|
{
|
|
|
|
for( int k = 0; k < trainClasses.cols; k++ )
|
|
|
|
{
|
|
|
|
if( k == trainedPointsMarkers[i] )
|
|
|
|
trainClasses.at<float>(i,k) = 1;
|
|
|
|
else
|
|
|
|
trainClasses.at<float>(i,k) = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Mat weights( 1, trainedPoints.size(), CV_32FC1, Scalar::all(1) );
|
|
|
|
|
|
|
|
// learn classifier
|
|
|
|
CvANN_MLP ann( layer_sizes, CvANN_MLP::SIGMOID_SYM, 1, 1 );
|
|
|
|
ann.train( trainSamples, trainClasses, weights );
|
|
|
|
|
|
|
|
Mat testSample( 1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
|
|
|
Mat outputs( 1, classColors.size(), CV_32FC1, testSample.data );
|
|
|
|
ann.predict( testSample, outputs );
|
|
|
|
Point maxLoc;
|
|
|
|
minMaxLoc( outputs, 0, 0, 0, &maxLoc );
|
2011-05-01 02:04:33 +08:00
|
|
|
circle( imgDst, Point(x,y), 2, classColors[maxLoc.x], 1 );
|
2011-05-01 00:44:34 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _EM_
|
2011-05-01 02:04:33 +08:00
|
|
|
void find_decision_boundary_EM()
|
2011-05-01 00:44:34 +08:00
|
|
|
{
|
2011-05-01 02:04:33 +08:00
|
|
|
img.copyTo( imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
Mat trainSamples, trainClasses;
|
|
|
|
prepare_train_data( trainSamples, trainClasses );
|
|
|
|
|
2012-04-16 22:54:56 +08:00
|
|
|
vector<cv::EM> em_models(classColors.size());
|
2011-05-01 00:44:34 +08:00
|
|
|
|
2012-04-16 22:54:56 +08:00
|
|
|
CV_Assert((int)trainClasses.total() == trainSamples.rows);
|
|
|
|
CV_Assert((int)trainClasses.type() == CV_32SC1);
|
|
|
|
|
|
|
|
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
|
|
|
|
{
|
|
|
|
const int componentCount = 3;
|
|
|
|
em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL);
|
|
|
|
|
|
|
|
Mat modelSamples;
|
|
|
|
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
|
|
|
{
|
|
|
|
if(trainClasses.at<int>(sampleIndex) == (int)modelIndex)
|
|
|
|
modelSamples.push_back(trainSamples.row(sampleIndex));
|
|
|
|
}
|
|
|
|
|
|
|
|
// learn models
|
|
|
|
if(!modelSamples.empty())
|
|
|
|
em_models[modelIndex].train(modelSamples);
|
|
|
|
}
|
2011-05-01 00:44:34 +08:00
|
|
|
|
2012-04-16 22:54:56 +08:00
|
|
|
// classify coordinate plane points using the bayes classifier, i.e.
|
|
|
|
// y(x) = arg max_i=1_modelsCount likelihoods_i(x)
|
2011-05-01 00:44:34 +08:00
|
|
|
Mat testSample(1, 2, CV_32FC1 );
|
|
|
|
for( int y = 0; y < img.rows; y += testStep )
|
|
|
|
{
|
|
|
|
for( int x = 0; x < img.cols; x += testStep )
|
|
|
|
{
|
|
|
|
testSample.at<float>(0) = (float)x;
|
|
|
|
testSample.at<float>(1) = (float)y;
|
|
|
|
|
2012-04-16 22:54:56 +08:00
|
|
|
Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX));
|
|
|
|
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
|
|
|
|
{
|
|
|
|
if(em_models[modelIndex].isTrained())
|
2012-04-17 14:29:40 +08:00
|
|
|
logLikelihoods.at<double>(modelIndex) = em_models[modelIndex].predict(testSample)[0];
|
2012-04-16 22:54:56 +08:00
|
|
|
}
|
|
|
|
Point maxLoc;
|
|
|
|
minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
|
|
|
|
|
|
|
|
int response = maxLoc.x;
|
2011-05-01 02:04:33 +08:00
|
|
|
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
2011-05-01 00:44:34 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
|
|
|
int main()
|
|
|
|
{
|
2011-05-01 02:04:33 +08:00
|
|
|
cout << "Use:" << endl
|
|
|
|
<< " right mouse button - to add new class;" << endl
|
|
|
|
<< " left mouse button - to add new point;" << endl
|
|
|
|
<< " key 'r' - to run the ML model;" << endl
|
|
|
|
<< " key 'i' - to init (clear) the data." << endl << endl;
|
|
|
|
|
2011-05-01 00:44:34 +08:00
|
|
|
cv::namedWindow( "points", 1 );
|
|
|
|
img.create( 480, 640, CV_8UC3 );
|
2011-05-01 02:04:33 +08:00
|
|
|
imgDst.create( 480, 640, CV_8UC3 );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
imshow( "points", img );
|
|
|
|
cvSetMouseCallback( "points", on_mouse );
|
|
|
|
|
|
|
|
for(;;)
|
|
|
|
{
|
2011-05-07 20:06:58 +08:00
|
|
|
uchar key = (uchar)waitKey();
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
if( key == 27 ) break;
|
|
|
|
|
|
|
|
if( key == 'i' ) // init
|
|
|
|
{
|
|
|
|
img = Scalar::all(0);
|
|
|
|
|
|
|
|
classColors.clear();
|
|
|
|
trainedPoints.clear();
|
|
|
|
trainedPointsMarkers.clear();
|
|
|
|
|
|
|
|
imshow( winName, img );
|
|
|
|
}
|
|
|
|
|
|
|
|
if( key == 'r' ) // run
|
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _NBC_
|
2011-05-01 02:04:33 +08:00
|
|
|
find_decision_boundary_NBC();
|
2013-10-30 20:34:27 +08:00
|
|
|
namedWindow( "NormalBayesClassifier", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "NormalBayesClassifier", imgDst );
|
|
|
|
#endif
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _KNN_
|
2011-05-01 00:44:34 +08:00
|
|
|
int K = 3;
|
|
|
|
find_decision_boundary_KNN( K );
|
|
|
|
namedWindow( "kNN", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "kNN", imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
K = 15;
|
|
|
|
find_decision_boundary_KNN( K );
|
|
|
|
namedWindow( "kNN2", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "kNN2", imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _SVM_
|
2011-05-01 00:44:34 +08:00
|
|
|
//(1)-(2)separable and not sets
|
|
|
|
CvSVMParams params;
|
|
|
|
params.svm_type = CvSVM::C_SVC;
|
|
|
|
params.kernel_type = CvSVM::POLY; //CvSVM::LINEAR;
|
|
|
|
params.degree = 0.5;
|
|
|
|
params.gamma = 1;
|
|
|
|
params.coef0 = 1;
|
|
|
|
params.C = 1;
|
|
|
|
params.nu = 0.5;
|
|
|
|
params.p = 0;
|
|
|
|
params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01);
|
|
|
|
|
|
|
|
find_decision_boundary_SVM( params );
|
|
|
|
namedWindow( "classificationSVM1", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "classificationSVM1", imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
|
|
|
|
params.C = 10;
|
|
|
|
find_decision_boundary_SVM( params );
|
2013-10-30 20:34:27 +08:00
|
|
|
namedWindow( "classificationSVM2", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "classificationSVM2", imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _DT_
|
2011-05-01 00:44:34 +08:00
|
|
|
find_decision_boundary_DT();
|
2011-05-01 17:01:57 +08:00
|
|
|
namedWindow( "DT", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "DT", imgDst );
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _BT_
|
2011-05-01 02:04:33 +08:00
|
|
|
find_decision_boundary_BT();
|
2011-05-01 17:01:57 +08:00
|
|
|
namedWindow( "BT", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "BT", imgDst);
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _GBT_
|
2011-05-01 02:04:33 +08:00
|
|
|
find_decision_boundary_GBT();
|
2011-05-01 17:01:57 +08:00
|
|
|
namedWindow( "GBT", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "GBT", imgDst);
|
2011-05-01 00:44:34 +08:00
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _RF_
|
2011-05-01 00:44:34 +08:00
|
|
|
find_decision_boundary_RF();
|
2011-05-01 17:01:57 +08:00
|
|
|
namedWindow( "RF", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "RF", imgDst);
|
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _ERT_
|
2011-05-01 02:04:33 +08:00
|
|
|
find_decision_boundary_ERT();
|
2011-05-01 17:01:57 +08:00
|
|
|
namedWindow( "ERT", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "ERT", imgDst);
|
2011-05-01 00:44:34 +08:00
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _ANN_
|
2011-05-01 00:44:34 +08:00
|
|
|
Mat layer_sizes1( 1, 3, CV_32SC1 );
|
|
|
|
layer_sizes1.at<int>(0) = 2;
|
|
|
|
layer_sizes1.at<int>(1) = 5;
|
|
|
|
layer_sizes1.at<int>(2) = classColors.size();
|
|
|
|
find_decision_boundary_ANN( layer_sizes1 );
|
|
|
|
namedWindow( "ANN", WINDOW_AUTOSIZE );
|
2011-05-01 02:04:33 +08:00
|
|
|
imshow( "ANN", imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
#endif
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
#if _EM_
|
2011-05-01 02:04:33 +08:00
|
|
|
find_decision_boundary_EM();
|
|
|
|
namedWindow( "EM", WINDOW_AUTOSIZE );
|
|
|
|
imshow( "EM", imgDst );
|
2011-05-01 00:44:34 +08:00
|
|
|
#endif
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return 1;
|
|
|
|
}
|