|
|
|
@ -572,7 +572,106 @@ protected: |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class CV_EMTest_Classification : public cvtest::BaseTest |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
CV_EMTest_Classification() {} |
|
|
|
|
protected: |
|
|
|
|
virtual void run(int) |
|
|
|
|
{ |
|
|
|
|
// This test classifies spam by the following way:
|
|
|
|
|
// 1. estimates distributions of "spam" / "not spam"
|
|
|
|
|
// 2. predict classID using Bayes classifier for estimated distributions.
|
|
|
|
|
|
|
|
|
|
CvMLData data; |
|
|
|
|
string dataFilename = string(ts->get_data_path()) + "spambase.data"; |
|
|
|
|
|
|
|
|
|
if(data.read_csv(dataFilename.c_str()) != 0) |
|
|
|
|
{ |
|
|
|
|
ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n"); |
|
|
|
|
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
Mat values = data.get_values(); |
|
|
|
|
CV_Assert(values.cols == 58); |
|
|
|
|
int responseIndex = 57; |
|
|
|
|
|
|
|
|
|
Mat samples = values.colRange(0, responseIndex); |
|
|
|
|
Mat responses = values.col(responseIndex); |
|
|
|
|
|
|
|
|
|
vector<int> trainSamplesMask(samples.rows, 0); |
|
|
|
|
int trainSamplesCount = (int)(0.5f * samples.rows); |
|
|
|
|
for(int i = 0; i < trainSamplesCount; i++) |
|
|
|
|
trainSamplesMask[i] = 1; |
|
|
|
|
RNG rng(0); |
|
|
|
|
for(size_t i = 0; i < trainSamplesMask.size(); i++) |
|
|
|
|
{ |
|
|
|
|
int i1 = rng(trainSamplesMask.size()); |
|
|
|
|
int i2 = rng(trainSamplesMask.size()); |
|
|
|
|
std::swap(trainSamplesMask[i1], trainSamplesMask[i2]); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
EM model0(3), model1(3); |
|
|
|
|
Mat samples0, samples1; |
|
|
|
|
for(int i = 0; i < samples.rows; i++) |
|
|
|
|
{ |
|
|
|
|
if(trainSamplesMask[i]) |
|
|
|
|
{ |
|
|
|
|
Mat sample = samples.row(i); |
|
|
|
|
int resp = (int)responses.at<float>(i); |
|
|
|
|
if(resp == 0) |
|
|
|
|
samples0.push_back(sample); |
|
|
|
|
else |
|
|
|
|
samples1.push_back(sample); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
model0.train(samples0); |
|
|
|
|
model1.train(samples1); |
|
|
|
|
|
|
|
|
|
Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)), |
|
|
|
|
testConfusionMat(2, 2, CV_32SC1, Scalar(0)); |
|
|
|
|
const double lambda = 1.; |
|
|
|
|
for(int i = 0; i < samples.rows; i++) |
|
|
|
|
{ |
|
|
|
|
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0; |
|
|
|
|
Mat sample = samples.row(i); |
|
|
|
|
model0.predict(sample, noArray(), &sampleLogLikelihoods0); |
|
|
|
|
model1.predict(sample, noArray(), &sampleLogLikelihoods1); |
|
|
|
|
|
|
|
|
|
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1; |
|
|
|
|
|
|
|
|
|
if(trainSamplesMask[i]) |
|
|
|
|
trainConfusionMat.at<int>((int)responses.at<float>(i), classID)++; |
|
|
|
|
else |
|
|
|
|
testConfusionMat.at<int>((int)responses.at<float>(i), classID)++; |
|
|
|
|
} |
|
|
|
|
// std::cout << trainConfusionMat << std::endl;
|
|
|
|
|
// std::cout << testConfusionMat << std::endl;
|
|
|
|
|
|
|
|
|
|
double trainError = (double)(trainConfusionMat.at<int>(1,0) + trainConfusionMat.at<int>(0,1)) / trainSamplesCount; |
|
|
|
|
double testError = (double)(testConfusionMat.at<int>(1,0) + testConfusionMat.at<int>(0,1)) / (samples.rows - trainSamplesCount); |
|
|
|
|
const double maxTrainError = 0.16; |
|
|
|
|
const double maxTestError = 0.19; |
|
|
|
|
|
|
|
|
|
int code = cvtest::TS::OK; |
|
|
|
|
if(trainError > maxTrainError) |
|
|
|
|
{ |
|
|
|
|
ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError); |
|
|
|
|
code = cvtest::TS::FAIL_INVALID_TEST_DATA; |
|
|
|
|
} |
|
|
|
|
if(testError > maxTestError) |
|
|
|
|
{ |
|
|
|
|
ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", trainError, maxTrainError); |
|
|
|
|
code = cvtest::TS::FAIL_INVALID_TEST_DATA; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ts->set_failed_test_info(code); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); } |
|
|
|
|
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); } |
|
|
|
|
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); } |
|
|
|
|
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); } |
|
|
|
|
TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); } |
|
|
|
|
|
|
|
|
|