diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 7acce7f33c..0b9026950e 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -224,6 +224,9 @@ public: CV_WRAP virtual void setTrainTestSplitRatio(double ratio, bool shuffle=true) = 0; CV_WRAP virtual void shuffleTrainTest() = 0; + /** @brief Returns matrix of test samples */ + CV_WRAP Mat getTestSamples() const; + CV_WRAP static Mat getSubVector(const Mat& vec, const Mat& idx); /** @brief Reads the dataset from a .csv file and returns the ready-to-use training data. diff --git a/modules/ml/src/data.cpp b/modules/ml/src/data.cpp index a1608e3984..ad652568c7 100644 --- a/modules/ml/src/data.cpp +++ b/modules/ml/src/data.cpp @@ -50,6 +50,13 @@ static const int VAR_MISSED = VAR_ORDERED; TrainData::~TrainData() {} +Mat TrainData::getTestSamples() const +{ + Mat idx = getTestSampleIdx(); + Mat samples = getSamples(); + return idx.empty() ? Mat() : getSubVector(samples, idx); +} + Mat TrainData::getSubVector(const Mat& vec, const Mat& idx) { if( idx.empty() )