mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
6.5 KiB
186 lines
6.5 KiB
// This file is part of OpenCV project. |
|
// It is subject to the license terms in the LICENSE file found in the top-level directory |
|
// of this distribution and at http://opencv.org/license.html. |
|
|
|
#include "test_precomp.hpp" |
|
|
|
namespace opencv_test { namespace { |
|
|
|
CV_ENUM(EM_START_STEP, EM::START_AUTO_STEP, EM::START_M_STEP, EM::START_E_STEP) |
|
CV_ENUM(EM_COV_MAT, EM::COV_MAT_GENERIC, EM::COV_MAT_DIAGONAL, EM::COV_MAT_SPHERICAL) |
|
|
|
typedef testing::TestWithParam< tuple<EM_START_STEP, EM_COV_MAT> > ML_EM_Params; |
|
|
|
TEST_P(ML_EM_Params, accuracy) |
|
{ |
|
const int nclusters = 3; |
|
const int sizesArr[] = { 500, 700, 800 }; |
|
const vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) ); |
|
const int pointsCount = sizesArr[0] + sizesArr[1] + sizesArr[2]; |
|
Mat means; |
|
vector<Mat> covs; |
|
defaultDistribs( means, covs, CV_64FC1 ); |
|
Mat trainData(pointsCount, 2, CV_64FC1 ); |
|
Mat trainLabels; |
|
generateData( trainData, trainLabels, sizes, means, covs, CV_64FC1, CV_32SC1 ); |
|
Mat testData( pointsCount, 2, CV_64FC1 ); |
|
Mat testLabels; |
|
generateData( testData, testLabels, sizes, means, covs, CV_64FC1, CV_32SC1 ); |
|
Mat probs(trainData.rows, nclusters, CV_64FC1, cv::Scalar(1)); |
|
Mat weights(1, nclusters, CV_64FC1, cv::Scalar(1)); |
|
TermCriteria termCrit(cv::TermCriteria::COUNT + cv::TermCriteria::EPS, 100, FLT_EPSILON); |
|
int startStep = get<0>(GetParam()); |
|
int covMatType = get<1>(GetParam()); |
|
cv::Mat labels; |
|
|
|
Ptr<EM> em = EM::create(); |
|
em->setClustersNumber(nclusters); |
|
em->setCovarianceMatrixType(covMatType); |
|
em->setTermCriteria(termCrit); |
|
if( startStep == EM::START_AUTO_STEP ) |
|
em->trainEM( trainData, noArray(), labels, noArray() ); |
|
else if( startStep == EM::START_E_STEP ) |
|
em->trainE( trainData, means, covs, weights, noArray(), labels, noArray() ); |
|
else if( startStep == EM::START_M_STEP ) |
|
em->trainM( trainData, probs, noArray(), labels, noArray() ); |
|
|
|
{ |
|
SCOPED_TRACE("Train"); |
|
float err = 1000; |
|
EXPECT_TRUE(calcErr( labels, trainLabels, sizes, err , false, false )); |
|
EXPECT_LE(err, 0.008f); |
|
} |
|
|
|
{ |
|
SCOPED_TRACE("Test"); |
|
float err = 1000; |
|
labels.create( testData.rows, 1, CV_32SC1 ); |
|
for( int i = 0; i < testData.rows; i++ ) |
|
{ |
|
Mat sample = testData.row(i); |
|
Mat out_probs; |
|
labels.at<int>(i) = static_cast<int>(em->predict2( sample, out_probs )[1]); |
|
} |
|
EXPECT_TRUE(calcErr( labels, testLabels, sizes, err, false, false )); |
|
EXPECT_LE(err, 0.008f); |
|
} |
|
} |
|
|
|
INSTANTIATE_TEST_CASE_P(/**/, ML_EM_Params, |
|
testing::Combine( |
|
testing::Values(EM::START_AUTO_STEP, EM::START_M_STEP, EM::START_E_STEP), |
|
testing::Values(EM::COV_MAT_GENERIC, EM::COV_MAT_DIAGONAL, EM::COV_MAT_SPHERICAL) |
|
)); |
|
|
|
//================================================================================================== |
|
|
|
TEST(ML_EM, save_load) |
|
{ |
|
const int nclusters = 2; |
|
Mat_<double> samples(3, 1); |
|
samples << 1., 2., 3.; |
|
|
|
std::vector<double> firstResult; |
|
string filename = cv::tempfile(".xml"); |
|
{ |
|
Mat labels; |
|
Ptr<EM> em = EM::create(); |
|
em->setClustersNumber(nclusters); |
|
em->trainEM(samples, noArray(), labels, noArray()); |
|
for( int i = 0; i < samples.rows; i++) |
|
{ |
|
Vec2d res = em->predict2(samples.row(i), noArray()); |
|
firstResult.push_back(res[1]); |
|
} |
|
{ |
|
FileStorage fs = FileStorage(filename, FileStorage::WRITE); |
|
ASSERT_NO_THROW(fs << "em" << "{"); |
|
ASSERT_NO_THROW(em->write(fs)); |
|
ASSERT_NO_THROW(fs << "}"); |
|
} |
|
} |
|
{ |
|
Ptr<EM> em; |
|
ASSERT_NO_THROW(em = Algorithm::load<EM>(filename)); |
|
for( int i = 0; i < samples.rows; i++) |
|
{ |
|
SCOPED_TRACE(i); |
|
Vec2d res = em->predict2(samples.row(i), noArray()); |
|
EXPECT_DOUBLE_EQ(firstResult[i], res[1]); |
|
} |
|
} |
|
remove(filename.c_str()); |
|
} |
|
|
|
//================================================================================================== |
|
|
|
TEST(ML_EM, classification) |
|
{ |
|
// This test classifies spam by the following way: |
|
// 1. estimates distributions of "spam" / "not spam" |
|
// 2. predict classID using Bayes classifier for estimated distributions. |
|
string dataFilename = findDataFile("spambase.data"); |
|
Ptr<TrainData> data = TrainData::loadFromCSV(dataFilename, 0); |
|
ASSERT_FALSE(data.empty()); |
|
|
|
Mat samples = data->getSamples(); |
|
ASSERT_EQ(samples.cols, 57); |
|
Mat responses = data->getResponses(); |
|
|
|
vector<int> trainSamplesMask(samples.rows, 0); |
|
const int trainSamplesCount = (int)(0.5f * samples.rows); |
|
const int testSamplesCount = samples.rows - trainSamplesCount; |
|
for(int i = 0; i < trainSamplesCount; i++) |
|
trainSamplesMask[i] = 1; |
|
RNG &rng = cv::theRNG(); |
|
for(size_t i = 0; i < trainSamplesMask.size(); i++) |
|
{ |
|
int i1 = rng(static_cast<unsigned>(trainSamplesMask.size())); |
|
int i2 = rng(static_cast<unsigned>(trainSamplesMask.size())); |
|
std::swap(trainSamplesMask[i1], trainSamplesMask[i2]); |
|
} |
|
|
|
Mat samples0, samples1; |
|
for(int i = 0; i < samples.rows; i++) |
|
{ |
|
if(trainSamplesMask[i]) |
|
{ |
|
Mat sample = samples.row(i); |
|
int resp = (int)responses.at<float>(i); |
|
if(resp == 0) |
|
samples0.push_back(sample); |
|
else |
|
samples1.push_back(sample); |
|
} |
|
} |
|
|
|
Ptr<EM> model0 = EM::create(); |
|
model0->setClustersNumber(3); |
|
model0->trainEM(samples0, noArray(), noArray(), noArray()); |
|
|
|
Ptr<EM> model1 = EM::create(); |
|
model1->setClustersNumber(3); |
|
model1->trainEM(samples1, noArray(), noArray(), noArray()); |
|
|
|
// confusion matrices |
|
Mat_<int> trainCM(2, 2, 0); |
|
Mat_<int> testCM(2, 2, 0); |
|
const double lambda = 1.; |
|
for(int i = 0; i < samples.rows; i++) |
|
{ |
|
Mat sample = samples.row(i); |
|
double sampleLogLikelihoods0 = model0->predict2(sample, noArray())[0]; |
|
double sampleLogLikelihoods1 = model1->predict2(sample, noArray())[0]; |
|
int classID = (sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1) ? 0 : 1; |
|
int resp = (int)responses.at<float>(i); |
|
EXPECT_TRUE(resp == 0 || resp == 1); |
|
if(trainSamplesMask[i]) |
|
trainCM(resp, classID)++; |
|
else |
|
testCM(resp, classID)++; |
|
} |
|
EXPECT_LE((double)(trainCM(1,0) + trainCM(0,1)) / trainSamplesCount, 0.23); |
|
EXPECT_LE((double)(testCM(1,0) + testCM(0,1)) / testSamplesCount, 0.26); |
|
} |
|
|
|
}} // namespace
|
|
|