added updated logistic regression prototype with newer C++ API

pull/3119/head
Rahul Kavi 11 years ago committed by Maksim Shabunin
parent 0e13f33193
commit d5ad4f3255
  1. 85
      modules/ml/include/opencv2/ml.hpp

@ -571,81 +571,66 @@ public:
/****************************************************************************************\ /****************************************************************************************\
* Logistic Regression * * Logistic Regression *
\****************************************************************************************/ \****************************************************************************************/
namespace cv
struct CV_EXPORTS_W_MAP CvLR_TrainParams {
struct CV_EXPORTS LogisticRegressionParams
{ {
CV_PROP_RW double alpha; double alpha;
CV_PROP_RW int num_iters; int num_iters;
CV_PROP_RW int norm; int norm;
/////////////////////////////////////////////////// int regularized;
// CV_PROP_RW int debug; int train_method;
/////////////////////////////////////////////////// int mini_batch_size;
CV_PROP_RW int regularized; CvTermCriteria term_crit;
CV_PROP_RW int train_method;
CV_PROP_RW int minibatchsize; LogisticRegressionParams();
LogisticRegressionParams(double alpha, int num_iters, int norm, int regularized, int train_method, int minbatchsize);
CV_PROP_RW CvTermCriteria term_crit;
CvLR_TrainParams();
///////////////////////////////////////////////////
// CvLR_TrainParams(double alpha, int num_iters, int norm, int debug, int regularized, int train_method, int minbatchsize);
///////////////////////////////////////////////////
CvLR_TrainParams(double alpha, int num_iters, int norm, int regularized, int train_method, int minbatchsize);
~CvLR_TrainParams();
}; };
class CV_EXPORTS_W CvLR : public CvStatModel class CV_EXPORTS LogisticRegression
{ {
public: public:
CvLR();
// CvLR(const CvLR_TrainParams& Params);
CvLR(const cv::Mat& data, const cv::Mat& labels, const CvLR_TrainParams& params);
virtual ~CvLR();
enum { REG_L1=0, REG_L2 = 1}; LogisticRegression();
enum { BATCH, MINI_BATCH}; LogisticRegression(cv::InputArray data_ip, cv::InputArray labels_ip, const LogisticRegressionParams& params);
virtual ~LogisticRegression();
enum { REG_L1 = 0, REG_L2 = 1};
enum { BATCH = 0, MINI_BATCH = 1};
virtual bool train(const cv::Mat& data, const cv::Mat& labels);//, const CvLR_TrainParams& params); virtual bool train(cv::InputArray data_ip, cv::InputArray label_ip);
virtual void predict( cv::InputArray data, cv::OutputArray predicted_labels ) const;
virtual float predict(const cv::Mat& data, cv::Mat& predicted_labels); virtual void save(std::string filepath) const;
virtual float predict(const cv::Mat& data); virtual void load(const std::string filepath);
virtual void write( CvFileStorage* storage, const char* name ) const; cv::Mat get_learnt_thetas() const;
virtual void read( CvFileStorage* storage, CvFileNode* node );
virtual void clear();
virtual cv::Mat get_learnt_mat();
protected: protected:
LogisticRegressionParams params;
cv::Mat learnt_thetas; cv::Mat learnt_thetas;
CvLR_TrainParams params; std::string default_model_name;
std::map<int, int> forward_mapper; std::map<int, int> forward_mapper;
std::map<int, int> reverse_mapper; std::map<int, int> reverse_mapper;
virtual bool set_default_params(); cv::Mat labels_o;
virtual cv::Mat calc_sigmoid(const cv::Mat& data); cv::Mat labels_n;
static cv::Mat calc_sigmoid(const cv::Mat& data);
virtual double compute_cost(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); virtual double compute_cost(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
virtual cv::Mat compute_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); virtual cv::Mat compute_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta); virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
virtual std::map<int, int> get_label_map(const cv::Mat& labels);
virtual bool set_label_map(const cv::Mat& labels); virtual bool set_label_map(const cv::Mat& labels);
virtual cv::Mat remap_labels(const cv::Mat& labels, const std::map<int, int> lmap); static cv::Mat remap_labels(const cv::Mat& labels, const std::map<int, int>& lmap);
//cv::Mat Mapper; virtual void write(FileStorage& fs) const;
virtual void read(const FileNode& fn);
cv::Mat labels_o; virtual void clear();
cv::Mat labels_n;
}; };
}// namespace cv
/****************************************************************************************\ /****************************************************************************************\
* Auxilary functions declarations * * Auxilary functions declarations *

Loading…
Cancel
Save