From d5ad4f32556dd4e2f12ca99aed3d51d256463b51 Mon Sep 17 00:00:00 2001 From: Rahul Kavi Date: Fri, 4 Oct 2013 08:34:01 -0400 Subject: [PATCH] added updated logistic regression prototype with newer C++ API --- modules/ml/include/opencv2/ml.hpp | 87 +++++++++++++------------------ 1 file changed, 36 insertions(+), 51 deletions(-) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index ea90538a92..e424f2b499 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -571,81 +571,66 @@ public: /****************************************************************************************\ * Logistic Regression * \****************************************************************************************/ - -struct CV_EXPORTS_W_MAP CvLR_TrainParams +namespace cv +{ +struct CV_EXPORTS LogisticRegressionParams { - CV_PROP_RW double alpha; - CV_PROP_RW int num_iters; - CV_PROP_RW int norm; - /////////////////////////////////////////////////// - // CV_PROP_RW int debug; - /////////////////////////////////////////////////// - CV_PROP_RW int regularized; - CV_PROP_RW int train_method; - CV_PROP_RW int minibatchsize; - - CV_PROP_RW CvTermCriteria term_crit; - - CvLR_TrainParams(); - /////////////////////////////////////////////////// - // CvLR_TrainParams(double alpha, int num_iters, int norm, int debug, int regularized, int train_method, int minbatchsize); - /////////////////////////////////////////////////// - CvLR_TrainParams(double alpha, int num_iters, int norm, int regularized, int train_method, int minbatchsize); - ~CvLR_TrainParams(); + double alpha; + int num_iters; + int norm; + int regularized; + int train_method; + int mini_batch_size; + CvTermCriteria term_crit; + + LogisticRegressionParams(); + LogisticRegressionParams(double alpha, int num_iters, int norm, int regularized, int train_method, int minbatchsize); }; -class CV_EXPORTS_W CvLR : public CvStatModel +class CV_EXPORTS LogisticRegression { public: - CvLR(); - // CvLR(const CvLR_TrainParams& Params); - - CvLR(const cv::Mat& data, const cv::Mat& labels, const CvLR_TrainParams& params); - - virtual ~CvLR(); - enum { REG_L1=0, REG_L2 = 1}; - enum { BATCH, MINI_BATCH}; + LogisticRegression(); + LogisticRegression(cv::InputArray data_ip, cv::InputArray labels_ip, const LogisticRegressionParams& params); + virtual ~LogisticRegression(); + enum { REG_L1 = 0, REG_L2 = 1}; + enum { BATCH = 0, MINI_BATCH = 1}; - virtual bool train(const cv::Mat& data, const cv::Mat& labels);//, const CvLR_TrainParams& params); + virtual bool train(cv::InputArray data_ip, cv::InputArray label_ip); + virtual void predict( cv::InputArray data, cv::OutputArray predicted_labels ) const; - virtual float predict(const cv::Mat& data, cv::Mat& predicted_labels); - virtual float predict(const cv::Mat& data); + virtual void save(std::string filepath) const; + virtual void load(const std::string filepath); - virtual void write( CvFileStorage* storage, const char* name ) const; - virtual void read( CvFileStorage* storage, CvFileNode* node ); - - virtual void clear(); - - virtual cv::Mat get_learnt_mat(); + cv::Mat get_learnt_thetas() const; protected: - cv::Mat learnt_thetas; - CvLR_TrainParams params; - + LogisticRegressionParams params; + cv::Mat learnt_thetas; + std::string default_model_name; std::map forward_mapper; std::map reverse_mapper; - virtual bool set_default_params(); - virtual cv::Mat calc_sigmoid(const cv::Mat& data); + cv::Mat labels_o; + cv::Mat labels_n; + + static cv::Mat calc_sigmoid(const cv::Mat& data); virtual double compute_cost(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); virtual cv::Mat compute_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); - - virtual std::map get_label_map(const cv::Mat& labels); - virtual bool set_label_map(const cv::Mat& labels); - virtual cv::Mat remap_labels(const cv::Mat& labels, const std::map lmap); + static cv::Mat remap_labels(const cv::Mat& labels, const std::map& lmap); - //cv::Mat Mapper; - - cv::Mat labels_o; - cv::Mat labels_n; + virtual void write(FileStorage& fs) const; + virtual void read(const FileNode& fn); + virtual void clear(); }; +}// namespace cv /****************************************************************************************\ * Auxilary functions declarations *