|
|
|
@ -86,7 +86,8 @@ bool EM::train(InputArray samples, |
|
|
|
|
OutputArray probs, |
|
|
|
|
OutputArray logLikelihoods) |
|
|
|
|
{ |
|
|
|
|
setTrainData(START_AUTO_STEP, samples.getMat(), 0, 0, 0, 0); |
|
|
|
|
Mat samplesMat = samples.getMat(); |
|
|
|
|
setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0); |
|
|
|
|
return doTrain(START_AUTO_STEP, labels, probs, logLikelihoods); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -98,12 +99,13 @@ bool EM::trainE(InputArray samples, |
|
|
|
|
OutputArray probs, |
|
|
|
|
OutputArray logLikelihoods) |
|
|
|
|
{ |
|
|
|
|
Mat samplesMat = samples.getMat(); |
|
|
|
|
vector<Mat> covs0; |
|
|
|
|
_covs0.getMatVector(covs0); |
|
|
|
|
|
|
|
|
|
Mat means0 = _means0.getMat(), weights0 = _weights0.getMat(); |
|
|
|
|
|
|
|
|
|
setTrainData(START_E_STEP, samples.getMat(), 0, !_means0.empty() ? &means0 : 0, |
|
|
|
|
setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0, |
|
|
|
|
!_covs0.empty() ? &covs0 : 0, _weights0.empty() ? &weights0 : 0); |
|
|
|
|
return doTrain(START_E_STEP, labels, probs, logLikelihoods); |
|
|
|
|
} |
|
|
|
@ -114,9 +116,10 @@ bool EM::trainM(InputArray samples, |
|
|
|
|
OutputArray probs, |
|
|
|
|
OutputArray logLikelihoods) |
|
|
|
|
{ |
|
|
|
|
Mat samplesMat = samples.getMat(); |
|
|
|
|
Mat probs0 = _probs0.getMat(); |
|
|
|
|
|
|
|
|
|
setTrainData(START_M_STEP, samples.getMat(), !_probs0.empty() ? &probs0 : 0, 0, 0, 0); |
|
|
|
|
setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0); |
|
|
|
|
return doTrain(START_M_STEP, labels, probs, logLikelihoods); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -337,7 +340,11 @@ void EM::clusterTrainSamples() |
|
|
|
|
|
|
|
|
|
CV_Assert(meansFlt.type() == CV_32FC1); |
|
|
|
|
if(trainSamples.type() != CV_64FC1) |
|
|
|
|
trainSamplesFlt.convertTo(trainSamples, CV_64FC1); |
|
|
|
|
{ |
|
|
|
|
Mat trainSamplesBuffer; |
|
|
|
|
trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1); |
|
|
|
|
trainSamples = trainSamplesBuffer; |
|
|
|
|
} |
|
|
|
|
meansFlt.convertTo(means, CV_64FC1); |
|
|
|
|
|
|
|
|
|
// Compute weights and covs
|
|
|
|
|