From fcfeb2451b32b7ea65e55104297529fb83cfed10 Mon Sep 17 00:00:00 2001 From: Rahul Kavi Date: Mon, 5 Aug 2013 09:42:07 -0400 Subject: [PATCH] added logistic regression prototype --- modules/ml/include/opencv2/ml.hpp | 84 +++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index a5ce3010bf..ea90538a92 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -89,6 +89,8 @@ public: CV_PROP_RW double maxVal; CV_PROP_RW double logStep; }; +#define CV_TYPE_NAME_ML_LR "opencv-ml-lr" + class CV_EXPORTS TrainData @@ -566,6 +568,85 @@ public: static Ptr create(const Params& params=Params()); }; +/****************************************************************************************\ +* Logistic Regression * +\****************************************************************************************/ + +struct CV_EXPORTS_W_MAP CvLR_TrainParams +{ + CV_PROP_RW double alpha; + CV_PROP_RW int num_iters; + CV_PROP_RW int norm; + /////////////////////////////////////////////////// + // CV_PROP_RW int debug; + /////////////////////////////////////////////////// + CV_PROP_RW int regularized; + CV_PROP_RW int train_method; + CV_PROP_RW int minibatchsize; + + CV_PROP_RW CvTermCriteria term_crit; + + CvLR_TrainParams(); + /////////////////////////////////////////////////// + // CvLR_TrainParams(double alpha, int num_iters, int norm, int debug, int regularized, int train_method, int minbatchsize); + /////////////////////////////////////////////////// + CvLR_TrainParams(double alpha, int num_iters, int norm, int regularized, int train_method, int minbatchsize); + ~CvLR_TrainParams(); +}; + +class CV_EXPORTS_W CvLR : public CvStatModel +{ +public: + CvLR(); + // CvLR(const CvLR_TrainParams& Params); + + CvLR(const cv::Mat& data, const cv::Mat& labels, const CvLR_TrainParams& params); + + virtual ~CvLR(); + + enum { REG_L1=0, REG_L2 = 1}; + enum { BATCH, MINI_BATCH}; + + + virtual bool train(const cv::Mat& data, const cv::Mat& labels);//, const CvLR_TrainParams& params); + + virtual float predict(const cv::Mat& data, cv::Mat& predicted_labels); + virtual float predict(const cv::Mat& data); + + virtual void write( CvFileStorage* storage, const char* name ) const; + virtual void read( CvFileStorage* storage, CvFileNode* node ); + + virtual void clear(); + + virtual cv::Mat get_learnt_mat(); + +protected: + + cv::Mat learnt_thetas; + CvLR_TrainParams params; + + std::map forward_mapper; + std::map reverse_mapper; + + virtual bool set_default_params(); + virtual cv::Mat calc_sigmoid(const cv::Mat& data); + + virtual double compute_cost(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); + virtual cv::Mat compute_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); + virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); + + virtual std::map get_label_map(const cv::Mat& labels); + + virtual bool set_label_map(const cv::Mat& labels); + virtual cv::Mat remap_labels(const cv::Mat& labels, const std::map lmap); + + //cv::Mat Mapper; + + cv::Mat labels_o; + cv::Mat labels_n; + +}; + /****************************************************************************************\ * Auxilary functions declarations * \****************************************************************************************/ @@ -581,6 +662,9 @@ CV_EXPORTS void randGaussMixture( InputArray means, InputArray covs, InputArray /* creates test set */ CV_EXPORTS void createConcentricSpheresTestSet( int nsamples, int nfeatures, int nclasses, OutputArray samples, OutputArray responses); +typedef CvLR_TrainParams LogisticRegression_TrainParams; +typedef CvLR LogisticRegression; + } }