diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 4da34992dd..f1f122ebf9 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -187,7 +187,7 @@ public: oobidx.clear(); for( i = 0; i < n; i++ ) { - if( !oobmask[i] ) + if( oobmask[i] ) oobidx.push_back(i); } int n_oob = (int)oobidx.size(); @@ -217,6 +217,7 @@ public: else { int ival = cvRound(val); + //Voting scheme to combine OOB errors of each tree int* votes = &oobvotes[j*nclasses]; votes[ival]++; int best_class = 0; @@ -235,35 +236,35 @@ public: oobperm.resize(n_oob); for( i = 0; i < n_oob; i++ ) oobperm[i] = oobidx[i]; + for (i = n_oob - 1; i > 0; --i) //Randomly shuffle indices so we can permute features + { + int r_i = rng.uniform(0, i + 1); + std::swap(oobperm[i], oobperm[r_i]); + } for( vi_ = 0; vi_ < nvars; vi_++ ) { - vi = vidx ? vidx[vi_] : vi_; + vi = vidx ? vidx[vi_] : vi_; //Ensure that only the user specified predictors are used for training double ncorrect_responses_permuted = 0; - for( i = 0; i < n_oob; i++ ) - { - int i1 = rng.uniform(0, n_oob); - int i2 = rng.uniform(0, n_oob); - std::swap(i1, i2); - } for( i = 0; i < n_oob; i++ ) { j = oobidx[i]; int vj = oobperm[i]; sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) ); - for( k = 0; k < nallvars; k++ ) - sample.at(k) = sample0.at(k); - sample.at(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi]; + Mat sample_clone = sample0.clone(); //create a copy so we don't mess up the original data + sample_clone.at(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi]; - double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags); + double val = predictTrees(Range(treeidx, treeidx+1), sample_clone, predictFlags); if( !_isClassifier ) { val = (val - w->ord_responses[w->sidx[j]])/max_response; ncorrect_responses_permuted += exp( -val*val ); } else + { ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]]; + } } varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted); } diff --git a/samples/cpp/tree_engine.cpp b/samples/cpp/tree_engine.cpp index 2d6824d24d..d9fbb96788 100644 --- a/samples/cpp/tree_engine.cpp +++ b/samples/cpp/tree_engine.cpp @@ -63,7 +63,6 @@ int main(int argc, char** argv) const double train_test_split_ratio = 0.5; Ptr data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec); - if( data.empty() ) { printf("ERROR: File %s can not be read\n", filename); @@ -71,6 +70,7 @@ int main(int argc, char** argv) } data->setTrainTestSplitRatio(train_test_split_ratio); + std::cout << "Test/Train: " << data->getNTestSamples() << "/" << data->getNTrainSamples(); printf("======DTREE=====\n"); Ptr dtree = DTrees::create(); @@ -106,10 +106,19 @@ int main(int argc, char** argv) rtrees->setUseSurrogates(false); rtrees->setMaxCategories(16); rtrees->setPriors(Mat()); - rtrees->setCalculateVarImportance(false); + rtrees->setCalculateVarImportance(true); rtrees->setActiveVarCount(0); rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0)); train_and_print_errs(rtrees, data); + cv::Mat ref_labels = data->getClassLabels(); + cv::Mat test_data = data->getTestSampleIdx(); + cv::Mat predict_labels; + rtrees->predict(data->getSamples(), predict_labels); + cv::Mat variable_importance = rtrees->getVarImportance(); + std::cout << "Estimated variable importance" << std::endl; + for (int i = 0; i < variable_importance.rows; i++) { + std::cout << "Variable " << i << ": " << variable_importance.at(i, 0) << std::endl; + } return 0; }