diff --git a/modules/ml/src/em.cpp b/modules/ml/src/em.cpp index 860a698c29..130c9b87ab 100644 --- a/modules/ml/src/em.cpp +++ b/modules/ml/src/em.cpp @@ -44,7 +44,7 @@ namespace cv { -const double minEigenValue = 1.e-5; +const double minEigenValue = DBL_MIN; /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -121,7 +121,7 @@ bool EM::trainM(InputArray samples, } -int EM::predict(InputArray _sample, OutputArray _probs, double* _logLikelihood) const +int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) const { Mat sample = _sample.getMat(); CV_Assert(isTrained()); @@ -135,16 +135,13 @@ int EM::predict(InputArray _sample, OutputArray _probs, double* _logLikelihood) } int label; - double logLikelihood = 0.; Mat probs; if( _probs.needed() ) { _probs.create(1, nclusters, CV_64FC1); probs = _probs.getMat(); } - computeProbabilities(sample, label, !probs.empty() ? &probs : 0, _logLikelihood ? &logLikelihood : 0); - if(_logLikelihood) - *_logLikelihood = logLikelihood; + computeProbabilities(sample, label, !probs.empty() ? &probs : 0, logLikelihood); return label; } @@ -372,6 +369,7 @@ void EM::computeLogWeightDivDet() CV_Assert(!covsEigenValues.empty()); Mat logWeights; + cv::max(weights, DBL_MIN, weights); log(weights, logWeights); logWeightDivDet.create(1, nclusters, CV_64FC1); @@ -504,28 +502,24 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double* if(!probs && !logLikelihood) return; - if(probs) - { - Mat expL_Lmax; - exp(L - L.at(label), expL_Lmax); - double partSum = 0, // sum_j!=q (exp(L_ij - L_iq)) - factor; // 1/(1 + partExpSum) - for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) - if(clusterIndex != label) - partSum += expL_Lmax.at(clusterIndex); - factor = 1./(1 + partSum); + Mat buf, *sampleProbs = probs ? probs : &buf; + Mat expL_Lmax; + exp(L - L.at(label), expL_Lmax); + double partSum = 0, // sum_j!=q (exp(L_ij - L_iq)) + factor; // 1/(1 + partExpSum) + for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) + if(clusterIndex != label) + partSum += expL_Lmax.at(clusterIndex); + factor = 1./(1 + partSum); - probs->create(1, nclusters, CV_64FC1); - expL_Lmax *= factor; - expL_Lmax.copyTo(*probs); - } + sampleProbs->create(1, nclusters, CV_64FC1); + expL_Lmax *= factor; + expL_Lmax.copyTo(*sampleProbs); if(logLikelihood) { - Mat expL; - exp(L, expL); - // note logLikelihood = log (sum_j exp(L_ij)) - 0.5 * dims * ln2Pi - *logLikelihood = std::log(sum(expL)[0]) - (double)(0.5 * dim * CV_LOG2PI); + double logWeightProbs = std::log(std::max(DBL_MIN, sum(*sampleProbs)[0])); + *logLikelihood = logWeightProbs; } } diff --git a/modules/ml/test/test_emknearestkmeans.cpp b/modules/ml/test/test_emknearestkmeans.cpp index d9b9460c89..31a9bee097 100644 --- a/modules/ml/test/test_emknearestkmeans.cpp +++ b/modules/ml/test/test_emknearestkmeans.cpp @@ -83,7 +83,7 @@ void generateData( Mat& data, Mat& labels, const vector& sizes, const Mat& labels.create( data.rows, 1, labelType ); - randn( data, Scalar::all(0.0), Scalar::all(1.0) ); + randn( data, Scalar::all(-1.0), Scalar::all(1.0) ); vector means(sizes.size()); for(int i = 0; i < _means.rows; i++) means[i] = _means.row(i); @@ -381,7 +381,7 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params, ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex ); code = cvtest::TS::FAIL_INVALID_OUTPUT; } - else if( err > 0.006f ) + else if( err > 0.008f ) { ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on train data.\n", caseIndex, err ); code = cvtest::TS::FAIL_BAD_ACCURACY; @@ -401,7 +401,7 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params, ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex ); code = cvtest::TS::FAIL_INVALID_OUTPUT; } - else if( err > 0.006f ) + else if( err > 0.008f ) { ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on test data.\n", caseIndex, err ); code = cvtest::TS::FAIL_BAD_ACCURACY; @@ -505,7 +505,8 @@ protected: virtual void run( int /*start_from*/ ) { int code = cvtest::TS::OK; - cv::EM em(2); + const int nclusters = 2; + cv::EM em(nclusters); Mat samples = Mat(3,1,CV_64FC1); samples.at(0,0) = 1;