From fc5bba66afca45cff01816576183fcec8f554559 Mon Sep 17 00:00:00 2001 From: berak Date: Tue, 24 Apr 2018 12:11:59 +0200 Subject: [PATCH] ml: refactor non-virtual methods --- modules/ml/include/opencv2/ml.hpp | 14 ++++---- modules/ml/src/data.cpp | 59 +++++++++++++++---------------- modules/ml/src/rtrees.cpp | 16 +-------- modules/ml/src/svm.cpp | 26 ++------------ 4 files changed, 39 insertions(+), 76 deletions(-) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 2b694c8d47..357aac146c 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -198,7 +198,7 @@ public: CV_WRAP virtual Mat getTestSampleWeights() const = 0; CV_WRAP virtual Mat getVarIdx() const = 0; CV_WRAP virtual Mat getVarType() const = 0; - CV_WRAP Mat getVarSymbolFlags() const; + CV_WRAP virtual Mat getVarSymbolFlags() const = 0; CV_WRAP virtual int getResponseType() const = 0; CV_WRAP virtual Mat getTrainSampleIdx() const = 0; CV_WRAP virtual Mat getTestSampleIdx() const = 0; @@ -234,10 +234,10 @@ public: CV_WRAP virtual void shuffleTrainTest() = 0; /** @brief Returns matrix of test samples */ - CV_WRAP Mat getTestSamples() const; + CV_WRAP virtual Mat getTestSamples() const = 0; /** @brief Returns vector of symbolic names captured in loadFromCSV() */ - CV_WRAP void getNames(std::vector& names) const; + CV_WRAP virtual void getNames(std::vector& names) const = 0; CV_WRAP static Mat getSubVector(const Mat& vec, const Mat& idx); @@ -727,7 +727,7 @@ public: regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and the usual %SVM with parameters specified in params is executed. */ - CV_WRAP bool trainAuto(InputArray samples, + CV_WRAP virtual bool trainAuto(InputArray samples, int layout, InputArray responses, int kFold = 10, @@ -737,7 +737,7 @@ public: Ptr nuGrid = SVM::getDefaultGridPtr(SVM::NU), Ptr coeffGrid = SVM::getDefaultGridPtr(SVM::COEF), Ptr degreeGrid = SVM::getDefaultGridPtr(SVM::DEGREE), - bool balanced=false); + bool balanced=false) = 0; /** @brief Retrieves all the support vectors @@ -752,7 +752,7 @@ public: support vector, used for prediction, was derived from. They are returned in a floating-point matrix, where the support vectors are stored as matrix rows. */ - CV_WRAP Mat getUncompressedSupportVectors() const; + CV_WRAP virtual Mat getUncompressedSupportVectors() const = 0; /** @brief Retrieves the decision function @@ -1273,7 +1273,7 @@ public: @param results Array where the result of the calculation will be written. @param flags Flags for defining the type of RTrees. */ - CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const; + CV_WRAP virtual void getVotes(InputArray samples, OutputArray results, int flags) const = 0; /** Creates the empty model. Use StatModel::train to train the model, StatModel::train to create and train the model, diff --git a/modules/ml/src/data.cpp b/modules/ml/src/data.cpp index cbd1c3fde5..1067c31246 100644 --- a/modules/ml/src/data.cpp +++ b/modules/ml/src/data.cpp @@ -50,13 +50,6 @@ static const int VAR_MISSED = VAR_ORDERED; TrainData::~TrainData() {} -Mat TrainData::getTestSamples() const -{ - Mat idx = getTestSampleIdx(); - Mat samples = getSamples(); - return idx.empty() ? Mat() : getSubVector(samples, idx); -} - Mat TrainData::getSubVector(const Mat& vec, const Mat& idx) { if( idx.empty() ) @@ -119,6 +112,7 @@ Mat TrainData::getSubVector(const Mat& vec, const Mat& idx) return subvec; } + class TrainDataImpl CV_FINAL : public TrainData { public: @@ -155,6 +149,12 @@ public: return layout == ROW_SAMPLE ? samples.cols : samples.rows; } + Mat getTestSamples() const CV_OVERRIDE + { + Mat idx = getTestSampleIdx(); + return idx.empty() ? Mat() : getSubVector(samples, idx); + } + Mat getSamples() const CV_OVERRIDE { return samples; } Mat getResponses() const CV_OVERRIDE { return responses; } Mat getMissing() const CV_OVERRIDE { return missing; } @@ -987,6 +987,27 @@ public: } } + void getNames(std::vector& names) const CV_OVERRIDE + { + size_t n = nameMap.size(); + TrainDataImpl::MapType::const_iterator it = nameMap.begin(), + it_end = nameMap.end(); + names.resize(n+1); + names[0] = "?"; + for( ; it != it_end; ++it ) + { + String s = it->first; + int label = it->second; + CV_Assert( label > 0 && label <= (int)n ); + names[label] = s; + } + } + + Mat getVarSymbolFlags() const CV_OVERRIDE + { + return varSymbolFlags; + } + FILE* file; int layout; Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst; @@ -996,30 +1017,6 @@ public: MapType nameMap; }; -void TrainData::getNames(std::vector& names) const -{ - const TrainDataImpl* impl = dynamic_cast(this); - CV_Assert(impl != 0); - size_t n = impl->nameMap.size(); - TrainDataImpl::MapType::const_iterator it = impl->nameMap.begin(), - it_end = impl->nameMap.end(); - names.resize(n+1); - names[0] = "?"; - for( ; it != it_end; ++it ) - { - String s = it->first; - int label = it->second; - CV_Assert( label > 0 && label <= (int)n ); - names[label] = s; - } -} - -Mat TrainData::getVarSymbolFlags() const -{ - const TrainDataImpl* impl = dynamic_cast(this); - CV_Assert(impl != 0); - return impl->varSymbolFlags; -} Ptr TrainData::loadFromCSV(const String& filename, int headerLines, diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 0751e37b91..cc5253e57c 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -453,6 +453,7 @@ public: inline void setRegressionAccuracy(float val) CV_OVERRIDE { impl.params.setRegressionAccuracy(val); } inline cv::Mat getPriors() const CV_OVERRIDE { return impl.params.getPriors(); } inline void setPriors(const cv::Mat& val) CV_OVERRIDE { impl.params.setPriors(val); } + inline void getVotes(InputArray input, OutputArray output, int flags) const CV_OVERRIDE {return impl.getVotes(input,output,flags);} RTreesImpl() {} virtual ~RTreesImpl() CV_OVERRIDE {} @@ -485,12 +486,6 @@ public: impl.read(fn); } - void getVotes_( InputArray samples, OutputArray results, int flags ) const - { - CV_TRACE_FUNCTION(); - impl.getVotes(samples, results, flags); - } - Mat getVarImportance() const CV_OVERRIDE { return Mat_(impl.varImportance, true); } int getVarCount() const CV_OVERRIDE { return impl.getVarCount(); } @@ -519,15 +514,6 @@ Ptr RTrees::load(const String& filepath, const String& nodeName) return Algorithm::load(filepath, nodeName); } -void RTrees::getVotes(InputArray input, OutputArray output, int flags) const -{ - CV_TRACE_FUNCTION(); - const RTreesImpl* this_ = dynamic_cast(this); - if(!this_) - CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl"); - return this_->getVotes_(input, output, flags); -} - }} // End of file. diff --git a/modules/ml/src/svm.cpp b/modules/ml/src/svm.cpp index d6518ef013..b1a073b509 100644 --- a/modules/ml/src/svm.cpp +++ b/modules/ml/src/svm.cpp @@ -1250,7 +1250,7 @@ public: uncompressed_sv.release(); } - Mat getUncompressedSupportVectors_() const + Mat getUncompressedSupportVectors() const CV_OVERRIDE { return uncompressed_sv; } @@ -1982,10 +1982,10 @@ public: bool returnDFVal; }; - bool trainAuto_(InputArray samples, int layout, + bool trainAuto(InputArray samples, int layout, InputArray responses, int kfold, Ptr Cgrid, Ptr gammaGrid, Ptr pGrid, Ptr nuGrid, - Ptr coeffGrid, Ptr degreeGrid, bool balanced) + Ptr coeffGrid, Ptr degreeGrid, bool balanced) CV_OVERRIDE { Ptr data = TrainData::create(samples, layout, responses); return this->trainAuto( @@ -2353,26 +2353,6 @@ Ptr SVM::load(const String& filepath) return svm; } -Mat SVM::getUncompressedSupportVectors() const -{ - const SVMImpl* this_ = dynamic_cast(this); - if(!this_) - CV_Error(Error::StsNotImplemented, "the class is not SVMImpl"); - return this_->getUncompressedSupportVectors_(); -} - -bool SVM::trainAuto(InputArray samples, int layout, - InputArray responses, int kfold, Ptr Cgrid, - Ptr gammaGrid, Ptr pGrid, Ptr nuGrid, - Ptr coeffGrid, Ptr degreeGrid, bool balanced) -{ - SVMImpl* this_ = dynamic_cast(this); - if (!this_) { - CV_Error(Error::StsNotImplemented, "the class is not SVMImpl"); - } - return this_->trainAuto_(samples, layout, responses, - kfold, Cgrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced); -} } }