From d8425d88816ef7dbc9c8e5fa4cce407c7cf7a64e Mon Sep 17 00:00:00 2001 From: mrquorr Date: Tue, 31 Jan 2017 23:31:10 -0600 Subject: [PATCH] finished for one sample Finished with several samples support, need regression testing Gave a more relevant name to function (getVotes) Finished implicit implementation Removed printf, finished regresion testing Fixed conversion warning Finished test for Rtrees Fixed documentation Initialized variable Added doxygen documentation Added parameter name --- modules/ml/include/opencv2/ml.hpp | 11 +++++ modules/ml/src/rtrees.cpp | 67 +++++++++++++++++++++++++++++++ modules/ml/test/test_mltests.cpp | 45 +++++++++++++++++++++ 3 files changed, 123 insertions(+) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 3614a91298..9ae14eaca9 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -1164,6 +1164,17 @@ public: */ CV_WRAP virtual Mat getVarImportance() const = 0; + /** Returns the result of each individual tree in the forest. + In case the model is a regression problem, the method will return each of the trees' + results for each of the sample cases. If the model is a classifier, it will return + a Mat with samples + 1 rows, where the first row gives the class number and the + following rows return the votes each class had for each sample. + @param samples Array containg the samples for which votes will be calculated. + @param results Array where the result of the calculation will be written. + @param flags Flags for defining the type of RTrees. + */ + CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const; + /** Creates the empty model. Use StatModel::train to train the model, StatModel::train to create and train the model, Algorithm::load to load the pre-trained model. diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 65fe6827a7..fa2a23950f 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -349,6 +349,60 @@ public: } } + void getVotes( InputArray input, OutputArray output, int flags ) const + { + CV_Assert( !roots.empty() ); + int nclasses = (int)classLabels.size(), ntrees = (int)roots.size(); + Mat samples = input.getMat(), results; + int i, j, nsamples = samples.rows; + + int predictType = flags & PREDICT_MASK; + if( predictType == PREDICT_AUTO ) + { + predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ? + PREDICT_SUM : PREDICT_MAX_VOTE; + } + + if( predictType == PREDICT_SUM ) + { + output.create(nsamples, ntrees, CV_32F); + results = output.getMat(); + for( i = 0; i < nsamples; i++ ) + { + for( j = 0; j < ntrees; j++ ) + { + float val = predictTrees( Range(j, j+1), samples.row(i), flags); + results.at (i, j) = val; + } + } + } else + { + vector votes; + output.create(nsamples+1, nclasses, CV_32S); + results = output.getMat(); + + for ( j = 0; j < nclasses; j++) + { + results.at (0, j) = classLabels[j]; + } + + for( i = 0; i < nsamples; i++ ) + { + votes.clear(); + for( j = 0; j < ntrees; j++ ) + { + int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags); + votes.push_back(val); + } + + for ( j = 0; j < nclasses; j++) + { + results.at (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]); + } + } + } + } + RTreeParams rparams; double oobError; vector varImportance; @@ -401,6 +455,11 @@ public: impl.read(fn); } + void getVotes_( InputArray samples, OutputArray results, int flags ) const + { + impl.getVotes(samples, results, flags); + } + Mat getVarImportance() const { return Mat_(impl.varImportance, true); } int getVarCount() const { return impl.getVarCount(); } @@ -427,6 +486,14 @@ Ptr RTrees::load(const String& filepath, const String& nodeName) return Algorithm::load(filepath, nodeName); } +void RTrees::getVotes(InputArray input, OutputArray output, int flags) const +{ + const RTreesImpl* this_ = dynamic_cast(this); + if(!this_) + CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl"); + return this_->getVotes_(input, output, flags); +} + }} // End of file. diff --git a/modules/ml/test/test_mltests.cpp b/modules/ml/test/test_mltests.cpp index 70cc0f7ecb..719333140c 100644 --- a/modules/ml/test/test_mltests.cpp +++ b/modules/ml/test/test_mltests.cpp @@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911) EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total()); } +TEST(ML_RTrees, getVotes) +{ + int n = 12; + int count, i; + int label_size = 3; + int predicted_class = 0; + int max_votes = -1; + int val; + // RTrees for classification + Ptr rt = cv::ml::RTrees::create(); + + //data + Mat data(n, 4, CV_32F); + randu(data, 0, 10); + + //labels + Mat labels = (Mat_(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2); + + rt->train(data, ml::ROW_SAMPLE, labels); + + //run function + Mat test(1, 4, CV_32F); + Mat result; + randu(test, 0, 10); + rt->getVotes(test, result, 0); + + //count vote amount and find highest vote + count = 0; + const int* result_row = result.ptr(1); + for( i = 0; i < label_size; i++ ) + { + val = result_row[i]; + //predicted_class = max_votes < val? i; + if( max_votes < val ) + { + max_votes = val; + predicted_class = i; + } + count += val; + } + + EXPECT_EQ(count, (int)rt->getRoots().size()); + EXPECT_EQ(result.at(0, predicted_class), rt->predict(test)); +} + /* End of file. */