From fef7fc343e1f6d2297bed2e0c33bfa3ea5f22e19 Mon Sep 17 00:00:00 2001 From: Alexander Alekhin Date: Sun, 22 Sep 2019 11:11:08 +0000 Subject: [PATCH] ml: add checks of empty train data --- modules/ml/src/ann_mlp.cpp | 2 ++ modules/ml/src/boost.cpp | 3 +++ modules/ml/src/em.cpp | 1 + modules/ml/src/inner_functions.cpp | 5 ++++- modules/ml/src/knearest.cpp | 2 ++ modules/ml/src/lr.cpp | 6 ++---- modules/ml/src/nbayes.cpp | 1 + modules/ml/src/rtrees.cpp | 3 +++ modules/ml/src/svm.cpp | 2 ++ modules/ml/src/svmsgd.cpp | 1 + modules/ml/src/tree.cpp | 3 +++ modules/ml/test/test_lr.cpp | 7 ++----- modules/ml/test/test_mltests2.cpp | 8 +++++++- 13 files changed, 33 insertions(+), 11 deletions(-) diff --git a/modules/ml/src/ann_mlp.cpp b/modules/ml/src/ann_mlp.cpp index ce6fdd877d..da5c506fcc 100644 --- a/modules/ml/src/ann_mlp.cpp +++ b/modules/ml/src/ann_mlp.cpp @@ -920,6 +920,7 @@ public: bool train( const Ptr& trainData, int flags ) CV_OVERRIDE { + CV_Assert(!trainData.empty()); const int MAX_ITER = 1000; const double DEFAULT_EPSILON = FLT_EPSILON; @@ -955,6 +956,7 @@ public: } int train_anneal(const Ptr& trainData) { + CV_Assert(!trainData.empty()); SimulatedAnnealingANN_MLP s(*this, trainData); trained = true; // Enable call to CalcError int iter = simulatedAnnealingSolver(s, params.initialT, params.finalT, params.coolingRatio, params.itePerStep, NULL, params.rEnergy); diff --git a/modules/ml/src/boost.cpp b/modules/ml/src/boost.cpp index b3e8c2724a..4b94410eeb 100644 --- a/modules/ml/src/boost.cpp +++ b/modules/ml/src/boost.cpp @@ -88,6 +88,7 @@ public: void startTraining( const Ptr& trainData, int flags ) CV_OVERRIDE { + CV_Assert(!trainData.empty()); DTreesImpl::startTraining(trainData, flags); sumResult.assign(w->sidx.size(), 0.); @@ -184,6 +185,7 @@ public: bool train( const Ptr& trainData, int flags ) CV_OVERRIDE { + CV_Assert(!trainData.empty()); startTraining(trainData, flags); int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000; vector sidx = w->sidx; @@ -482,6 +484,7 @@ public: bool train( const Ptr& trainData, int flags ) CV_OVERRIDE { + CV_Assert(!trainData.empty()); return impl.train(trainData, flags); } diff --git a/modules/ml/src/em.cpp b/modules/ml/src/em.cpp index c2dfc9c523..ec73bfd1b5 100644 --- a/modules/ml/src/em.cpp +++ b/modules/ml/src/em.cpp @@ -112,6 +112,7 @@ public: bool train(const Ptr& data, int) CV_OVERRIDE { + CV_Assert(!data.empty()); Mat samples = data->getTrainSamples(), labels; return trainEM(samples, labels, noArray(), noArray()); } diff --git a/modules/ml/src/inner_functions.cpp b/modules/ml/src/inner_functions.cpp index 6f8b222d19..b823c5ba22 100644 --- a/modules/ml/src/inner_functions.cpp +++ b/modules/ml/src/inner_functions.cpp @@ -59,9 +59,10 @@ bool StatModel::empty() const { return !isTrained(); } int StatModel::getVarCount() const { return 0; } -bool StatModel::train( const Ptr&, int ) +bool StatModel::train(const Ptr& trainData, int ) { CV_TRACE_FUNCTION(); + CV_Assert(!trainData.empty()); CV_Error(CV_StsNotImplemented, ""); return false; } @@ -69,6 +70,7 @@ bool StatModel::train( const Ptr&, int ) bool StatModel::train( InputArray samples, int layout, InputArray responses ) { CV_TRACE_FUNCTION(); + CV_Assert(!samples.empty()); return train(TrainData::create(samples, layout, responses)); } @@ -134,6 +136,7 @@ public: float StatModel::calcError(const Ptr& data, bool testerr, OutputArray _resp) const { CV_TRACE_FUNCTION_SKIP_NESTED(); + CV_Assert(!data.empty()); Mat samples = data->getSamples(); Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx(); Mat weights = testerr ? data->getTestSampleWeights() : data->getTrainSampleWeights(); diff --git a/modules/ml/src/knearest.cpp b/modules/ml/src/knearest.cpp index dcc201158d..ca23d0f4d6 100644 --- a/modules/ml/src/knearest.cpp +++ b/modules/ml/src/knearest.cpp @@ -73,6 +73,7 @@ public: bool train( const Ptr& data, int flags ) { + CV_Assert(!data.empty()); Mat new_samples = data->getTrainSamples(ROW_SAMPLE); Mat new_responses; data->getTrainResponses().convertTo(new_responses, CV_32F); @@ -494,6 +495,7 @@ public: bool train( const Ptr& data, int flags ) CV_OVERRIDE { + CV_Assert(!data.empty()); return impl->train(data, flags); } diff --git a/modules/ml/src/lr.cpp b/modules/ml/src/lr.cpp index 166b6a39d8..ad7b8079a2 100644 --- a/modules/ml/src/lr.cpp +++ b/modules/ml/src/lr.cpp @@ -142,12 +142,10 @@ Ptr LogisticRegression::load(const String& filepath, const S bool LogisticRegressionImpl::train(const Ptr& trainData, int) { CV_TRACE_FUNCTION_SKIP_NESTED(); + CV_Assert(!trainData.empty()); + // return value bool ok = false; - - if (trainData.empty()) { - return false; - } clear(); Mat _data_i = trainData->getSamples(); Mat _labels_i = trainData->getResponses(); diff --git a/modules/ml/src/nbayes.cpp b/modules/ml/src/nbayes.cpp index baa46d8f0a..60dda0c7d4 100644 --- a/modules/ml/src/nbayes.cpp +++ b/modules/ml/src/nbayes.cpp @@ -54,6 +54,7 @@ public: bool train( const Ptr& trainData, int flags ) CV_OVERRIDE { + CV_Assert(!trainData.empty()); const float min_variation = FLT_EPSILON; Mat responses = trainData->getNormCatResponses(); Mat __cls_labels = trainData->getClassLabels(); diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index e10ef9bc42..f7a3a5b0fb 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -111,6 +111,7 @@ public: void startTraining( const Ptr& trainData, int flags ) CV_OVERRIDE { CV_TRACE_FUNCTION(); + CV_Assert(!trainData.empty()); DTreesImpl::startTraining(trainData, flags); int nvars = w->data->getNVars(); int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars)); @@ -133,6 +134,7 @@ public: bool train( const Ptr& trainData, int flags ) CV_OVERRIDE { CV_TRACE_FUNCTION(); + CV_Assert(!trainData.empty()); startTraining(trainData, flags); int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ? rparams.termCrit.maxCount : 10000; @@ -463,6 +465,7 @@ public: bool train( const Ptr& trainData, int flags ) CV_OVERRIDE { CV_TRACE_FUNCTION(); + CV_Assert(!trainData.empty()); if (impl.getCVFolds() != 0) CV_Error(Error::StsBadArg, "Cross validation for RTrees is not implemented"); return impl.train(trainData, flags); diff --git a/modules/ml/src/svm.cpp b/modules/ml/src/svm.cpp index 4a1ffd0575..4c3ff2a319 100644 --- a/modules/ml/src/svm.cpp +++ b/modules/ml/src/svm.cpp @@ -1613,6 +1613,7 @@ public: bool train( const Ptr& data, int ) CV_OVERRIDE { + CV_Assert(!data.empty()); clear(); checkParams(); @@ -1739,6 +1740,7 @@ public: ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid, bool balanced ) CV_OVERRIDE { + CV_Assert(!data.empty()); checkParams(); int svmType = params.svmType; diff --git a/modules/ml/src/svmsgd.cpp b/modules/ml/src/svmsgd.cpp index ac778f4da8..266c7cf300 100644 --- a/modules/ml/src/svmsgd.cpp +++ b/modules/ml/src/svmsgd.cpp @@ -230,6 +230,7 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const bool SVMSGDImpl::train(const Ptr& data, int) { + CV_Assert(!data.empty()); clear(); CV_Assert( isClassifier() ); //toDo: consider diff --git a/modules/ml/src/tree.cpp b/modules/ml/src/tree.cpp index 2f9dc049e1..87181b156c 100644 --- a/modules/ml/src/tree.cpp +++ b/modules/ml/src/tree.cpp @@ -98,6 +98,7 @@ DTrees::Split::Split() DTreesImpl::WorkData::WorkData(const Ptr& _data) { + CV_Assert(!_data.empty()); data = _data; vector subsampleIdx; Mat sidx0 = _data->getTrainSampleIdx(); @@ -136,6 +137,7 @@ void DTreesImpl::clear() void DTreesImpl::startTraining( const Ptr& data, int ) { + CV_Assert(!data.empty()); clear(); w = makePtr(data); @@ -223,6 +225,7 @@ void DTreesImpl::endTraining() bool DTreesImpl::train( const Ptr& trainData, int flags ) { + CV_Assert(!trainData.empty()); startTraining(trainData, flags); bool ok = addTree( w->sidx ) >= 0; w.release(); diff --git a/modules/ml/test/test_lr.cpp b/modules/ml/test/test_lr.cpp index 15d59d77fa..d57825152c 100644 --- a/modules/ml/test/test_lr.cpp +++ b/modules/ml/test/test_lr.cpp @@ -94,11 +94,7 @@ void CV_LRTest::run( int /*start_from*/ ) // initialize variables from the popular Iris Dataset string dataFileName = ts->get_data_path() + "iris.data"; Ptr tdata = TrainData::loadFromCSV(dataFileName, 0); - - if (tdata.empty()) { - ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA); - return; - } + ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << dataFileName; // run LR classifier train classifier Ptr p = LogisticRegression::create(); @@ -156,6 +152,7 @@ void CV_LRTest_SaveLoad::run( int /*start_from*/ ) // initialize variables from the popular Iris Dataset string dataFileName = ts->get_data_path() + "iris.data"; Ptr tdata = TrainData::loadFromCSV(dataFileName, 0); + ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << dataFileName; Mat responses1, responses2; Mat learnt_mat1, learnt_mat2; diff --git a/modules/ml/test/test_mltests2.cpp b/modules/ml/test/test_mltests2.cpp index 616a527bfe..f9bbf70f95 100644 --- a/modules/ml/test/test_mltests2.cpp +++ b/modules/ml/test/test_mltests2.cpp @@ -105,6 +105,7 @@ int str_to_ann_activation_function(String& str) void ann_check_data( Ptr _data ) { CV_TRACE_FUNCTION(); + CV_Assert(!_data.empty()); Mat values = _data->getSamples(); Mat var_idx = _data->getVarIdx(); int nvars = (int)var_idx.total(); @@ -118,6 +119,7 @@ void ann_check_data( Ptr _data ) Mat ann_get_new_responses( Ptr _data, map& cls_map ) { CV_TRACE_FUNCTION(); + CV_Assert(!_data.empty()); Mat train_sidx = _data->getTrainSampleIdx(); int* train_sidx_ptr = train_sidx.ptr(); Mat responses = _data->getResponses(); @@ -150,6 +152,8 @@ Mat ann_get_new_responses( Ptr _data, map& cls_map ) float ann_calc_error( Ptr ann, Ptr _data, map& cls_map, int type, vector *resp_labels ) { CV_TRACE_FUNCTION(); + CV_Assert(!ann.empty()); + CV_Assert(!_data.empty()); float err = 0; Mat samples = _data->getSamples(); Mat responses = _data->getResponses(); @@ -264,13 +268,15 @@ TEST_P(ML_ANN_METHOD, Test) String dataname = folder + "waveform" + '_' + methodName; Ptr tdata2 = TrainData::loadFromCSV(original_path, 0); + ASSERT_FALSE(tdata2.empty()) << "Could not find test data file : " << original_path; + Mat samples = tdata2->getSamples()(Range(0, N), Range::all()); Mat responses(N, 3, CV_32FC1, Scalar(0)); for (int i = 0; i < N; i++) responses.at(i, static_cast(tdata2->getResponses().at(i, 0))) = 1; Ptr tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses); + ASSERT_FALSE(tdata.empty()); - ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path; RNG& rng = theRNG(); rng.state = 0; tdata->setTrainTestSplitRatio(0.8);