add 2 extra methods to ml::TrainData (#7169)

* expose 2 extra methods from ml::TrainData: getNames() and getVarSymbolFlags(). The first one returns text labels from CSV (if the data has been loaded from CSV); the second one returns a matrix of boolean values; its n-th element is 1 iff the corresponding column in the CSV uses symbolic names, not numbers.

* check that the dynamic_cast succeeds
pull/7179/merge
Vadim Pisarevsky 9 years ago committed by GitHub
parent 5ddd25313f
commit 40b870704e
  1. 4
      modules/ml/include/opencv2/ml.hpp
  2. 40
      modules/ml/src/data.cpp

@ -190,6 +190,7 @@ public:
CV_WRAP virtual Mat getTestSampleWeights() const = 0;
CV_WRAP virtual Mat getVarIdx() const = 0;
CV_WRAP virtual Mat getVarType() const = 0;
CV_WRAP Mat getVarSymbolFlags() const;
CV_WRAP virtual int getResponseType() const = 0;
CV_WRAP virtual Mat getTrainSampleIdx() const = 0;
CV_WRAP virtual Mat getTestSampleIdx() const = 0;
@ -227,6 +228,9 @@ public:
/** @brief Returns matrix of test samples */
CV_WRAP Mat getTestSamples() const;
/** @brief Returns vector of symbolic names captured in loadFromCSV() */
CV_WRAP void getNames(std::vector<String>& names) const;
CV_WRAP static Mat getSubVector(const Mat& vec, const Mat& idx);
/** @brief Reads the dataset from a .csv file and returns the ready-to-use training data.

@ -220,6 +220,7 @@ public:
samples.release();
missing.release();
varType.release();
varSymbolFlags.release();
responses.release();
sampleIdx.release();
trainSampleIdx.release();
@ -522,6 +523,7 @@ public:
std::vector<float> allresponses;
std::vector<float> rowvals;
std::vector<uchar> vtypes, rowtypes;
std::vector<uchar> vsymbolflags;
bool haveMissed = false;
char* buf = &_buf[0];
@ -583,6 +585,9 @@ public:
}
else
vtypes = rowtypes;
vsymbolflags.resize(nvars);
for( i = 0; i < nvars; i++ )
vsymbolflags[i] = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
@ -598,6 +603,11 @@ public:
{
CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
(varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
uchar sflag = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
if( vsymbolflags[i] == VAR_MISSED )
vsymbolflags[i] = sflag;
else
CV_Assert(vsymbolflags[i] == sflag || rowtypes[i] == VAR_MISSED);
}
if( ridx0 >= 0 )
@ -657,7 +667,10 @@ public:
}
bool ok = !samples.empty();
if(ok)
{
std::swap(tempNameMap, nameMap);
Mat(vsymbolflags).copyTo(varSymbolFlags);
}
return ok;
}
@ -976,13 +989,38 @@ public:
FILE* file;
int layout;
Mat samples, missing, varType, varIdx, responses, missingSubst;
Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
Mat sampleIdx, trainSampleIdx, testSampleIdx;
Mat sampleWeights, catMap, catOfs;
Mat normCatResponses, classLabels, classCounters;
MapType nameMap;
};
void TrainData::getNames(std::vector<String>& names) const
{
const TrainDataImpl* impl = dynamic_cast<const TrainDataImpl*>(this);
CV_Assert(impl != 0);
size_t n = impl->nameMap.size();
TrainDataImpl::MapType::const_iterator it = impl->nameMap.begin(),
it_end = impl->nameMap.end();
names.resize(n+1);
names[0] = "?";
for( ; it != it_end; ++it )
{
String s = it->first;
int label = it->second;
CV_Assert( label > 0 && label <= (int)n );
names[label] = s;
}
}
Mat TrainData::getVarSymbolFlags() const
{
const TrainDataImpl* impl = dynamic_cast<const TrainDataImpl*>(this);
CV_Assert(impl != 0);
return impl->varSymbolFlags;
}
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
int headerLines,
int responseStartIdx,

Loading…
Cancel
Save