modified EM interface; updated tests & samples

pull/13383/head
Vadim Pisarevsky 13 years ago
parent 1c1c6b98f6
commit b8c310065c
  1. 3
      modules/contrib/include/opencv2/contrib/hybridtracker.hpp
  2. 22
      modules/contrib/src/hybridtracker.cpp
  3. 6
      modules/legacy/include/opencv2/legacy/legacy.hpp
  4. 121
      modules/legacy/src/em.cpp
  5. 15
      modules/legacy/test/test_em.cpp
  6. 118
      modules/ml/include/opencv2/ml/ml.hpp
  7. 342
      modules/ml/src/em.cpp
  8. 48
      modules/ml/test/test_emknearestkmeans.cpp

@ -66,7 +66,6 @@ struct CV_EXPORTS CvMotionModel
} }
float low_pass_gain; // low pass gain float low_pass_gain; // low pass gain
cv::EM::Params em_params; // EM parameters
}; };
// Mean Shift Tracker parameters for specifying use of HSV channel and CamShift parameters. // Mean Shift Tracker parameters for specifying use of HSV channel and CamShift parameters.
@ -109,7 +108,6 @@ struct CV_EXPORTS CvHybridTrackerParams
float ms_tracker_weight; float ms_tracker_weight;
CvFeatureTrackerParams ft_params; CvFeatureTrackerParams ft_params;
CvMeanShiftTrackerParams ms_params; CvMeanShiftTrackerParams ms_params;
cv::EM::Params em_params;
int motion_model; int motion_model;
float low_pass_gain; float low_pass_gain;
}; };
@ -182,7 +180,6 @@ private:
CvMat* samples; CvMat* samples;
CvMat* labels; CvMat* labels;
cv::EM em_model;
Rect prev_window; Rect prev_window;
Point2f prev_center; Point2f prev_center;

@ -132,17 +132,6 @@ void CvHybridTracker::newTracker(Mat image, Rect selection) {
mstracker->newTrackingWindow(image, selection); mstracker->newTrackingWindow(image, selection);
fttracker->newTrackingWindow(image, selection); fttracker->newTrackingWindow(image, selection);
params.em_params.covs = NULL;
params.em_params.means = NULL;
params.em_params.probs = NULL;
params.em_params.nclusters = 1;
params.em_params.weights = NULL;
params.em_params.covMatType = cv::EM::COV_MAT_SPHERICAL;
params.em_params.startStep = cv::EM::START_AUTO_STEP;
params.em_params.termCrit.maxCount = 10000;
params.em_params.termCrit.epsilon = 0.001;
params.em_params.termCrit.type = cv::TermCriteria::COUNT + cv::TermCriteria::EPS;
samples = cvCreateMat(2, 1, CV_32FC1); samples = cvCreateMat(2, 1, CV_32FC1);
labels = cvCreateMat(2, 1, CV_32SC1); labels = cvCreateMat(2, 1, CV_32SC1);
@ -222,12 +211,15 @@ void CvHybridTracker::updateTrackerWithEM(Mat image) {
} }
cv::Mat lbls; cv::Mat lbls;
em_model.train(samples, cv::Mat(), params.em_params, &lbls);
EM em_model(1, EM::COV_MAT_SPHERICAL, TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.001));
em_model.train(cvarrToMat(samples), lbls);
if(labels) if(labels)
*labels = lbls; lbls.copyTo(cvarrToMat(labels));
curr_center.x = (float)em_model.getMeans().at<double> (0, 0); Mat em_means = em_model.get<Mat>("means");
curr_center.y = (float)em_model.getMeans().at<double> (0, 1); curr_center.x = (float)em_means.at<float>(0, 0);
curr_center.y = (float)em_means.at<float>(0, 1);
} }
void CvHybridTracker::updateTrackerWithLowPassFilter(Mat image) { void CvHybridTracker::updateTrackerWithLowPassFilter(Mat image) {

@ -1821,10 +1821,10 @@ public:
CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const; CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const;
CV_WRAP int getNClusters() const; CV_WRAP int getNClusters() const;
CV_WRAP const cv::Mat& getMeans() const; CV_WRAP cv::Mat getMeans() const;
CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const; CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const;
CV_WRAP const cv::Mat& getWeights() const; CV_WRAP cv::Mat getWeights() const;
CV_WRAP const cv::Mat& getProbs() const; CV_WRAP cv::Mat getProbs() const;
CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? likelihood : DBL_MAX; } CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? likelihood : DBL_MAX; }
#endif #endif

@ -41,6 +41,8 @@
#include "precomp.hpp" #include "precomp.hpp"
using namespace cv;
CvEMParams::CvEMParams() : nclusters(10), cov_mat_type(CvEM::COV_MAT_DIAGONAL), CvEMParams::CvEMParams() : nclusters(10), cov_mat_type(CvEM::COV_MAT_DIAGONAL),
start_step(CvEM::START_AUTO_STEP), probs(0), weights(0), means(0), covs(0) start_step(CvEM::START_AUTO_STEP), probs(0), weights(0), means(0), covs(0)
{ {
@ -76,38 +78,44 @@ void CvEM::clear()
void CvEM::read( CvFileStorage* fs, CvFileNode* node ) void CvEM::read( CvFileStorage* fs, CvFileNode* node )
{ {
cv::FileNode fn(fs, node); FileNode fn(fs, node);
emObj.read(fn); emObj.read(fn);
set_mat_hdrs(); set_mat_hdrs();
} }
void CvEM::write( CvFileStorage* _fs, const char* name ) const void CvEM::write( CvFileStorage* _fs, const char* name ) const
{ {
cv::FileStorage fs = _fs; FileStorage fs = _fs;
if(name) if(name)
fs << name << "{"; fs << name << "{";
emObj.write(fs); emObj.write(fs);
if(name) if(name)
fs << "}"; fs << "}";
fs.fs.obj = 0;
} }
double CvEM::calcLikelihood( const cv::Mat &input_sample ) const double CvEM::calcLikelihood( const Mat &input_sample ) const
{ {
double likelihood; double likelihood;
emObj.predict(input_sample, 0, &likelihood); emObj.predict(input_sample, noArray(), &likelihood);
return likelihood; return likelihood;
} }
float float
CvEM::predict( const CvMat* _sample, CvMat* _probs, bool isNormalize ) const CvEM::predict( const CvMat* _sample, CvMat* _probs, bool isNormalize ) const
{ {
cv::Mat prbs; Mat prbs0 = cvarrToMat(_probs), prbs = prbs0, sample = cvarrToMat(_sample);
int cls = emObj.predict(_sample, _probs ? &prbs : 0); int cls = emObj.predict(sample, _probs ? _OutputArray(prbs) : _OutputArray::_OutputArray());
if(_probs) if(_probs)
{ {
if(isNormalize) if(isNormalize)
cv::normalize(prbs, prbs, 1, 0, cv::NORM_L1); normalize(prbs, prbs, 1, 0, NORM_L1);
*_probs = prbs;
if( prbs.data != prbs0.data )
{
CV_Assert( prbs.size == prbs0.size );
prbs.convertTo(prbs0, prbs0.type());
}
} }
return (float)cls; return (float)cls;
} }
@ -116,73 +124,55 @@ void CvEM::set_mat_hdrs()
{ {
if(emObj.isTrained()) if(emObj.isTrained())
{ {
meansHdr = emObj.getMeans(); meansHdr = emObj.get<Mat>("means");
covsHdrs.resize(emObj.getNClusters()); int K = emObj.get<int>("nclusters");
covsPtrs.resize(emObj.getNClusters()); covsHdrs.resize(K);
const std::vector<cv::Mat>& covs = emObj.getCovs(); covsPtrs.resize(K);
const std::vector<Mat>& covs = emObj.get<vector<Mat> >("covs");
for(size_t i = 0; i < covsHdrs.size(); i++) for(size_t i = 0; i < covsHdrs.size(); i++)
{ {
covsHdrs[i] = covs[i]; covsHdrs[i] = covs[i];
covsPtrs[i] = &covsHdrs[i]; covsPtrs[i] = &covsHdrs[i];
} }
weightsHdr = emObj.getWeights(); weightsHdr = emObj.get<Mat>("weights");
probsHdr = probs; probsHdr = probs;
} }
} }
static static
void init_params(const CvEMParams& src, cv::EM::Params& dst, void init_params(const CvEMParams& src,
cv::Mat& prbs, cv::Mat& weights, Mat& prbs, Mat& weights,
cv::Mat& means, cv::vector<cv::Mat>& covsHdrs) Mat& means, vector<Mat>& covsHdrs)
{ {
dst.nclusters = src.nclusters;
dst.covMatType = src.cov_mat_type;
dst.startStep = src.start_step;
dst.termCrit = src.term_crit;
prbs = src.probs; prbs = src.probs;
dst.probs = &prbs;
weights = src.weights; weights = src.weights;
dst.weights = &weights;
means = src.means; means = src.means;
dst.means = &means;
if(src.covs) if(src.covs)
{ {
covsHdrs.resize(src.nclusters); covsHdrs.resize(src.nclusters);
for(size_t i = 0; i < covsHdrs.size(); i++) for(size_t i = 0; i < covsHdrs.size(); i++)
covsHdrs[i] = src.covs[i]; covsHdrs[i] = src.covs[i];
dst.covs = &covsHdrs;
} }
} }
bool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx, bool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx,
CvEMParams _params, CvMat* _labels ) CvEMParams _params, CvMat* _labels )
{ {
cv::EM::Params params; CV_Assert(_sample_idx == 0);
cv::Mat prbs, weights, means; Mat samples = cvarrToMat(_samples), labels0, labels;
std::vector<cv::Mat> covsHdrs; if( _labels )
init_params(_params, params, prbs, weights, means, covsHdrs); labels0 = labels = cvarrToMat(_labels);
cv::Mat lbls; bool isOk = train(samples, Mat(), _params, _labels ? &labels : 0);
cv::Mat likelihoods; CV_Assert( labels0.data == labels.data );
bool isOk = emObj.train(_samples, _sample_idx, params, _labels ? &lbls : 0, &probs, &likelihoods );
if(isOk)
{
if(_labels)
*_labels = lbls;
likelihood = cv::sum(likelihoods)[0];
set_mat_hdrs();
}
return isOk; return isOk;
} }
int CvEM::get_nclusters() const int CvEM::get_nclusters() const
{ {
return emObj.getNClusters(); return emObj.get<int>("nclusters");
} }
const CvMat* CvEM::get_means() const const CvMat* CvEM::get_means() const
@ -215,16 +205,29 @@ CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
bool CvEM::train( const Mat& _samples, const Mat& _sample_idx, bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
CvEMParams _params, Mat* _labels ) CvEMParams _params, Mat* _labels )
{ {
cv::EM::Params params; Mat prbs, weights, means, likelihoods;
cv::Mat prbs, weights, means; std::vector<Mat> covsHdrs;
std::vector<cv::Mat> covsHdrs; init_params(_params, prbs, weights, means, covsHdrs);
init_params(_params, params, prbs, weights, means, covsHdrs);
emObj = EM(_params.nclusters, _params.cov_mat_type, _params.term_crit);
bool isOk = false;
if( _params.start_step == EM::START_AUTO_STEP )
isOk = emObj.train(_samples, _labels ? _OutputArray(*_labels) : _OutputArray::_OutputArray(),
probs, likelihoods);
else if( _params.start_step == EM::START_E_STEP )
isOk = emObj.trainE(_samples, means, covsHdrs, weights,
_labels ? _OutputArray(*_labels) : _OutputArray::_OutputArray(),
probs, likelihoods);
else if( _params.start_step == EM::START_M_STEP )
isOk = emObj.trainM(_samples, prbs,
_labels ? _OutputArray(*_labels) : _OutputArray::_OutputArray(),
probs, likelihoods);
else
CV_Error(CV_StsBadArg, "Bad start type of EM algorithm");
cv::Mat likelihoods;
bool isOk = emObj.train(_samples, _sample_idx, params, _labels, &probs, &likelihoods);
if(isOk) if(isOk)
{ {
likelihoods = cv::sum(likelihoods).val[0]; likelihoods = sum(likelihoods).val[0];
set_mat_hdrs(); set_mat_hdrs();
} }
@ -234,34 +237,34 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
float float
CvEM::predict( const Mat& _sample, Mat* _probs, bool isNormalize ) const CvEM::predict( const Mat& _sample, Mat* _probs, bool isNormalize ) const
{ {
int cls = emObj.predict(_sample, _probs); int cls = emObj.predict(_sample, _probs ? _OutputArray(*_probs) : _OutputArray::_OutputArray());
if(_probs && isNormalize) if(_probs && isNormalize)
cv::normalize(*_probs, *_probs, 1, 0, cv::NORM_L1); normalize(*_probs, *_probs, 1, 0, NORM_L1);
return (float)cls; return (float)cls;
} }
int CvEM::getNClusters() const int CvEM::getNClusters() const
{ {
return emObj.getNClusters(); return emObj.get<int>("nclusters");
} }
const Mat& CvEM::getMeans() const Mat CvEM::getMeans() const
{ {
return emObj.getMeans(); return emObj.get<Mat>("means");
} }
void CvEM::getCovs(vector<Mat>& _covs) const void CvEM::getCovs(vector<Mat>& _covs) const
{ {
_covs = emObj.getCovs(); _covs = emObj.get<vector<Mat> >("covs");
} }
const Mat& CvEM::getWeights() const Mat CvEM::getWeights() const
{ {
return emObj.getWeights(); return emObj.get<Mat>("weights");
} }
const Mat& CvEM::getProbs() const Mat CvEM::getProbs() const
{ {
return probs; return probs;
} }

@ -371,19 +371,20 @@ protected:
virtual void run( int /*start_from*/ ) virtual void run( int /*start_from*/ )
{ {
int code = cvtest::TS::OK; int code = cvtest::TS::OK;
cv::EM em;
Mat samples = Mat(3,1,CV_32F); Mat samples = Mat(3,1,CV_32F);
samples.at<float>(0,0) = 1; samples.at<float>(0,0) = 1;
samples.at<float>(1,0) = 2; samples.at<float>(1,0) = 2;
samples.at<float>(2,0) = 3; samples.at<float>(2,0) = 3;
cv::EM::Params params; Mat labels(samples.rows, 1, CV_32S);
CvEMParams params;
params.nclusters = 2; params.nclusters = 2;
Mat labels; CvMat samples_c = samples, labels_c = labels;
em.train(samples, Mat(), params, &labels); CvEM em(&samples_c, 0, params, &labels_c);
Mat firstResult(samples.rows, 1, CV_32FC1); Mat firstResult(samples.rows, 1, CV_32FC1);
for( int i = 0; i < samples.rows; i++) for( int i = 0; i < samples.rows; i++)
@ -396,9 +397,7 @@ protected:
FileStorage fs = FileStorage(filename, FileStorage::WRITE); FileStorage fs = FileStorage(filename, FileStorage::WRITE);
try try
{ {
fs << "em" << "{"; em.write(fs.fs, "em");
em.write(fs);
fs << "}";
} }
catch(...) catch(...)
{ {
@ -416,7 +415,7 @@ protected:
FileNode fn = fs["em"]; FileNode fn = fs["em"];
try try
{ {
em.read(fn); em.read(fs.fs, (CvFileNode*)fn.node);
} }
catch(...) catch(...)
{ {

@ -555,61 +555,66 @@ protected:
\****************************************************************************************/ \****************************************************************************************/
namespace cv namespace cv
{ {
class CV_EXPORTS EM : public Algorithm class CV_EXPORTS_W EM : public Algorithm
{ {
public: public:
// Type of covariation matrices // Type of covariation matrices
enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2}; enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL};
// Default parameters
enum {DEFAULT_NCLUSTERS=10, DEFAULT_MAX_ITERS=100};
// The initial step // The initial step
enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0}; enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
class CV_EXPORTS Params CV_WRAP EM(int nclusters=EM::DEFAULT_NCLUSTERS, int covMatType=EM::COV_MAT_DIAGONAL,
{ const TermCriteria& termcrit=TermCriteria(TermCriteria::COUNT+
public: TermCriteria::EPS,
Params(int nclusters=10, int covMatType=EM::COV_MAT_DIAGONAL, int startStep=EM::START_AUTO_STEP, EM::DEFAULT_MAX_ITERS, FLT_EPSILON));
const cv::TermCriteria& termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON),
const cv::Mat* probs=0, const cv::Mat* weights=0,
const cv::Mat* means=0, const std::vector<cv::Mat>* covs=0);
int nclusters;
int covMatType;
int startStep;
// all 4 following matrices should have type CV_32FC1
const cv::Mat* probs;
const cv::Mat* weights;
const cv::Mat* means;
const std::vector<cv::Mat>* covs;
cv::TermCriteria termCrit;
};
EM();
EM(const cv::Mat& samples, const cv::Mat samplesMask=cv::Mat(),
const EM::Params& params=EM::Params(), cv::Mat* labels=0, cv::Mat* probs=0, cv::Mat* likelihoods=0);
virtual ~EM(); virtual ~EM();
virtual void clear(); CV_WRAP virtual void clear();
CV_WRAP virtual bool train(InputArray samples,
OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray likelihoods=noArray());
virtual bool train(const cv::Mat& samples, const cv::Mat& samplesMask=cv::Mat(), CV_WRAP virtual bool trainE(InputArray samples,
const EM::Params& params=EM::Params(), cv::Mat* labels=0, cv::Mat* probs=0, cv::Mat* likelihoods=0); InputArray means0,
int predict(const cv::Mat& sample, cv::Mat* probs=0, double* likelihood=0) const; InputArray covs0=noArray(),
InputArray weights0=noArray(),
OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray likelihoods=noArray());
bool isTrained() const; CV_WRAP virtual bool trainM(InputArray samples,
int getNClusters() const; InputArray probs0,
int getCovMatType() const; OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray likelihoods=noArray());
const cv::Mat& getWeights() const; CV_WRAP int predict(InputArray sample,
const cv::Mat& getMeans() const; OutputArray probs=noArray(),
const std::vector<cv::Mat>& getCovs() const; CV_OUT double* likelihood=0) const;
CV_WRAP bool isTrained() const;
AlgorithmInfo* info() const; AlgorithmInfo* info() const;
virtual void read(const FileNode& fn); virtual void read(const FileNode& fn);
protected: protected:
virtual void setTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM::Params& params);
bool doTrain(const cv::TermCriteria& termCrit); virtual void setTrainData(int startStep, const Mat& samples,
const Mat* probs0,
const Mat* means0,
const vector<Mat>* covs0,
const Mat* weights0);
bool doTrain(int startStep,
OutputArray labels,
OutputArray probs,
OutputArray likelihoods);
virtual void eStep(); virtual void eStep();
virtual void mStep(); virtual void mStep();
@ -617,27 +622,28 @@ protected:
void decomposeCovs(); void decomposeCovs();
void computeLogWeightDivDet(); void computeLogWeightDivDet();
void computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs, float* likelihood) const; void computeProbabilities(const Mat& sample, int& label, Mat* probs, float* likelihood) const;
// all inner matrices have type CV_32FC1 // all inner matrices have type CV_32FC1
int nclusters; CV_PROP_RW int nclusters;
int covMatType; CV_PROP_RW int covMatType;
int startStep; CV_PROP_RW int maxIters;
CV_PROP_RW double epsilon;
cv::Mat trainSamples;
cv::Mat trainProbs; Mat trainSamples;
cv::Mat trainLikelihoods; Mat trainProbs;
cv::Mat trainLabels; Mat trainLikelihoods;
cv::Mat trainCounts; Mat trainLabels;
Mat trainCounts;
cv::Mat weights;
cv::Mat means; CV_PROP Mat weights;
std::vector<cv::Mat> covs; CV_PROP Mat means;
CV_PROP vector<Mat> covs;
std::vector<cv::Mat> covsEigenValues;
std::vector<cv::Mat> covsRotateMats; vector<Mat> covsEigenValues;
std::vector<cv::Mat> invCovsEigenValues; vector<Mat> covsRotateMats;
cv::Mat logWeightDivDet; vector<Mat> invCovsEigenValues;
Mat logWeightDivDet;
}; };
} // namespace cv } // namespace cv

@ -46,22 +46,14 @@ namespace cv
const float minEigenValue = 1.e-3f; const float minEigenValue = 1.e-3f;
EM::Params::Params( int nclusters, int covMatType, int startStep, const cv::TermCriteria& termCrit,
const cv::Mat* probs, const cv::Mat* weights,
const cv::Mat* means, const std::vector<cv::Mat>* covs )
: nclusters(nclusters), covMatType(covMatType), startStep(startStep),
probs(probs), weights(weights), means(means), covs(covs), termCrit(termCrit)
{}
/////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////
EM::EM() EM::EM(int _nclusters, int _covMatType, const TermCriteria& _criteria)
{}
EM::EM(const cv::Mat& samples, const cv::Mat samplesMask,
const EM::Params& params, cv::Mat* labels, cv::Mat* probs, cv::Mat* likelihoods)
{ {
train(samples, samplesMask, params, labels, probs, likelihoods); nclusters = _nclusters;
covMatType = _covMatType;
maxIters = (_criteria.type & TermCriteria::MAX_ITER) ? _criteria.maxCount : DEFAULT_MAX_ITERS;
epsilon = (_criteria.type & TermCriteria::EPS) ? _criteria.epsilon : 0;
} }
EM::~EM() EM::~EM()
@ -88,36 +80,50 @@ void EM::clear()
logWeightDivDet.release(); logWeightDivDet.release();
} }
bool EM::train(const cv::Mat& samples, const cv::Mat& samplesMask,
const EM::Params& params, cv::Mat* labels, cv::Mat* probs, cv::Mat* likelihoods) bool EM::train(InputArray samples,
OutputArray labels,
OutputArray probs,
OutputArray likelihoods)
{ {
setTrainData(samples, samplesMask, params); setTrainData(START_AUTO_STEP, samples.getMat(), 0, 0, 0, 0);
return doTrain(START_AUTO_STEP, labels, probs, likelihoods);
}
bool isOk = doTrain(params.termCrit); bool EM::trainE(InputArray samples,
InputArray _means0,
InputArray _covs0,
InputArray _weights0,
OutputArray labels,
OutputArray probs,
OutputArray likelihoods)
{
vector<Mat> covs0;
_covs0.getMatVector(covs0);
if(isOk) Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
{
if(labels) setTrainData(START_E_STEP, samples.getMat(), 0, !_means0.empty() ? &means0 : 0,
cv::swap(*labels, trainLabels); !_covs0.empty() ? &covs0 : 0, _weights0.empty() ? &weights0 : 0);
if(probs) return doTrain(START_E_STEP, labels, probs, likelihoods);
cv::swap(*probs, trainProbs); }
if(likelihoods)
cv::swap(*likelihoods, trainLikelihoods); bool EM::trainM(InputArray samples,
InputArray _probs0,
trainSamples.release(); OutputArray labels,
trainProbs.release(); OutputArray probs,
trainLabels.release(); OutputArray likelihoods)
trainLikelihoods.release(); {
trainCounts.release(); Mat probs0 = _probs0.getMat();
}
else
clear();
return isOk; setTrainData(START_M_STEP, samples.getMat(), !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
return doTrain(START_M_STEP, labels, probs, likelihoods);
} }
int EM::predict(const cv::Mat& sample, cv::Mat* _probs, double* _likelihood) const
int EM::predict(InputArray _sample, OutputArray _probs, double* _likelihood) const
{ {
Mat sample = _sample.getMat();
CV_Assert(isTrained()); CV_Assert(isTrained());
CV_Assert(!sample.empty()); CV_Assert(!sample.empty());
@ -125,7 +131,13 @@ int EM::predict(const cv::Mat& sample, cv::Mat* _probs, double* _likelihood) con
int label; int label;
float likelihood = 0.f; float likelihood = 0.f;
computeProbabilities(sample, label, _probs, _likelihood ? &likelihood : 0); Mat probs;
if( _probs.needed() )
{
_probs.create(1, nclusters, CV_32FC1);
probs = _probs.getMat();
}
computeProbabilities(sample, label, !probs.empty() ? &probs : 0, _likelihood ? &likelihood : 0);
if(_likelihood) if(_likelihood)
*_likelihood = static_cast<double>(likelihood); *_likelihood = static_cast<double>(likelihood);
@ -137,36 +149,11 @@ bool EM::isTrained() const
return !means.empty(); return !means.empty();
} }
int EM::getNClusters() const
{
return isTrained() ? nclusters : -1;
}
int EM::getCovMatType() const
{
return isTrained() ? covMatType : -1;
}
const cv::Mat& EM::getWeights() const
{
CV_Assert((isTrained() && !weights.empty()) || (!isTrained() && weights.empty()));
return weights;
}
const cv::Mat& EM::getMeans() const
{
CV_Assert((isTrained() && !means.empty()) || (!isTrained() && means.empty()));
return means;
}
const std::vector<cv::Mat>& EM::getCovs() const
{
CV_Assert((isTrained() && !covs.empty()) || (!isTrained() && covs.empty()));
return covs;
}
static static
void checkTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM::Params& params) void checkTrainData(int startStep, const Mat& samples,
int nclusters, int covMatType, const Mat* probs, const Mat* means,
const vector<Mat>* covs, const Mat* weights)
{ {
// Check samples. // Check samples.
CV_Assert(!samples.empty()); CV_Assert(!samples.empty());
@ -175,138 +162,117 @@ void checkTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM
int nsamples = samples.rows; int nsamples = samples.rows;
int dim = samples.cols; int dim = samples.cols;
// Check samples indices.
CV_Assert(samplesMask.empty() ||
((samplesMask.rows == 1 || samplesMask.cols == 1) &&
static_cast<int>(samplesMask.total()) == nsamples && samplesMask.type() == CV_8UC1));
// Check training params. // Check training params.
CV_Assert(params.nclusters > 0); CV_Assert(nclusters > 0);
CV_Assert(params.nclusters <= nsamples); CV_Assert(nclusters <= nsamples);
CV_Assert(params.startStep == EM::START_AUTO_STEP || params.startStep == EM::START_E_STEP || params.startStep == EM::START_M_STEP); CV_Assert(startStep == EM::START_AUTO_STEP ||
startStep == EM::START_E_STEP ||
CV_Assert(!params.probs || startStep == EM::START_M_STEP);
(!params.probs->empty() &&
params.probs->rows == nsamples && params.probs->cols == params.nclusters && CV_Assert(!probs ||
params.probs->type() == CV_32FC1)); (!probs->empty() &&
probs->rows == nsamples && probs->cols == nclusters &&
CV_Assert(!params.weights || probs->type() == CV_32FC1));
(!params.weights->empty() &&
(params.weights->cols == 1 || params.weights->rows == 1) && static_cast<int>(params.weights->total()) == params.nclusters && CV_Assert(!weights ||
params.weights->type() == CV_32FC1)); (!weights->empty() &&
(weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
CV_Assert(!params.means || weights->type() == CV_32FC1));
(!params.means->empty() &&
params.means->rows == params.nclusters && params.means->cols == dim && CV_Assert(!means ||
params.means->type() == CV_32FC1)); (!means->empty() &&
means->rows == nclusters && means->cols == dim &&
CV_Assert(!params.covs || means->type() == CV_32FC1));
(!params.covs->empty() &&
static_cast<int>(params.covs->size()) == params.nclusters)); CV_Assert(!covs ||
if(params.covs) (!covs->empty() &&
static_cast<int>(covs->size()) == nclusters));
if(covs)
{ {
const cv::Size covSize(dim, dim); const Size covSize(dim, dim);
for(size_t i = 0; i < params.covs->size(); i++) for(size_t i = 0; i < covs->size(); i++)
{ {
const cv::Mat& m = (*params.covs)[i]; const Mat& m = (*covs)[i];
CV_Assert(!m.empty() && m.size() == covSize && (m.type() == CV_32FC1)); CV_Assert(!m.empty() && m.size() == covSize && (m.type() == CV_32FC1));
} }
} }
if(params.startStep == EM::START_E_STEP) if(startStep == EM::START_E_STEP)
{ {
CV_Assert(params.means); CV_Assert(means);
} }
else if(params.startStep == EM::START_M_STEP) else if(startStep == EM::START_M_STEP)
{ {
CV_Assert(params.probs); CV_Assert(probs);
} }
} }
static static
void preprocessSampleData(const cv::Mat& src, cv::Mat& dst, int dstType, const cv::Mat& samplesMask, bool isAlwaysClone) void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
{ {
if(samplesMask.empty() || cv::countNonZero(samplesMask) == src.rows) if(src.type() == dstType && !isAlwaysClone)
{ dst = src;
if(src.type() == dstType && !isAlwaysClone)
dst = src;
else
src.convertTo(dst, dstType);
}
else else
{ src.convertTo(dst, dstType);
dst.release();
for(int sampleIndex = 0; sampleIndex < src.rows; sampleIndex++)
{
if(samplesMask.at<uchar>(sampleIndex))
{
cv::Mat sample = src.row(sampleIndex);
cv::Mat sample_dbl;
sample.convertTo(sample_dbl, dstType);
dst.push_back(sample_dbl);
}
}
}
} }
static static
void preprocessProbability(cv::Mat& probs) void preprocessProbability(Mat& probs)
{ {
cv::max(probs, 0., probs); max(probs, 0., probs);
const float uniformProbability = (float)(1./probs.cols); const float uniformProbability = (float)(1./probs.cols);
for(int y = 0; y < probs.rows; y++) for(int y = 0; y < probs.rows; y++)
{ {
cv::Mat sampleProbs = probs.row(y); Mat sampleProbs = probs.row(y);
double maxVal = 0; double maxVal = 0;
cv::minMaxLoc(sampleProbs, 0, &maxVal); minMaxLoc(sampleProbs, 0, &maxVal);
if(maxVal < FLT_EPSILON) if(maxVal < FLT_EPSILON)
sampleProbs.setTo(uniformProbability); sampleProbs.setTo(uniformProbability);
else else
cv::normalize(sampleProbs, sampleProbs, 1, 0, cv::NORM_L1); normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
} }
} }
void EM::setTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM::Params& params) void EM::setTrainData(int startStep, const Mat& samples,
const Mat* probs0,
const Mat* means0,
const vector<Mat>* covs0,
const Mat* weights0)
{ {
clear(); clear();
checkTrainData(samples, samplesMask, params); checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
// Set checked data // Set checked data
preprocessSampleData(samples, trainSamples, CV_32FC1, false);
nclusters = params.nclusters;
covMatType = params.covMatType;
startStep = params.startStep;
preprocessSampleData(samples, trainSamples, CV_32FC1, samplesMask, false);
// set probs // set probs
if(params.probs && startStep == EM::START_M_STEP) if(probs0 && startStep == EM::START_M_STEP)
{ {
preprocessSampleData(*params.probs, trainProbs, CV_32FC1, samplesMask, true); preprocessSampleData(*probs0, trainProbs, CV_32FC1, true);
preprocessProbability(trainProbs); preprocessProbability(trainProbs);
} }
// set weights // set weights
if(params.weights && (startStep == EM::START_E_STEP && params.covs)) if(weights0 && (startStep == EM::START_E_STEP && covs0))
{ {
params.weights->convertTo(weights, CV_32FC1); weights0->convertTo(weights, CV_32FC1);
weights.reshape(1,1); weights.reshape(1,1);
preprocessProbability(weights); preprocessProbability(weights);
} }
// set means // set means
if(params.means && (startStep == EM::START_E_STEP || startStep == EM::START_AUTO_STEP)) if(means0 && (startStep == EM::START_E_STEP || startStep == EM::START_AUTO_STEP))
params.means->convertTo(means, CV_32FC1); means0->convertTo(means, CV_32FC1);
// set covs // set covs
if(params.covs && (startStep == EM::START_E_STEP && params.weights)) if(covs0 && (startStep == EM::START_E_STEP && weights0))
{ {
covs.resize(nclusters); covs.resize(nclusters);
for(size_t i = 0; i < params.covs->size(); i++) for(size_t i = 0; i < covs0->size(); i++)
(*params.covs)[i].convertTo(covs[i], CV_32FC1); (*covs0)[i].convertTo(covs[i], CV_32FC1);
} }
} }
@ -321,14 +287,14 @@ void EM::decomposeCovs()
{ {
CV_Assert(!covs[clusterIndex].empty()); CV_Assert(!covs[clusterIndex].empty());
cv::SVD svd(covs[clusterIndex], cv::SVD::MODIFY_A + cv::SVD::FULL_UV); SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
CV_DbgAssert(svd.w.rows == 1 || svd.w.cols == 1); CV_DbgAssert(svd.w.rows == 1 || svd.w.cols == 1);
CV_DbgAssert(svd.w.type() == CV_32FC1 && svd.u.type() == CV_32FC1); CV_DbgAssert(svd.w.type() == CV_32FC1 && svd.u.type() == CV_32FC1);
if(covMatType == EM::COV_MAT_SPHERICAL) if(covMatType == EM::COV_MAT_SPHERICAL)
{ {
float maxSingularVal = svd.w.at<float>(0); float maxSingularVal = svd.w.at<float>(0);
covsEigenValues[clusterIndex] = cv::Mat(1, 1, CV_32FC1, cv::Scalar(maxSingularVal)); covsEigenValues[clusterIndex] = Mat(1, 1, CV_32FC1, Scalar(maxSingularVal));
} }
else if(covMatType == EM::COV_MAT_DIAGONAL) else if(covMatType == EM::COV_MAT_DIAGONAL)
{ {
@ -339,7 +305,7 @@ void EM::decomposeCovs()
covsEigenValues[clusterIndex] = svd.w; covsEigenValues[clusterIndex] = svd.w;
covsRotateMats[clusterIndex] = svd.u; covsRotateMats[clusterIndex] = svd.u;
} }
cv::max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]); max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex]; invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
} }
} }
@ -349,29 +315,29 @@ void EM::clusterTrainSamples()
int nsamples = trainSamples.rows; int nsamples = trainSamples.rows;
// Cluster samples, compute/update means // Cluster samples, compute/update means
cv::Mat labels; Mat labels;
cv::kmeans(trainSamples, nclusters, labels, kmeans(trainSamples, nclusters, labels,
cv::TermCriteria(cv::TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5), TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
10, cv::KMEANS_PP_CENTERS, means); 10, KMEANS_PP_CENTERS, means);
CV_Assert(means.type() == CV_32FC1); CV_Assert(means.type() == CV_32FC1);
// Compute weights and covs // Compute weights and covs
weights = cv::Mat(1, nclusters, CV_32FC1, cv::Scalar(0)); weights = Mat(1, nclusters, CV_32FC1, Scalar(0));
covs.resize(nclusters); covs.resize(nclusters);
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{ {
cv::Mat clusterSamples; Mat clusterSamples;
for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++) for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
{ {
if(labels.at<int>(sampleIndex) == clusterIndex) if(labels.at<int>(sampleIndex) == clusterIndex)
{ {
const cv::Mat sample = trainSamples.row(sampleIndex); const Mat sample = trainSamples.row(sampleIndex);
clusterSamples.push_back(sample); clusterSamples.push_back(sample);
} }
} }
CV_Assert(!clusterSamples.empty()); CV_Assert(!clusterSamples.empty());
cv::calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex), calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_32FC1); CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_32FC1);
weights.at<float>(clusterIndex) = static_cast<float>(clusterSamples.rows)/static_cast<float>(nsamples); weights.at<float>(clusterIndex) = static_cast<float>(clusterSamples.rows)/static_cast<float>(nsamples);
} }
@ -383,8 +349,8 @@ void EM::computeLogWeightDivDet()
{ {
CV_Assert(!covsEigenValues.empty()); CV_Assert(!covsEigenValues.empty());
cv::Mat logWeights; Mat logWeights;
cv::log(weights, logWeights); log(weights, logWeights);
logWeightDivDet.create(1, nclusters, CV_32FC1); logWeightDivDet.create(1, nclusters, CV_32FC1);
// note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|) // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
@ -399,7 +365,7 @@ void EM::computeLogWeightDivDet()
} }
} }
bool EM::doTrain(const cv::TermCriteria& termCrit) bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArray likelihoods)
{ {
int dim = trainSamples.cols; int dim = trainSamples.cols;
// Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP // Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP
@ -425,15 +391,15 @@ bool EM::doTrain(const cv::TermCriteria& termCrit)
for(int iter = 0; ; iter++) for(int iter = 0; ; iter++)
{ {
eStep(); eStep();
trainLikelihood = cv::sum(trainLikelihoods)[0]; trainLikelihood = sum(trainLikelihoods)[0];
if(iter >= termCrit.maxCount - 1) if(iter >= maxIters - 1)
break; break;
double trainLikelihoodDelta = trainLikelihood - (iter > 0 ? prevTrainLikelihood : 0); double trainLikelihoodDelta = trainLikelihood - (iter > 0 ? prevTrainLikelihood : 0);
if( iter != 0 && if( iter != 0 &&
(trainLikelihoodDelta < -DBL_EPSILON || (trainLikelihoodDelta < -DBL_EPSILON ||
trainLikelihoodDelta < termCrit.epsilon * std::fabs(trainLikelihood))) trainLikelihoodDelta < epsilon * std::fabs(trainLikelihood)))
break; break;
mStep(); mStep();
@ -442,7 +408,10 @@ bool EM::doTrain(const cv::TermCriteria& termCrit)
} }
if( trainLikelihood <= -DBL_MAX/10000. ) if( trainLikelihood <= -DBL_MAX/10000. )
{
clear();
return false; return false;
}
// postprocess covs // postprocess covs
covs.resize(nclusters); covs.resize(nclusters);
@ -451,16 +420,29 @@ bool EM::doTrain(const cv::TermCriteria& termCrit)
if(covMatType == EM::COV_MAT_SPHERICAL) if(covMatType == EM::COV_MAT_SPHERICAL)
{ {
covs[clusterIndex].create(dim, dim, CV_32FC1); covs[clusterIndex].create(dim, dim, CV_32FC1);
cv::setIdentity(covs[clusterIndex], cv::Scalar(covsEigenValues[clusterIndex].at<float>(0))); setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<float>(0)));
} }
else if(covMatType == EM::COV_MAT_DIAGONAL) else if(covMatType == EM::COV_MAT_DIAGONAL)
covs[clusterIndex] = cv::Mat::diag(covsEigenValues[clusterIndex].t()); covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex].t());
} }
if(labels.needed())
trainLabels.copyTo(labels);
if(probs.needed())
trainProbs.copyTo(probs);
if(likelihoods.needed())
trainLikelihoods.copyTo(likelihoods);
trainSamples.release();
trainProbs.release();
trainLabels.release();
trainLikelihoods.release();
trainCounts.release();
return true; return true;
} }
void EM::computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs, float* likelihood) const void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, float* likelihood) const
{ {
// L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)] // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
// q = arg(max_k(L_ik)) // q = arg(max_k(L_ik))
@ -470,15 +452,15 @@ void EM::computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs,
int dim = sample.cols; int dim = sample.cols;
cv::Mat L(1, nclusters, CV_32FC1); Mat L(1, nclusters, CV_32FC1);
cv::Mat expL(1, nclusters, CV_32FC1); Mat expL(1, nclusters, CV_32FC1);
label = 0; label = 0;
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{ {
const cv::Mat centeredSample = sample - means.row(clusterIndex); const Mat centeredSample = sample - means.row(clusterIndex);
cv::Mat rotatedCenteredSample = covMatType != EM::COV_MAT_GENERIC ? Mat rotatedCenteredSample = covMatType != EM::COV_MAT_GENERIC ?
centeredSample : centeredSample * covsRotateMats[clusterIndex]; centeredSample : centeredSample * covsRotateMats[clusterIndex];
float Lval = 0; float Lval = 0;
@ -500,7 +482,7 @@ void EM::computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs,
return; return;
// TODO maybe without finding max L value // TODO maybe without finding max L value
cv::exp(L, expL); exp(L, expL);
float partExpSum = 0, // sum_j!=q (exp(L_jk) float partExpSum = 0, // sum_j!=q (exp(L_jk)
factor; // 1/(1 + sum_j!=q (exp(L_jk)) factor; // 1/(1 + sum_j!=q (exp(L_jk))
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
@ -510,7 +492,7 @@ void EM::computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs,
} }
factor = 1.f/(1 + partExpSum); factor = 1.f/(1 + partExpSum);
cv::exp(L - L.at<float>(label), expL); exp(L - L.at<float>(label), expL);
if(probs) if(probs)
{ {
@ -537,7 +519,7 @@ void EM::eStep()
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
{ {
cv::Mat sampleProbs = trainProbs.row(sampleIndex); Mat sampleProbs = trainProbs.row(sampleIndex);
computeProbabilities(trainSamples.row(sampleIndex), trainLabels.at<int>(sampleIndex), computeProbabilities(trainSamples.row(sampleIndex), trainLabels.at<int>(sampleIndex),
&sampleProbs, &trainLikelihoods.at<float>(sampleIndex)); &sampleProbs, &trainLikelihoods.at<float>(sampleIndex));
} }
@ -546,12 +528,12 @@ void EM::eStep()
void EM::mStep() void EM::mStep()
{ {
trainCounts.create(1, nclusters, CV_32SC1); trainCounts.create(1, nclusters, CV_32SC1);
trainCounts = cv::Scalar(0); trainCounts = Scalar(0);
for(int sampleIndex = 0; sampleIndex < trainLabels.rows; sampleIndex++) for(int sampleIndex = 0; sampleIndex < trainLabels.rows; sampleIndex++)
trainCounts.at<int>(trainLabels.at<int>(sampleIndex))++; trainCounts.at<int>(trainLabels.at<int>(sampleIndex))++;
if(cv::countNonZero(trainCounts) != (int)trainCounts.total()) if(countNonZero(trainCounts) != (int)trainCounts.total())
{ {
clusterTrainSamples(); clusterTrainSamples();
} }
@ -562,14 +544,14 @@ void EM::mStep()
// Update weights // Update weights
// not normalized first // not normalized first
cv::reduce(trainProbs, weights, 0, CV_REDUCE_SUM); reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
// Update means // Update means
means.create(nclusters, dim, CV_32FC1); means.create(nclusters, dim, CV_32FC1);
means = cv::Scalar(0); means = Scalar(0);
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{ {
cv::Mat clusterMean = means.row(clusterIndex); Mat clusterMean = means.row(clusterIndex);
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
clusterMean += trainProbs.at<float>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex); clusterMean += trainProbs.at<float>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
clusterMean /= weights.at<float>(clusterIndex); clusterMean /= weights.at<float>(clusterIndex);
@ -591,12 +573,12 @@ void EM::mStep()
if(covMatType == EM::COV_MAT_GENERIC) if(covMatType == EM::COV_MAT_GENERIC)
covs[clusterIndex].create(dim, dim, CV_32FC1); covs[clusterIndex].create(dim, dim, CV_32FC1);
cv::Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ? Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ?
covsEigenValues[clusterIndex] : covs[clusterIndex]; covsEigenValues[clusterIndex] : covs[clusterIndex];
clusterCov = cv::Scalar(0); clusterCov = Scalar(0);
cv::Mat centeredSample; Mat centeredSample;
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
{ {
centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex); centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
@ -622,12 +604,12 @@ void EM::mStep()
// Update covsRotateMats for EM::COV_MAT_GENERIC only // Update covsRotateMats for EM::COV_MAT_GENERIC only
if(covMatType == EM::COV_MAT_GENERIC) if(covMatType == EM::COV_MAT_GENERIC)
{ {
cv::SVD svd(covs[clusterIndex], cv::SVD::MODIFY_A + cv::SVD::FULL_UV); SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
covsEigenValues[clusterIndex] = svd.w; covsEigenValues[clusterIndex] = svd.w;
covsRotateMats[clusterIndex] = svd.u; covsRotateMats[clusterIndex] = svd.u;
} }
cv::max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]); max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
// update invCovsEigenValues // update invCovsEigenValues
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex]; invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];

@ -320,6 +320,30 @@ void CV_KNearestTest::run( int /*start_from*/ )
ts->set_failed_test_info( code ); ts->set_failed_test_info( code );
} }
class EM_Params
{
public:
EM_Params(int nclusters=10, int covMatType=EM::COV_MAT_DIAGONAL, int startStep=EM::START_AUTO_STEP,
const cv::TermCriteria& termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON),
const cv::Mat* probs=0, const cv::Mat* weights=0,
const cv::Mat* means=0, const std::vector<cv::Mat>* covs=0)
: nclusters(nclusters), covMatType(covMatType), startStep(startStep),
probs(probs), weights(weights), means(means), covs(covs), termCrit(termCrit)
{}
int nclusters;
int covMatType;
int startStep;
// all 4 following matrices should have type CV_32FC1
const cv::Mat* probs;
const cv::Mat* weights;
const cv::Mat* means;
const std::vector<cv::Mat>* covs;
cv::TermCriteria termCrit;
};
//-------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------
class CV_EMTest : public cvtest::BaseTest class CV_EMTest : public cvtest::BaseTest
{ {
@ -327,13 +351,13 @@ public:
CV_EMTest() {} CV_EMTest() {}
protected: protected:
virtual void run( int start_from ); virtual void run( int start_from );
int runCase( int caseIndex, const cv::EM::Params& params, int runCase( int caseIndex, const EM_Params& params,
const cv::Mat& trainData, const cv::Mat& trainLabels, const cv::Mat& trainData, const cv::Mat& trainLabels,
const cv::Mat& testData, const cv::Mat& testLabels, const cv::Mat& testData, const cv::Mat& testLabels,
const vector<int>& sizes); const vector<int>& sizes);
}; };
int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params, int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
const cv::Mat& trainData, const cv::Mat& trainLabels, const cv::Mat& trainData, const cv::Mat& trainLabels,
const cv::Mat& testData, const cv::Mat& testLabels, const cv::Mat& testData, const cv::Mat& testLabels,
const vector<int>& sizes ) const vector<int>& sizes )
@ -343,8 +367,13 @@ int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
cv::Mat labels; cv::Mat labels;
float err; float err;
cv::EM em; cv::EM em(params.nclusters, params.covMatType, params.termCrit);
em.train( trainData, Mat(), params, &labels ); if( params.startStep == EM::START_AUTO_STEP )
em.train( trainData, labels );
else if( params.startStep == EM::START_E_STEP )
em.trainE( trainData, *params.means, *params.covs, *params.weights, labels );
else if( params.startStep == EM::START_M_STEP )
em.trainM( trainData, *params.probs, labels );
// check train error // check train error
if( !calcErr( labels, trainLabels, sizes, err , false ) ) if( !calcErr( labels, trainLabels, sizes, err , false ) )
@ -363,7 +392,7 @@ int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
for( int i = 0; i < testData.rows; i++ ) for( int i = 0; i < testData.rows; i++ )
{ {
Mat sample = testData.row(i); Mat sample = testData.row(i);
labels.at<int>(i,0) = (int)em.predict( sample, 0 ); labels.at<int>(i,0) = (int)em.predict( sample, noArray() );
} }
if( !calcErr( labels, testLabels, sizes, err, false ) ) if( !calcErr( labels, testLabels, sizes, err, false ) )
{ {
@ -398,7 +427,7 @@ void CV_EMTest::run( int /*start_from*/ )
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels; Mat testData( pointsCount, 2, CV_32FC1 ), testLabels;
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 ); generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
cv::EM::Params params; EM_Params params;
params.nclusters = 3; params.nclusters = 3;
Mat probs(trainData.rows, params.nclusters, CV_32FC1, cv::Scalar(1)); Mat probs(trainData.rows, params.nclusters, CV_32FC1, cv::Scalar(1));
params.probs = &probs; params.probs = &probs;
@ -474,19 +503,16 @@ protected:
virtual void run( int /*start_from*/ ) virtual void run( int /*start_from*/ )
{ {
int code = cvtest::TS::OK; int code = cvtest::TS::OK;
cv::EM em; cv::EM em(2);
Mat samples = Mat(3,1,CV_32F); Mat samples = Mat(3,1,CV_32F);
samples.at<float>(0,0) = 1; samples.at<float>(0,0) = 1;
samples.at<float>(1,0) = 2; samples.at<float>(1,0) = 2;
samples.at<float>(2,0) = 3; samples.at<float>(2,0) = 3;
cv::EM::Params params;
params.nclusters = 2;
Mat labels; Mat labels;
em.train(samples, Mat(), params, &labels); em.train(samples, labels);
Mat firstResult(samples.rows, 1, CV_32FC1); Mat firstResult(samples.rows, 1, CV_32FC1);
for( int i = 0; i < samples.rows; i++) for( int i = 0; i < samples.rows; i++)

Loading…
Cancel
Save