|
|
|
@ -158,6 +158,109 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); } |
|
|
|
|
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); } |
|
|
|
|
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); } |
|
|
|
|
|
|
|
|
|
class CV_LegacyTest : public cvtest::BaseTest |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string()) |
|
|
|
|
: cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes) |
|
|
|
|
{ |
|
|
|
|
} |
|
|
|
|
virtual ~CV_LegacyTest() {} |
|
|
|
|
protected: |
|
|
|
|
void run(int) |
|
|
|
|
{ |
|
|
|
|
unsigned int idx = 0; |
|
|
|
|
for (;;) |
|
|
|
|
{ |
|
|
|
|
if (idx >= suffixes.size()) |
|
|
|
|
break; |
|
|
|
|
int found = (int)suffixes.find(';', idx); |
|
|
|
|
string piece = suffixes.substr(idx, found - idx); |
|
|
|
|
if (piece.empty()) |
|
|
|
|
break; |
|
|
|
|
oneTest(piece); |
|
|
|
|
idx += (unsigned int)piece.size() + 1; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
void oneTest(const string & suffix) |
|
|
|
|
{ |
|
|
|
|
using namespace cv::ml; |
|
|
|
|
|
|
|
|
|
int code = cvtest::TS::OK; |
|
|
|
|
string filename = ts->get_data_path() + "legacy/" + modelName + suffix; |
|
|
|
|
bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES; |
|
|
|
|
Ptr<StatModel> model; |
|
|
|
|
if (modelName == CV_BOOST) |
|
|
|
|
model = StatModel::load<Boost>(filename); |
|
|
|
|
else if (modelName == CV_ANN) |
|
|
|
|
model = StatModel::load<ANN_MLP>(filename); |
|
|
|
|
else if (modelName == CV_DTREE) |
|
|
|
|
model = StatModel::load<DTrees>(filename); |
|
|
|
|
else if (modelName == CV_NBAYES) |
|
|
|
|
model = StatModel::load<NormalBayesClassifier>(filename); |
|
|
|
|
else if (modelName == CV_SVM) |
|
|
|
|
model = StatModel::load<SVM>(filename); |
|
|
|
|
else if (modelName == CV_RTREES) |
|
|
|
|
model = StatModel::load<RTrees>(filename); |
|
|
|
|
if (!model) |
|
|
|
|
{ |
|
|
|
|
code = cvtest::TS::FAIL_INVALID_TEST_DATA; |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F); |
|
|
|
|
ts->get_rng().fill(input, RNG::UNIFORM, 0, 40); |
|
|
|
|
|
|
|
|
|
if (isTree) |
|
|
|
|
randomFillCategories(filename, input); |
|
|
|
|
|
|
|
|
|
Mat output; |
|
|
|
|
model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0)); |
|
|
|
|
// just check if no internal assertions or errors thrown
|
|
|
|
|
} |
|
|
|
|
ts->set_failed_test_info(code); |
|
|
|
|
} |
|
|
|
|
void randomFillCategories(const string & filename, Mat & input) |
|
|
|
|
{ |
|
|
|
|
Mat catMap; |
|
|
|
|
Mat catCount; |
|
|
|
|
std::vector<uchar> varTypes; |
|
|
|
|
|
|
|
|
|
FileStorage fs(filename, FileStorage::READ); |
|
|
|
|
FileNode root = fs.getFirstTopLevelNode(); |
|
|
|
|
root["cat_map"] >> catMap; |
|
|
|
|
root["cat_count"] >> catCount; |
|
|
|
|
root["var_type"] >> varTypes; |
|
|
|
|
|
|
|
|
|
int offset = 0; |
|
|
|
|
int countOffset = 0; |
|
|
|
|
uint var = 0, varCount = (uint)varTypes.size(); |
|
|
|
|
for (; var < varCount; ++var) |
|
|
|
|
{ |
|
|
|
|
if (varTypes[var] == ml::VAR_CATEGORICAL) |
|
|
|
|
{ |
|
|
|
|
int size = catCount.at<int>(0, countOffset); |
|
|
|
|
for (int row = 0; row < input.rows; ++row) |
|
|
|
|
{ |
|
|
|
|
int randomChosenIndex = offset + ((uint)ts->get_rng()) % size; |
|
|
|
|
int value = catMap.at<int>(0, randomChosenIndex); |
|
|
|
|
input.at<float>(row, var) = (float)value; |
|
|
|
|
} |
|
|
|
|
offset += size; |
|
|
|
|
++countOffset; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
string modelName; |
|
|
|
|
string suffixes; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); } |
|
|
|
|
TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); } |
|
|
|
|
TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); } |
|
|
|
|
TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); } |
|
|
|
|
TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); } |
|
|
|
|
TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); } |
|
|
|
|
|
|
|
|
|
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
|
|
|
|
|
{ |
|
|
|
|