diff --git a/include/opencv2/opencv.hpp b/include/opencv2/opencv.hpp index 49b6a6691f..e411621d0c 100644 --- a/include/opencv2/opencv.hpp +++ b/include/opencv2/opencv.hpp @@ -75,6 +75,7 @@ #endif #ifdef HAVE_OPENCV_ML #include "opencv2/ml.hpp" +#include "opencv2/ml/svmsgd.hpp" #endif #endif diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index d5debdbf18..791f580093 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -1513,126 +1513,6 @@ CV_EXPORTS void randMVNormal( InputArray mean, InputArray cov, int nsamples, Out CV_EXPORTS void createConcentricSpheresTestSet( int nsamples, int nfeatures, int nclasses, OutputArray samples, OutputArray responses); -/****************************************************************************************\ -* Stochastic Gradient Descent SVM Classifier * -\****************************************************************************************/ - -/*! -@brief Stochastic Gradient Descent SVM classifier - -SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large. -The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which -is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example). - -First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined. - -Then the SVM model can be trained using the train features and the correspondent labels. - -After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically. - -@code -// Initialize object -SVMSGD SvmSgd; - -// Train the Stochastic Gradient Descent SVM -SvmSgd.train(trainFeatures, labels); - -// Predict label for the new feature vector (1xM) -predictedLabel = SvmSgd.predict(newFeatureVector); -@endcode - -*/ -class CV_EXPORTS_W SVMSGD { - - public: - /** @brief SGDSVM constructor. - - @param lambda regularization - @param learnRate learning rate - @param nIterations number of training iterations - - */ - SVMSGD(float lambda = 0.000001, float learnRate = 2, uint nIterations = 100000); - - /** @brief SGDSVM constructor. - - @param updateFrequency online update frequency - @param learnRateDecay learn rate decay over time: learnRate = learnRate * learnDecay - @param lambda regularization - @param learnRate learning rate - @param nIterations number of training iterations - - */ - SVMSGD(uint updateFrequency, float learnRateDecay = 1, float lambda = 0.000001, float learnRate = 2, uint nIterations = 100000); - virtual ~SVMSGD(); - virtual SVMSGD* clone() const; - - /** @brief Train the SGDSVM classifier. - - The function trains the SGDSVM classifier using the train features and the correspondent labels (-1 or 1). - - @param trainFeatures features used for training. Each row is a new sample. - @param labels mat (size Nx1 with N = number of features) with the label of each training feature. - - */ - virtual void train(cv::Mat trainFeatures, cv::Mat labels); - - /** @brief Predict the label of a new feature vector. - - The function predicts and returns the label of a new feature vector, using the previously trained SVM model. - - @param newFeature new feature vector used for prediction - - */ - virtual float predict(cv::Mat newFeature); - - /** @brief Returns the weights of the trained model. - - */ - virtual std::vector getWeights(){ return _weights; }; - - /** @brief Sets the weights of the trained model. - - @param weights weights used to predict the label of a new feature vector. - - */ - virtual void setWeights(std::vector weights){ _weights = weights; }; - - private: - void updateWeights(); - void generateRandomIndex(); - float calcInnerProduct(float *rowDataPointer); - void updateWeights(float innerProduct, float *rowDataPointer, int label); - - // Vector with SVM weights - std::vector _weights; - - // Random index generation - long long int _randomNumber; - unsigned int _randomIndex; - - // Number of features and samples - unsigned int _nFeatures; - unsigned int _nTrainSamples; - - // Parameters for learning - float _lambda; //regularization - float _learnRate; //learning rate - unsigned int _nIterations; //number of training iterations - - // Vars to control the features slider matrix - bool _onlineUpdate; - bool _initPredict; - uint _slidingWindowSize; - uint _predictSlidingWindowSize; - float* _labelSlider; - float _learnRateDecay; - - // Mat with features slider and correspondent counter - unsigned int _sliderCounter; - cv::Mat _featuresSlider; - -}; //! @} ml diff --git a/modules/ml/include/opencv2/ml/svmsgd.hpp b/modules/ml/include/opencv2/ml/svmsgd.hpp new file mode 100644 index 0000000000..f61a905963 --- /dev/null +++ b/modules/ml/include/opencv2/ml/svmsgd.hpp @@ -0,0 +1,134 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2000, Intel Corporation, all rights reserved. +// Copyright (C) 2013, OpenCV Foundation, all rights reserved. +// Copyright (C) 2014, Itseez Inc, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of the copyright holders may not be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +#ifndef __OPENCV_ML_SVMSGD_HPP__ +#define __OPENCV_ML_SVMSGD_HPP__ + +#ifdef __cplusplus + +#include "opencv2/ml.hpp" + +namespace cv +{ +namespace ml +{ + + +/****************************************************************************************\ +* Stochastic Gradient Descent SVM Classifier * +\****************************************************************************************/ + +/*! +@brief Stochastic Gradient Descent SVM classifier + +SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large. +The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which +is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example). + +First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined. + +Then the SVM model can be trained using the train features and the correspondent labels. + +After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically. + +@code +// Initialize object +SVMSGD SvmSgd; + +// Train the Stochastic Gradient Descent SVM +SvmSgd.train(trainFeatures, labels); + +// Predict label for the new feature vector (1xM) +predictedLabel = SvmSgd.predict(newFeatureVector); +@endcode + +*/ + +class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel +{ +public: + + enum SvmsgdType + { + ILLEGAL_VALUE, + SGD, //Stochastic Gradient Descent + ASGD //Average Stochastic Gradient Descent + }; + + /** + * @return the weights of the trained model. + */ + CV_WRAP virtual Mat getWeights() = 0; + + CV_WRAP virtual float getShift() = 0; + + CV_WRAP static Ptr create(); + + CV_WRAP virtual void setOptimalParameters(int type = ASGD) = 0; + + CV_WRAP virtual int getType() const = 0; + + CV_WRAP virtual void setType(int type) = 0; + + CV_WRAP virtual float getLambda() const = 0; + + CV_WRAP virtual void setLambda(float lambda) = 0; + + CV_WRAP virtual float getGamma0() const = 0; + + CV_WRAP virtual void setGamma0(float gamma0) = 0; + + CV_WRAP virtual float getC() const = 0; + + CV_WRAP virtual void setC(float c) = 0; + + CV_WRAP virtual cv::TermCriteria getTermCriteria() const = 0; + + CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0; +}; + +} //ml +} //cv + +#endif // __clpusplus +#endif // __OPENCV_ML_SVMSGD_HPP diff --git a/modules/ml/src/precomp.hpp b/modules/ml/src/precomp.hpp index 84821988b6..9318e4a78c 100644 --- a/modules/ml/src/precomp.hpp +++ b/modules/ml/src/precomp.hpp @@ -45,7 +45,7 @@ #include "opencv2/ml.hpp" #include "opencv2/core/core_c.h" #include "opencv2/core/utility.hpp" - +#include "opencv2/ml/svmsgd.hpp" #include "opencv2/core/private.hpp" #include diff --git a/modules/ml/src/svmsgd.cpp b/modules/ml/src/svmsgd.cpp index 3114e43d9f..91377cfc4f 100644 --- a/modules/ml/src/svmsgd.cpp +++ b/modules/ml/src/svmsgd.cpp @@ -41,161 +41,430 @@ //M*/ #include "precomp.hpp" +#include "limits" /****************************************************************************************\ * Stochastic Gradient Descent SVM Classifier * \****************************************************************************************/ -namespace cv { -namespace ml { +namespace cv +{ +namespace ml +{ -SVMSGD::SVMSGD(float lambda, float learnRate, uint nIterations){ +class SVMSGDImpl : public SVMSGD +{ - // Initialize with random seed - _randomNumber = 1; +public: + SVMSGDImpl(); - // Initialize constants - _slidingWindowSize = 0; - _nFeatures = 0; - _predictSlidingWindowSize = 1; + virtual ~SVMSGDImpl() {} - // Initialize sliderCounter at index 0 - _sliderCounter = 0; + virtual bool train(const Ptr& data, int); - // Parameters for learning - _lambda = lambda; // regularization - _learnRate = learnRate; // learning rate (ideally should be large at beginning and decay each iteration) - _nIterations = nIterations; // number of training iterations + virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const; - // True only in the first predict iteration - _initPredict = true; + virtual bool isClassifier() const { return params.svmsgdType == SGD || params.svmsgdType == ASGD; } - // Online update flag - _onlineUpdate = false; -} + virtual bool isTrained() const; -SVMSGD::SVMSGD(uint updateFrequency, float learnRateDecay, float lambda, float learnRate, uint nIterations){ + virtual void clear(); - // Initialize with random seed - _randomNumber = 1; + virtual void write(FileStorage& fs) const; - // Initialize constants - _slidingWindowSize = 0; - _nFeatures = 0; - _predictSlidingWindowSize = updateFrequency; + virtual void read(const FileNode& fn); - // Initialize sliderCounter at index 0 - _sliderCounter = 0; + virtual Mat getWeights(){ return weights_; } - // Parameters for learning - _lambda = lambda; // regularization - _learnRate = learnRate; // learning rate (ideally should be large at beginning and decay each iteration) - _nIterations = nIterations; // number of training iterations + virtual float getShift(){ return shift_; } - // True only in the first predict iteration - _initPredict = true; + virtual int getVarCount() const { return weights_.cols; } - // Online update flag - _onlineUpdate = true; + virtual String getDefaultName() const {return "opencv_ml_svmsgd";} - // Learn rate decay: _learnRate = _learnRate * _learnDecay - _learnRateDecay = learnRateDecay; -} + virtual void setOptimalParameters(int type = ASGD); -SVMSGD::~SVMSGD(){ + virtual int getType() const; -} + virtual void setType(int type); + + CV_IMPL_PROPERTY(float, Lambda, params.lambda) + CV_IMPL_PROPERTY(float, Gamma0, params.gamma0) + CV_IMPL_PROPERTY(float, C, params.c) + CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit) -SVMSGD* SVMSGD::clone() const{ - return new SVMSGD(*this); + private: + void updateWeights(InputArray sample, bool is_first_class, float gamma); + float calcShift(InputArray trainSamples, InputArray trainResponses) const; + std::pair areClassesEmpty(Mat responses); + void writeParams( FileStorage& fs ) const; + void readParams( const FileNode& fn ); + static inline bool isFirstClass(float val) { return val > 0; } + + + // Vector with SVM weights + Mat weights_; + float shift_; + + // Random index generation + RNG rng_; + + // Parameters for learning + struct SVMSGDParams + { + float lambda; //regularization + float gamma0; //learning rate + float c; + TermCriteria termCrit; + SvmsgdType svmsgdType; + }; + + SVMSGDParams params; +}; + +Ptr SVMSGD::create() +{ + return makePtr(); } -void SVMSGD::train(cv::Mat trainFeatures, cv::Mat labels){ - // Initialize _nFeatures - _slidingWindowSize = trainFeatures.rows; - _nFeatures = trainFeatures.cols; +bool SVMSGDImpl::train(const Ptr& data, int) +{ + clear(); + + Mat trainSamples = data->getTrainSamples(); + + // Initialize varCount + int trainSamplesCount_ = trainSamples.rows; + int varCount = trainSamples.cols; - float innerProduct; // Initialize weights vector with zeros - if (_weights.size()==0){ - _weights.reserve(_nFeatures); - for (uint feat = 0; feat < _nFeatures; ++feat){ - _weights.push_back(0.0); - } + weights_ = Mat::zeros(1, varCount, CV_32F); + + Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix + + std::pair are_empty = areClassesEmpty(trainResponses); + + if ( are_empty.first && are_empty.second ) + { + weights_.release(); + return false; + } + if ( are_empty.first || are_empty.second ) + { + shift_ = are_empty.first ? -1 : 1; + return true; + } + + + Mat currentSample; + float gamma = 0; + Mat lastWeights = Mat::zeros(1, varCount, CV_32F); //weights vector for calculating terminal criterion + Mat averageWeights; //average weights vector for ASGD model + double err = DBL_MAX; + if (params.svmsgdType == ASGD) + { + averageWeights = Mat::zeros(1, varCount, CV_32F); } // Stochastic gradient descent SVM - for (uint iter = 0; iter < _nIterations; ++iter){ - generateRandomIndex(); - innerProduct = calcInnerProduct(trainFeatures.ptr(_randomIndex)); - int label = (labels.at(_randomIndex,0) > 0) ? 1 : -1; // ensure that labels are -1 or 1 - updateWeights(innerProduct, trainFeatures.ptr(_randomIndex), label ); + for (int iter = 0; (iter < params.termCrit.maxCount)&&(err > params.termCrit.epsilon); iter++) + { + //generate sample number + int randomNumber = rng_.uniform(0, trainSamplesCount_); + + currentSample = trainSamples.row(randomNumber); + + //update gamma + gamma = params.gamma0 * std::pow((1 + params.lambda * params.gamma0 * (float)iter), (-params.c)); + + bool is_first_class = isFirstClass(trainResponses.at(randomNumber)); + updateWeights( currentSample, is_first_class, gamma ); + + //average weights (only for ASGD model) + if (params.svmsgdType == ASGD) + { + averageWeights = ((float)iter/ (1 + (float)iter)) * averageWeights + weights_ / (1 + (float) iter); + } + + err = norm(weights_ - lastWeights); + weights_.copyTo(lastWeights); + } + + if (params.svmsgdType == ASGD) + { + weights_ = averageWeights; } + + shift_ = calcShift(trainSamples, trainResponses); + + return true; } -float SVMSGD::predict(cv::Mat newFeature){ - float innerProduct; +std::pair SVMSGDImpl::areClassesEmpty(Mat responses) +{ + std::pair are_classes_empty(true, true); + int limit_index = responses.rows; + + for(int index = 0; index < limit_index; index++) + { + if (isFirstClass(responses.at(index,0))) + are_classes_empty.first = false; + else + are_classes_empty.second = false; - if (_initPredict){ - _nFeatures = newFeature.cols; - _slidingWindowSize = _predictSlidingWindowSize; - _featuresSlider = cv::Mat::zeros(_slidingWindowSize, _nFeatures, CV_32F); - _initPredict = false; - _labelSlider = new float[_predictSlidingWindowSize](); - _learnRate = _learnRate * _learnRateDecay; + if (!are_classes_empty.first && ! are_classes_empty.second) + break; } - innerProduct = calcInnerProduct(newFeature.ptr(0)); + return are_classes_empty; +} - // Resultant label (-1 or 1) - int label = (innerProduct>=0) ? 1 : -1; +float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const +{ + float distance_to_classes[2] = { std::numeric_limits::max(), std::numeric_limits::max() }; - if (_onlineUpdate){ - // Update the featuresSlider with newFeature and _labelSlider with label - newFeature.row(0).copyTo(_featuresSlider.row(_sliderCounter)); - _labelSlider[_sliderCounter] = float(label); + Mat trainSamples = _samples.getMat(); + int trainSamplesCount = trainSamples.rows; - // Update weights with a random index - if (_sliderCounter == _slidingWindowSize-1){ - generateRandomIndex(); - updateWeights(innerProduct, _featuresSlider.ptr(_randomIndex), int(_labelSlider[_randomIndex]) ); - } + Mat trainResponses = _responses.getMat(); + + for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++) + { + Mat currentSample = trainSamples.row(samplesIndex); + float scalar_product = currentSample.dot(weights_); - // _sliderCounter++ if < _slidingWindowSize - _sliderCounter = (_sliderCounter == _slidingWindowSize-1) ? 0 : (_sliderCounter+1); + bool is_first_class = isFirstClass(trainResponses.at(samplesIndex)); + int index = is_first_class ? 0:1; + float sign_to_mul = is_first_class ? 1 : -1; + float cur_distance = scalar_product * sign_to_mul ; + + if (cur_distance < distance_to_classes[index]) + { + distance_to_classes[index] = cur_distance; + } } - return float(label); + //todo: areClassesEmpty(); make const; + return -(distance_to_classes[0] - distance_to_classes[1]) / 2.f; } -void SVMSGD::generateRandomIndex(){ - // Choose random sample, using Mikolov's fast almost-uniform random number - _randomNumber = _randomNumber * (unsigned long long) 25214903917 + 11; - _randomIndex = uint(_randomNumber % (unsigned long long) _slidingWindowSize); -} +float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const +{ + float result = 0; + cv::Mat samples = _samples.getMat(); + int nSamples = samples.rows; + cv::Mat results; -float SVMSGD::calcInnerProduct(float *rowDataPointer){ - float innerProduct = 0; - for (uint feat = 0; feat < _nFeatures; ++feat){ - innerProduct += _weights[feat] * rowDataPointer[feat]; + CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32F ); + + if( _results.needed() ) + { + _results.create( nSamples, 1, samples.type() ); + results = _results.getMat(); + } + else + { + CV_Assert( nSamples == 1 ); + results = Mat(1, 1, CV_32F, &result); } - return innerProduct; + + Mat currentSample; + float criterion = 0; + + for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++) + { + currentSample = samples.row(sampleIndex); + criterion = currentSample.dot(weights_) + shift_; + results.at(sampleIndex) = (criterion >= 0) ? 1 : -1; + } + + return result; } -void SVMSGD::updateWeights(float innerProduct, float *rowDataPointer, int label){ - if (label * innerProduct > 1) { +void SVMSGDImpl::updateWeights(InputArray _sample, bool is_first_class, float gamma) +{ + Mat sample = _sample.getMat(); + + int responce = is_first_class ? 1 : -1; // ensure that trainResponses are -1 or 1 + + if ( sample.dot(weights_) * responce > 1) + { // Not a support vector, only apply weight decay - for (uint feat = 0; feat < _nFeatures; feat++) { - _weights[feat] -= _learnRate * _lambda * _weights[feat]; - } - } else { + weights_ *= (1.f - gamma * params.lambda); + } + else + { // It's a support vector, add it to the weights - for (uint feat = 0; feat < _nFeatures; feat++) { - _weights[feat] -= _learnRate * (_lambda * _weights[feat] - label * rowDataPointer[feat]); - } + weights_ -= (gamma * params.lambda) * weights_ - gamma * responce * sample; + //std::cout << "sample " << sample << std::endl; + //std::cout << "weights_ " << weights_ << std::endl; + } +} + +bool SVMSGDImpl::isTrained() const +{ + return !weights_.empty(); +} + +void SVMSGDImpl::write(FileStorage& fs) const +{ + if( !isTrained() ) + CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" ); + + writeParams( fs ); + + fs << "shift" << shift_; + fs << "weights" << weights_; +} + +void SVMSGDImpl::writeParams( FileStorage& fs ) const +{ + String SvmsgdTypeStr; + + switch (params.svmsgdType) + { + case SGD: + SvmsgdTypeStr = "SGD"; + break; + case ASGD: + SvmsgdTypeStr = "ASGD"; + break; + case ILLEGAL_VALUE: + SvmsgdTypeStr = format("Uknown_%d", params.svmsgdType); + default: + std::cout << "params.svmsgdType isn't initialized" << std::endl; + } + + + fs << "svmsgdType" << SvmsgdTypeStr; + + fs << "lambda" << params.lambda; + fs << "gamma0" << params.gamma0; + fs << "c" << params.c; + + fs << "term_criteria" << "{:"; + if( params.termCrit.type & TermCriteria::EPS ) + fs << "epsilon" << params.termCrit.epsilon; + if( params.termCrit.type & TermCriteria::COUNT ) + fs << "iterations" << params.termCrit.maxCount; + fs << "}"; +} + + + +void SVMSGDImpl::read(const FileNode& fn) +{ + clear(); + + readParams(fn); + + shift_ = (float) fn["shift"]; + fn["weights"] >> weights_; +} + +void SVMSGDImpl::readParams( const FileNode& fn ) +{ + String svmsgdTypeStr = (String)fn["svmsgdType"]; + SvmsgdType svmsgdType = + svmsgdTypeStr == "SGD" ? SGD : + svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_VALUE; + + if( svmsgdType == ILLEGAL_VALUE ) + CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" ); + + params.svmsgdType = svmsgdType; + + CV_Assert ( fn["lambda"].isReal() ); + params.lambda = (float)fn["lambda"]; + + CV_Assert ( fn["gamma0"].isReal() ); + params.gamma0 = (float)fn["gamma0"]; + + CV_Assert ( fn["c"].isReal() ); + params.c = (float)fn["c"]; + + FileNode tcnode = fn["term_criteria"]; + if( !tcnode.empty() ) + { + params.termCrit.epsilon = (double)tcnode["epsilon"]; + params.termCrit.maxCount = (int)tcnode["iterations"]; + params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) + + (params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0); } + else + params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON ); + } +void SVMSGDImpl::clear() +{ + weights_.release(); + shift_ = 0; } + + +SVMSGDImpl::SVMSGDImpl() +{ + clear(); + rng_(0); + + params.svmsgdType = ILLEGAL_VALUE; + + // Parameters for learning + params.lambda = 0; // regularization + params.gamma0 = 0; // learning rate (ideally should be large at beginning and decay each iteration) + params.c = 0; + + TermCriteria _termCrit(TermCriteria::COUNT + TermCriteria::EPS, 0, 0); + params.termCrit = _termCrit; +} + +void SVMSGDImpl::setOptimalParameters(int type) +{ + switch (type) + { + case SGD: + params.svmsgdType = SGD; + params.lambda = 0.00001; + params.gamma0 = 0.05; + params.c = 1; + params.termCrit.maxCount = 50000; + params.termCrit.epsilon = 0.00000001; + break; + + case ASGD: + params.svmsgdType = ASGD; + params.lambda = 0.00001; + params.gamma0 = 0.5; + params.c = 0.75; + params.termCrit.maxCount = 100000; + params.termCrit.epsilon = 0.000001; + break; + + default: + CV_Error( CV_StsParseError, "SVMSGD model data is invalid" ); + } +} + +void SVMSGDImpl::setType(int type) +{ + switch (type) + { + case SGD: + params.svmsgdType = SGD; + break; + case ASGD: + params.svmsgdType = ASGD; + break; + default: + params.svmsgdType = ILLEGAL_VALUE; + } +} + +int SVMSGDImpl::getType() const +{ + return params.svmsgdType; } +} //ml +} //cv diff --git a/modules/ml/test/test_mltests2.cpp b/modules/ml/test/test_mltests2.cpp index 919fae6ce4..6603a35c5b 100644 --- a/modules/ml/test/test_mltests2.cpp +++ b/modules/ml/test/test_mltests2.cpp @@ -193,6 +193,16 @@ int str_to_boost_type( String& str ) // 8. rtrees // 9. ertrees +int str_to_svmsgd_type( String& str ) +{ + if ( !str.compare("SGD") ) + return SVMSGD::SGD; + if ( !str.compare("ASGD") ) + return SVMSGD::ASGD; + CV_Error( CV_StsBadArg, "incorrect boost type string" ); + return -1; +} + // ---------------------------------- MLBaseTest --------------------------------------------------- CV_MLBaseTest::CV_MLBaseTest(const char* _modelName) @@ -248,7 +258,9 @@ void CV_MLBaseTest::run( int ) { string filename = ts->get_data_path(); filename += get_validation_filename(); + validationFS.open( filename, FileStorage::READ ); + read_params( *validationFS ); int code = cvtest::TS::OK; @@ -436,6 +448,21 @@ int CV_MLBaseTest::train( int testCaseIdx ) model = m; } + else if( modelName == CV_SVMSGD ) + { + String svmsgdTypeStr; + modelParamsNode["svmsgdType"] >> svmsgdTypeStr; + Ptr m = SVMSGD::create(); + int type = str_to_svmsgd_type( svmsgdTypeStr ); + m->setType(type); + //m->setType(str_to_svmsgd_type( svmsgdTypeStr )); + m->setLambda(modelParamsNode["lambda"]); + m->setGamma0(modelParamsNode["gamma0"]); + m->setC(modelParamsNode["c"]); + m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001)); + model = m; + } + if( !model.empty() ) is_trained = model->train(data, 0); @@ -457,7 +484,7 @@ float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector *resp ) else if( modelName == CV_ANN ) err = ann_calc_error( model, data, cls_map, type, resp ); else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES || - modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST ) + modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD ) err = model->calcError( data, true, _resp ); if( !_resp.empty() && resp ) _resp.convertTo(*resp, CV_32F); @@ -485,6 +512,8 @@ void CV_MLBaseTest::load( const char* filename ) model = Algorithm::load( filename ); else if( modelName == CV_RTREES ) model = Algorithm::load( filename ); + else if( modelName == CV_SVMSGD ) + model = Algorithm::load( filename ); else CV_Error( CV_StsNotImplemented, "invalid stat model name"); } diff --git a/modules/ml/test/test_precomp.hpp b/modules/ml/test/test_precomp.hpp index 329b9bd6c0..18cee968fb 100644 --- a/modules/ml/test/test_precomp.hpp +++ b/modules/ml/test/test_precomp.hpp @@ -13,6 +13,7 @@ #include #include "opencv2/ts.hpp" #include "opencv2/ml.hpp" +#include "opencv2/ml/svmsgd.hpp" #include "opencv2/core/core_c.h" #define CV_NBAYES "nbayes" @@ -24,6 +25,7 @@ #define CV_BOOST "boost" #define CV_RTREES "rtrees" #define CV_ERTREES "ertrees" +#define CV_SVMSGD "svmsgd" enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 }; @@ -38,6 +40,7 @@ using cv::ml::ANN_MLP; using cv::ml::DTrees; using cv::ml::Boost; using cv::ml::RTrees; +using cv::ml::SVMSGD; class CV_MLBaseTest : public cvtest::BaseTest { diff --git a/modules/ml/test/test_save_load.cpp b/modules/ml/test/test_save_load.cpp index 2d6f144bb9..354c6e0307 100644 --- a/modules/ml/test/test_save_load.cpp +++ b/modules/ml/test/test_save_load.cpp @@ -150,12 +150,20 @@ int CV_SLMLTest::validate_test_results( int testCaseIdx ) TEST(ML_NaiveBayes, save_load) { CV_SLMLTest test( CV_NBAYES ); test.safe_run(); } TEST(ML_KNearest, save_load) { CV_SLMLTest test( CV_KNEAREST ); test.safe_run(); } -TEST(ML_SVM, save_load) { CV_SLMLTest test( CV_SVM ); test.safe_run(); } +TEST(ML_SVM, save_load) +{ + CV_SLMLTest test( CV_SVM ); + test.safe_run(); +} TEST(ML_ANN, save_load) { CV_SLMLTest test( CV_ANN ); test.safe_run(); } TEST(ML_DTree, save_load) { CV_SLMLTest test( CV_DTREE ); test.safe_run(); } TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); } TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); } TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); } +TEST(MV_SVMSGD, save_load){ + CV_SLMLTest test( CV_SVMSGD ); + test.safe_run(); +} class CV_LegacyTest : public cvtest::BaseTest { @@ -201,6 +209,8 @@ protected: model = Algorithm::load(filename); else if (modelName == CV_RTREES) model = Algorithm::load(filename); + else if (modelName == CV_SVMSGD) + model = Algorithm::load(filename); if (!model) { code = cvtest::TS::FAIL_INVALID_TEST_DATA; @@ -260,6 +270,11 @@ TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushro TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); } TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); } TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); } +TEST(ML_SVMSGD, legacy_load) +{ + CV_LegacyTest test(CV_SVMSGD, "_waveform.xml"); + test.safe_run(); +} /*TEST(ML_SVM, throw_exception_when_save_untrained_model) { diff --git a/modules/ml/test/test_svmsgd.cpp b/modules/ml/test/test_svmsgd.cpp new file mode 100644 index 0000000000..9f4aafc08b --- /dev/null +++ b/modules/ml/test/test_svmsgd.cpp @@ -0,0 +1,182 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// Intel License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2000, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of Intel Corporation may not be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +#include "test_precomp.hpp" +#include "opencv2/highgui.hpp" + +using namespace cv; +using namespace cv::ml; +using cv::ml::SVMSGD; +using cv::ml::TrainData; + + + +class CV_SVMSGDTrainTest : public cvtest::BaseTest +{ +public: + CV_SVMSGDTrainTest(Mat _weights, float _shift); +private: + virtual void run( int start_from ); + float decisionFunction(Mat sample, Mat weights, float shift); + + cv::Ptr data; + cv::Mat testSamples; + cv::Mat testResponses; + static const int TEST_VALUE_LIMIT = 50; +}; + +CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift) +{ + int datasize = 100000; + int varCount = weights.cols; + cv::Mat samples = cv::Mat::zeros( datasize, varCount, CV_32FC1 ); + cv::Mat responses = cv::Mat::zeros( datasize, 1, CV_32FC1 ); + cv::RNG rng(0); + + float lowerLimit = -TEST_VALUE_LIMIT; + float upperLimit = TEST_VALUE_LIMIT; + + + rng.fill(samples, RNG::UNIFORM, lowerLimit, upperLimit); + for (int sampleIndex = 0; sampleIndex < datasize; sampleIndex++) + { + responses.at( sampleIndex ) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1; + } + + data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses ); + + int testSamplesCount = 100000; + + testSamples.create(testSamplesCount, varCount, CV_32FC1); + rng.fill(testSamples, RNG::UNIFORM, lowerLimit, upperLimit); + testResponses.create(testSamplesCount, 1, CV_32FC1); + + for (int i = 0 ; i < testSamplesCount; i++) + { + testResponses.at(i) = decisionFunction(testSamples.row(i), weights, shift) > 0 ? 1 : -1; + } +} + +void CV_SVMSGDTrainTest::run( int /*start_from*/ ) +{ + cv::Ptr svmsgd = SVMSGD::create(); + + svmsgd->setOptimalParameters(SVMSGD::ASGD); + + svmsgd->train( data ); + + Mat responses; + + svmsgd->predict(testSamples, responses); + + int errCount = 0; + int testSamplesCount = testSamples.rows; + + for (int i = 0; i < testSamplesCount; i++) + { + if (responses.at(i) * testResponses.at(i) < 0 ) + errCount++; + } + + float err = (float)errCount / testSamplesCount; + std::cout << "err " << err << std::endl; + + if ( err > 0.01 ) + { + ts->set_failed_test_info( cvtest::TS::FAIL_BAD_ACCURACY ); + } +} + +float CV_SVMSGDTrainTest::decisionFunction(Mat sample, Mat weights, float shift) +{ + return sample.dot(weights) + shift; +} + +TEST(ML_SVMSGD, train0) +{ + int varCount = 2; + + Mat weights; + weights.create(1, varCount, CV_32FC1); + weights.at(0) = 1; + weights.at(1) = 0; + + float shift = 5; + + CV_SVMSGDTrainTest test(weights, shift); + test.safe_run(); +} + +TEST(ML_SVMSGD, train1) +{ + int varCount = 5; + + Mat weights; + weights.create(1, varCount, CV_32FC1); + + float lowerLimit = -1; + float upperLimit = 1; + cv::RNG rng(0); + rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit); + + float shift = rng.uniform(-5.f, 5.f); + + CV_SVMSGDTrainTest test(weights, shift); + test.safe_run(); +} + +TEST(ML_SVMSGD, train2) +{ + int varCount = 100; + + Mat weights; + weights.create(1, varCount, CV_32FC1); + + float lowerLimit = -1; + float upperLimit = 1; + cv::RNG rng(0); + rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit); + + float shift = rng.uniform(-1000.f, 1000.f); + + CV_SVMSGDTrainTest test(weights, shift); + test.safe_run(); +} diff --git a/modules/ts/src/ts_gtest.cpp b/modules/ts/src/ts_gtest.cpp index 29a3996be8..5604eb7e62 100644 --- a/modules/ts/src/ts_gtest.cpp +++ b/modules/ts/src/ts_gtest.cpp @@ -5659,7 +5659,7 @@ class TestCaseNameIs { // Returns true iff the name of test_case matches name_. bool operator()(const TestCase* test_case) const { - return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0; + return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0; } private: diff --git a/samples/cpp/train_svmsgd.cpp b/samples/cpp/train_svmsgd.cpp new file mode 100644 index 0000000000..cee821739f --- /dev/null +++ b/samples/cpp/train_svmsgd.cpp @@ -0,0 +1,226 @@ +#include +#include "opencv2/video/tracking.hpp" +#include "opencv2/imgproc/imgproc.hpp" +#include "opencv2/highgui/highgui.hpp" + +using namespace cv; +using namespace cv::ml; + +#define WIDTH 841 +#define HEIGHT 594 + +struct Data +{ + Mat img; + Mat samples; + Mat responses; + RNG rng; + //Point points[2]; + + Data() + { + img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); + imshow("Train svmsgd", img); + } +}; + +bool doTrain(const Mat samples,const Mat responses, Mat &weights, float &shift); +bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]); +bool findCrossPoint(const Mat &weights, float shift, const std::pair &segment, Point &crossPoint); +void fillSegments(std::vector > &segments); +void redraw(Data data, const Point points[2]); +void addPointsRetrainAndRedraw(Data &data, int x, int y); + + +bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift) +{ + cv::Ptr svmsgd = SVMSGD::create(); + svmsgd->setOptimalParameters(SVMSGD::ASGD); + svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 50000, 0.0000001)); + svmsgd->setLambda(0.01); + svmsgd->setGamma0(1); + // svmsgd->setC(5); + + cv::Ptr train_data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses ); + svmsgd->train( train_data ); + + if (svmsgd->isTrained()) + { + weights = svmsgd->getWeights(); + shift = svmsgd->getShift(); + + std::cout << weights << std::endl; + std::cout << shift << std::endl; + + return true; + } + return false; +} + + +bool findCrossPoint(const Mat &weights, float shift, const std::pair &segment, Point &crossPoint) +{ + int x = 0; + int y = 0; + //с (0,0) всё плохо + if (segment.first.x == segment.second.x && weights.at(1) != 0) + { + x = segment.first.x; + y = -(weights.at(0) * x + shift) / weights.at(1); + if (y >= 0 && y <= HEIGHT) + { + crossPoint.x = x; + crossPoint.y = y; + return true; + } + } + else if (segment.first.y == segment.second.y && weights.at(0) != 0) + { + y = segment.first.y; + x = - (weights.at(1) * y + shift) / weights.at(0); + if (x >= 0 && x <= WIDTH) + { + crossPoint.x = x; + crossPoint.y = y; + return true; + } + } + return false; +} + +bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]) +{ + if (weights.empty()) + { + return false; + } + + int foundPointsCount = 0; + std::vector > segments; + fillSegments(segments); + + for (int i = 0; i < 4; i++) + { + if (findCrossPoint(weights, shift, segments[i], points[foundPointsCount])) + foundPointsCount++; + if (foundPointsCount > 2) + break; + } + return true; +} + +void fillSegments(std::vector > &segments) +{ + std::pair curSegment; + + curSegment.first = Point(0,0); + curSegment.second = Point(0,HEIGHT); + segments.push_back(curSegment); + + curSegment.first = Point(0,0); + curSegment.second = Point(WIDTH,0); + segments.push_back(curSegment); + + curSegment.first = Point(WIDTH,0); + curSegment.second = Point(WIDTH,HEIGHT); + segments.push_back(curSegment); + + curSegment.first = Point(0,HEIGHT); + curSegment.second = Point(WIDTH,HEIGHT); + segments.push_back(curSegment); +} + +void redraw(Data data, const Point points[2]) +{ + data.img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); + Point center; + int radius = 3; + Scalar color; + for (int i = 0; i < data.samples.rows; i++) + { + center.x = data.samples.at(i,0); + center.y = data.samples.at(i,1); + color = (data.responses.at(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128); + circle(data.img, center, radius, color, 5); + } + line(data.img, points[0],points[1],cv::Scalar(1,255,1)); + + imshow("Train svmsgd", data.img); +} + +void addPointsRetrainAndRedraw(Data &data, int x, int y) +{ + + Mat currentSample(1, 2, CV_32F); + //start +/* + Mat _weights; + _weights.create(1, 2, CV_32FC1); + _weights.at(0) = 1; + _weights.at(1) = -1; + + int _x, _y; + + for (int i=0;i<199;i++) + { + _x = data.rng.uniform(0,800); + _y = data.rng.uniform(0,500);*/ + currentSample.at(0,0) = x; + currentSample.at(0,1) = y; + //if (currentSample.dot(_weights) > 0) + //data.responses.push_back(1); + // else data.responses.push_back(-1); + + //finish + data.samples.push_back(currentSample); + + + + Mat weights(1, 2, CV_32F); + float shift = 0; + + if (doTrain(data.samples, data.responses, weights, shift)) + { + Point points[2]; + shift = 0; + + findPointsForLine(weights, shift, points); + + redraw(data, points); + } +} + + +static void onMouse( int event, int x, int y, int, void* pData) +{ + Data &data = *(Data*)pData; + + switch( event ) + { + case CV_EVENT_LBUTTONUP: + data.responses.push_back(1); + addPointsRetrainAndRedraw(data, x, y); + + break; + + case CV_EVENT_RBUTTONDOWN: + data.responses.push_back(-1); + addPointsRetrainAndRedraw(data, x, y); + break; + } + +} + +int main() +{ + + Data data; + + setMouseCallback( "Train svmsgd", onMouse, &data ); + waitKey(); + + + + + return 0; +}