added new sample on ML models

pull/13383/head
Maria Dimashova 14 years ago
parent e762f2a33c
commit 428aef522b
  1. 411
      samples/cpp/points_classifier.cpp

@ -0,0 +1,411 @@
#include "opencv2/core/core.hpp"
#include "opencv2/ml/ml.hpp"
#include "opencv2/highgui/highgui.hpp"
#include <stdio.h>
using namespace std;
using namespace cv;
const Scalar WHITE_COLOR = CV_RGB(255,255,255);
const string winName = "points";
const int testStep = 5;
Mat img, img_dst;
RNG rng;
vector<Point> trainedPoints;
vector<int> trainedPointsMarkers;
vector<Scalar> classColors;
#define KNN 0
#define SVM 0
#define DT 1
#define RF 0
#define ANN 0
#define GMM 0
void on_mouse( int event, int x, int y, int /*flags*/, void* )
{
if( img.empty() )
return;
int updateFlag = 0;
if( event == CV_EVENT_LBUTTONUP )
{
if( classColors.empty() )
return;
trainedPoints.push_back( Point(x,y) );
trainedPointsMarkers.push_back( classColors.size()-1 );
updateFlag = true;
}
else if( event == CV_EVENT_RBUTTONUP )
{
classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
updateFlag = true;
}
//draw
if( updateFlag )
{
img = Scalar::all(0);
// put the text
stringstream text;
text << "current class " << classColors.size()-1;
putText( img, text.str(), Point(10,25), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
text.str("");
text << "total classes " << classColors.size();
putText( img, text.str(), Point(10,50), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
text.str("");
text << "total points " << trainedPoints.size();
putText(img, text.str(), cvPoint(10,75), CV_FONT_HERSHEY_SIMPLEX, 0.8f, WHITE_COLOR, 2 );
// draw points
for( size_t i = 0; i < trainedPoints.size(); i++ )
circle( img, trainedPoints[i], 5, classColors[trainedPointsMarkers[i]], -1 );
imshow( winName, img );
}
}
void prepare_train_data( Mat& samples, Mat& classes )
{
Mat( trainedPoints ).copyTo( samples );
Mat( trainedPointsMarkers ).copyTo( classes );
// reshape trainData and change its type
samples = samples.reshape( 1, samples.rows );
samples.convertTo( samples, CV_32FC1 );
}
#if KNN
void find_decision_boundary_KNN( int K )
{
img.copyTo( img_dst );
Mat trainSamples, trainClasses;
prepare_train_data( trainSamples, trainClasses );
// learn classifier
CvKNearest knnClassifier( trainSamples, trainClasses, Mat(), false, K );
Mat testSample( 1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep )
{
for( int x = 0; x < img.cols; x += testStep )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
int response = (int)knnClassifier.find_nearest( testSample, K );
circle( img_dst, Point(x,y), 1, classColors[response] );
}
}
}
#endif
#if SVM
void find_decision_boundary_SVM( CvSVMParams params )
{
img.copyTo( img_dst );
Mat trainSamples, trainClasses;
prepare_train_data( trainSamples, trainClasses );
// learn classifier
CvSVM svmClassifier( trainSamples, trainClasses, Mat(), Mat(), params );
Mat testSample( 1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep )
{
for( int x = 0; x < img.cols; x += testStep )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
int response = (int)svmClassifier.predict( testSample );
circle( img_dst, Point(x,y), 2, classColors[response], 1 );
}
}
for( int i = 0; i < svmClassifier.get_support_vector_count(); i++ )
{
const float* supportVector = svmClassifier.get_support_vector(i);
circle( img_dst, Point(supportVector[0],supportVector[1]), 5, CV_RGB(255,255,255), -1 );
}
}
#endif
#if DT
void find_decision_boundary_DT()
{
img.copyTo( img_dst );
Mat trainSamples, trainClasses;
prepare_train_data( trainSamples, trainClasses );
// learn classifier
CvDTree dtree;
Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
CvDTreeParams params;
params.max_depth = 8;
params.min_sample_count = 2;
params.use_surrogates = false;
params.cv_folds = 0; // the number of cross-validation folds
params.use_1se_rule = false;
params.truncate_pruned_tree = false;
dtree.train( trainSamples, CV_ROW_SAMPLE, trainClasses,
Mat(), Mat(), var_types, Mat(), params );
Mat testSample(1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep )
{
for( int x = 0; x < img.cols; x += testStep )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
int response = (int)dtree.predict( testSample )->value;
circle( img_dst, Point(x,y), 2, classColors[response], 1 );
}
}
}
#endif
#if RF
void find_decision_boundary_RF()
{
img.copyTo( img_dst );
Mat trainSamples, trainClasses;
prepare_train_data( trainSamples, trainClasses );
// learn classifier
CvRTrees rtrees;
Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;
CvRTParams params( 4, // max_depth,
2, // min_sample_count,
0.f, // regression_accuracy,
false, // use_surrogates,
16, // max_categories,
0, // priors,
false, // calc_var_importance,
1, // nactive_vars,
5, // max_num_of_trees_in_the_forest,
0, // forest_accuracy,
CV_TERMCRIT_ITER // termcrit_type
);
rtrees.train( trainSamples, CV_ROW_SAMPLE, trainClasses, Mat(), Mat(), var_types, Mat(), params );
Mat testSample(1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep )
{
for( int x = 0; x < img.cols; x += testStep )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
int response = (int)rtrees.predict( testSample );
circle( img_dst, Point(x,y), 2, classColors[response], 1 );
}
}
}
#endif
#if ANN
void find_decision_boundary_ANN( const Mat& layer_sizes )
{
img.copyTo( img_dst );
Mat trainSamples, trainClasses;
prepare_train_data( trainSamples, trainClasses );
// prerare trainClasses
trainClasses.create( trainedPoints.size(), classColors.size(), CV_32FC1 );
for( int i = 0; i < trainClasses.rows; i++ )
{
for( int k = 0; k < trainClasses.cols; k++ )
{
if( k == trainedPointsMarkers[i] )
trainClasses.at<float>(i,k) = 1;
else
trainClasses.at<float>(i,k) = 0;
}
}
Mat weights( 1, trainedPoints.size(), CV_32FC1, Scalar::all(1) );
// learn classifier
CvANN_MLP ann( layer_sizes, CvANN_MLP::SIGMOID_SYM, 1, 1 );
ann.train( trainSamples, trainClasses, weights );
Mat testSample( 1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep )
{
for( int x = 0; x < img.cols; x += testStep )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
Mat outputs( 1, classColors.size(), CV_32FC1, testSample.data );
ann.predict( testSample, outputs );
Point maxLoc;
minMaxLoc( outputs, 0, 0, 0, &maxLoc );
circle( img_dst, Point(x,y), 2, classColors[maxLoc.x], 1 );
}
}
}
#endif
#if GMM
void find_decision_boundary_GMM()
{
img.copyTo( img_dst );
Mat trainSamples, trainClasses;
prepare_train_data( trainSamples, trainClasses );
CvEM em;
CvEMParams params;
params.covs = NULL;
params.means = NULL;
params.weights = NULL;
params.probs = NULL;
params.nclusters = classColors.size();
params.cov_mat_type = CvEM::COV_MAT_GENERIC;
params.start_step = CvEM::START_AUTO_STEP;
params.term_crit.max_iter = 10;
params.term_crit.epsilon = 0.1;
params.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
// learn classifier
em.train( trainSamples, Mat(), params, &trainClasses );
Mat testSample(1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep )
{
for( int x = 0; x < img.cols; x += testStep )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
int response = (int)em.predict( testSample );
circle( img_dst, Point(x,y), 2, classColors[response], 1 );
}
}
}
#endif
int main()
{
cv::namedWindow( "points", 1 );
img.create( 480, 640, CV_8UC3 );
img_dst.create( 480, 640, CV_8UC3 );
imshow( "points", img );
cvSetMouseCallback( "points", on_mouse );
for(;;)
{
uchar key = waitKey();
if( key == 27 ) break;
if( key == 'i' ) // init
{
img = Scalar::all(0);
classColors.clear();
trainedPoints.clear();
trainedPointsMarkers.clear();
imshow( winName, img );
}
if( key == 'r' ) // run
{
#if KNN
int K = 3;
find_decision_boundary_KNN( K );
namedWindow( "kNN", WINDOW_AUTOSIZE );
imshow( "kNN", img_dst );
K = 15;
find_decision_boundary_KNN( K );
namedWindow( "kNN2", WINDOW_AUTOSIZE );
imshow( "kNN2", img_dst );
#endif
#if SVM
//(1)-(2)separable and not sets
CvSVMParams params;
params.svm_type = CvSVM::C_SVC;
params.kernel_type = CvSVM::POLY; //CvSVM::LINEAR;
params.degree = 0.5;
params.gamma = 1;
params.coef0 = 1;
params.C = 1;
params.nu = 0.5;
params.p = 0;
params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01);
find_decision_boundary_SVM( params );
namedWindow( "classificationSVM1", WINDOW_AUTOSIZE );
imshow( "classificationSVM1", img_dst );
params.C = 10;
find_decision_boundary_SVM( params );
cvNamedWindow( "classificationSVM2", WINDOW_AUTOSIZE );
imshow( "classificationSVM2", img_dst );
#endif
#if DT
find_decision_boundary_DT();
namedWindow( "DT", 1 );
imshow( "DT", img_dst );
#endif
#if RF
find_decision_boundary_RF();
namedWindow( "RF", 1 );
imshow( "RF", img_dst);
#endif
#if ANN
Mat layer_sizes1( 1, 3, CV_32SC1 );
layer_sizes1.at<int>(0) = 2;
layer_sizes1.at<int>(1) = 5;
layer_sizes1.at<int>(2) = classColors.size();
find_decision_boundary_ANN( layer_sizes1 );
namedWindow( "ANN", WINDOW_AUTOSIZE );
imshow( "ANN", img_dst );
#endif
#if GMM
find_decision_boundary_GMM();
namedWindow( "GMM", WINDOW_AUTOSIZE );
imshow( "GMM", img_dst );
#endif
}
}
return 1;
}
Loading…
Cancel
Save