From 40b870704edf3c89288a5cc7d81507e1103684c9 Mon Sep 17 00:00:00 2001 From: Vadim Pisarevsky Date: Fri, 26 Aug 2016 16:25:46 +0400 Subject: [PATCH] add 2 extra methods to ml::TrainData (#7169) * expose 2 extra methods from ml::TrainData: getNames() and getVarSymbolFlags(). The first one returns text labels from CSV (if the data has been loaded from CSV); the second one returns a matrix of boolean values; its n-th element is 1 iff the corresponding column in the CSV uses symbolic names, not numbers. * check that the dynamic_cast succeeds --- modules/ml/include/opencv2/ml.hpp | 4 ++++ modules/ml/src/data.cpp | 40 ++++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index d016810874..ea9c89e4e6 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -190,6 +190,7 @@ public: CV_WRAP virtual Mat getTestSampleWeights() const = 0; CV_WRAP virtual Mat getVarIdx() const = 0; CV_WRAP virtual Mat getVarType() const = 0; + CV_WRAP Mat getVarSymbolFlags() const; CV_WRAP virtual int getResponseType() const = 0; CV_WRAP virtual Mat getTrainSampleIdx() const = 0; CV_WRAP virtual Mat getTestSampleIdx() const = 0; @@ -227,6 +228,9 @@ public: /** @brief Returns matrix of test samples */ CV_WRAP Mat getTestSamples() const; + /** @brief Returns vector of symbolic names captured in loadFromCSV() */ + CV_WRAP void getNames(std::vector& names) const; + CV_WRAP static Mat getSubVector(const Mat& vec, const Mat& idx); /** @brief Reads the dataset from a .csv file and returns the ready-to-use training data. diff --git a/modules/ml/src/data.cpp b/modules/ml/src/data.cpp index 465163c927..5e1b6d2340 100644 --- a/modules/ml/src/data.cpp +++ b/modules/ml/src/data.cpp @@ -220,6 +220,7 @@ public: samples.release(); missing.release(); varType.release(); + varSymbolFlags.release(); responses.release(); sampleIdx.release(); trainSampleIdx.release(); @@ -522,6 +523,7 @@ public: std::vector allresponses; std::vector rowvals; std::vector vtypes, rowtypes; + std::vector vsymbolflags; bool haveMissed = false; char* buf = &_buf[0]; @@ -583,6 +585,9 @@ public: } else vtypes = rowtypes; + vsymbolflags.resize(nvars); + for( i = 0; i < nvars; i++ ) + vsymbolflags[i] = (uchar)(rowtypes[i] == VAR_CATEGORICAL); ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1; ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1; @@ -598,6 +603,11 @@ public: { CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) || (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) ); + uchar sflag = (uchar)(rowtypes[i] == VAR_CATEGORICAL); + if( vsymbolflags[i] == VAR_MISSED ) + vsymbolflags[i] = sflag; + else + CV_Assert(vsymbolflags[i] == sflag || rowtypes[i] == VAR_MISSED); } if( ridx0 >= 0 ) @@ -657,7 +667,10 @@ public: } bool ok = !samples.empty(); if(ok) + { std::swap(tempNameMap, nameMap); + Mat(vsymbolflags).copyTo(varSymbolFlags); + } return ok; } @@ -976,13 +989,38 @@ public: FILE* file; int layout; - Mat samples, missing, varType, varIdx, responses, missingSubst; + Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst; Mat sampleIdx, trainSampleIdx, testSampleIdx; Mat sampleWeights, catMap, catOfs; Mat normCatResponses, classLabels, classCounters; MapType nameMap; }; +void TrainData::getNames(std::vector& names) const +{ + const TrainDataImpl* impl = dynamic_cast(this); + CV_Assert(impl != 0); + size_t n = impl->nameMap.size(); + TrainDataImpl::MapType::const_iterator it = impl->nameMap.begin(), + it_end = impl->nameMap.end(); + names.resize(n+1); + names[0] = "?"; + for( ; it != it_end; ++it ) + { + String s = it->first; + int label = it->second; + CV_Assert( label > 0 && label <= (int)n ); + names[label] = s; + } +} + +Mat TrainData::getVarSymbolFlags() const +{ + const TrainDataImpl* impl = dynamic_cast(this); + CV_Assert(impl != 0); + return impl->varSymbolFlags; +} + Ptr TrainData::loadFromCSV(const String& filename, int headerLines, int responseStartIdx,