diff --git a/modules/ml/include/opencv2/ml/ml.hpp b/modules/ml/include/opencv2/ml/ml.hpp index 6612d2ea13..4a928c7305 100644 --- a/modules/ml/include/opencv2/ml/ml.hpp +++ b/modules/ml/include/opencv2/ml/ml.hpp @@ -1253,6 +1253,8 @@ public: protected: + void update_weights_impl( CvBoostTree* tree, double initial_weights[2] ); + virtual bool set_params( const CvBoostParams& params ); virtual void update_weights( CvBoostTree* tree ); virtual void trim_weights(); diff --git a/modules/ml/src/boost.cpp b/modules/ml/src/boost.cpp index 8db94bd713..d8e5c0d1d2 100644 --- a/modules/ml/src/boost.cpp +++ b/modules/ml/src/boost.cpp @@ -1117,9 +1117,9 @@ bool CvBoost::train( CvMLData* _data, } void -CvBoost::update_weights( CvBoostTree* tree ) +CvBoost::update_weights_impl( CvBoostTree* tree, double initial_weights[2] ) { - CV_FUNCNAME( "CvBoost::update_weights" ); + CV_FUNCNAME( "CvBoost::update_weights_impl" ); __BEGIN__; @@ -1161,7 +1161,7 @@ CvBoost::update_weights( CvBoostTree* tree ) // so we need to convert class labels to floating-point values double w0 = 1./n; - double p[2] = { 1, 1 }; + double p[2] = { initial_weights[0], initial_weights[1] }; cvReleaseMat( &orig_response ); cvReleaseMat( &sum_response ); @@ -1414,6 +1414,11 @@ CvBoost::update_weights( CvBoostTree* tree ) __END__; } +void +CvBoost::update_weights( CvBoostTree* tree ) { + double initial_weights[2] = { 1, 1 }; + update_weights_impl( tree, initial_weights ); +} static CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )