Wrappers for load methods of EM, LR, SVMSGD and Normal Bayes Classifier

pull/8099/head
chrizandr 8 years ago
parent aa5caf83f6
commit 519fbdb8ab
  1. 44
      modules/ml/include/opencv2/ml.hpp
  2. 5
      modules/ml/src/em.cpp
  3. 6
      modules/ml/src/lr.cpp
  4. 5
      modules/ml/src/nbayes.cpp
  5. 6
      modules/ml/src/svmsgd.cpp

@ -393,6 +393,17 @@ public:
/** Creates empty model /** Creates empty model
Use StatModel::train to train the model after creation. */ Use StatModel::train to train the model after creation. */
CV_WRAP static Ptr<NormalBayesClassifier> create(); CV_WRAP static Ptr<NormalBayesClassifier> create();
/** @brief Loads and creates a serialized NormalBayesClassifier from a file
*
* Use NormalBayesClassifier::save to serialize and store an NormalBayesClassifier to disk.
* Load the NormalBayesClassifier from this file again, by calling this function with the path to the file.
* Optionally specify the node for the file containing the classifier
*
* @param filepath path to serialized NormalBayesClassifier
* @param nodeName name of node containing the classifier
*/
CV_WRAP static Ptr<NormalBayesClassifier> load(const String& filepath , const String& nodeName = String());
}; };
/****************************************************************************************\ /****************************************************************************************\
@ -927,6 +938,17 @@ public:
can use one of the EM::train\* methods or load it from file using Algorithm::load\<EM\>(filename). can use one of the EM::train\* methods or load it from file using Algorithm::load\<EM\>(filename).
*/ */
CV_WRAP static Ptr<EM> create(); CV_WRAP static Ptr<EM> create();
/** @brief Loads and creates a serialized EM from a file
*
* Use EM::save to serialize and store an EM to disk.
* Load the EM from this file again, by calling this function with the path to the file.
* Optionally specify the node for the file containing the classifier
*
* @param filepath path to serialized EM
* @param nodeName name of node containing the classifier
*/
CV_WRAP static Ptr<EM> load(const String& filepath , const String& nodeName = String());
}; };
/****************************************************************************************\ /****************************************************************************************\
@ -1512,6 +1534,17 @@ public:
Creates Logistic Regression model with parameters given. Creates Logistic Regression model with parameters given.
*/ */
CV_WRAP static Ptr<LogisticRegression> create(); CV_WRAP static Ptr<LogisticRegression> create();
/** @brief Loads and creates a serialized LogisticRegression from a file
*
* Use LogisticRegression::save to serialize and store an LogisticRegression to disk.
* Load the LogisticRegression from this file again, by calling this function with the path to the file.
* Optionally specify the node for the file containing the classifier
*
* @param filepath path to serialized LogisticRegression
* @param nodeName name of node containing the classifier
*/
CV_WRAP static Ptr<LogisticRegression> load(const String& filepath , const String& nodeName = String());
}; };
@ -1627,6 +1660,17 @@ public:
*/ */
CV_WRAP static Ptr<SVMSGD> create(); CV_WRAP static Ptr<SVMSGD> create();
/** @brief Loads and creates a serialized SVMSGD from a file
*
* Use SVMSGD::save to serialize and store an SVMSGD to disk.
* Load the SVMSGD from this file again, by calling this function with the path to the file.
* Optionally specify the node for the file containing the classifier
*
* @param filepath path to serialized SVMSGD
* @param nodeName name of node containing the classifier
*/
CV_WRAP static Ptr<SVMSGD> load(const String& filepath , const String& nodeName = String());
/** @brief Function sets optimal parameters values for chosen SVM SGD model. /** @brief Function sets optimal parameters values for chosen SVM SGD model.
* @param svmsgdType is the type of SVMSGD classifier. * @param svmsgdType is the type of SVMSGD classifier.
* @param marginType is the type of margin constraint. * @param marginType is the type of margin constraint.

@ -845,6 +845,11 @@ Ptr<EM> EM::create()
return makePtr<EMImpl>(); return makePtr<EMImpl>();
} }
Ptr<EM> EM::load(const String& filepath, const String& nodeName)
{
return Algorithm::load<EM>(filepath, nodeName);
}
} }
} // namespace cv } // namespace cv

@ -127,6 +127,12 @@ Ptr<LogisticRegression> LogisticRegression::create()
return makePtr<LogisticRegressionImpl>(); return makePtr<LogisticRegressionImpl>();
} }
Ptr<LogisticRegression> LogisticRegression::load(const String& filepath, const String& nodeName)
{
return Algorithm::load<LogisticRegression>(filepath, nodeName);
}
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int) bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
{ {
// return value // return value

@ -458,6 +458,11 @@ Ptr<NormalBayesClassifier> NormalBayesClassifier::create()
return p; return p;
} }
Ptr<NormalBayesClassifier> NormalBayesClassifier::load(const String& filepath, const String& nodeName)
{
return Algorithm::load<NormalBayesClassifier>(filepath, nodeName);
}
} }
} }

@ -134,6 +134,12 @@ Ptr<SVMSGD> SVMSGD::create()
return makePtr<SVMSGDImpl>(); return makePtr<SVMSGDImpl>();
} }
Ptr<SVMSGD> SVMSGD::load(const String& filepath, const String& nodeName)
{
return Algorithm::load<SVMSGD>(filepath, nodeName);
}
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier) void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
{ {
int featuresCount = samples.cols; int featuresCount = samples.cols;

Loading…
Cancel
Save