Merge pull request #6096 from mnoskova:mn/SVMSGD_to_opencv3_0
commit
fbc221d334
8 changed files with 1254 additions and 1 deletions
@ -0,0 +1,510 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||
// Copyright (C) 2016, Itseez Inc, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "precomp.hpp" |
||||
#include "limits" |
||||
|
||||
#include <iostream> |
||||
|
||||
using std::cout; |
||||
using std::endl; |
||||
|
||||
/****************************************************************************************\
|
||||
* Stochastic Gradient Descent SVM Classifier * |
||||
\****************************************************************************************/ |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace ml |
||||
{ |
||||
|
||||
class SVMSGDImpl : public SVMSGD |
||||
{ |
||||
|
||||
public: |
||||
SVMSGDImpl(); |
||||
|
||||
virtual ~SVMSGDImpl() {} |
||||
|
||||
virtual bool train(const Ptr<TrainData>& data, int); |
||||
|
||||
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const; |
||||
|
||||
virtual bool isClassifier() const; |
||||
|
||||
virtual bool isTrained() const; |
||||
|
||||
virtual void clear(); |
||||
|
||||
virtual void write(FileStorage &fs) const; |
||||
|
||||
virtual void read(const FileNode &fn); |
||||
|
||||
virtual Mat getWeights(){ return weights_; } |
||||
|
||||
virtual float getShift(){ return shift_; } |
||||
|
||||
virtual int getVarCount() const { return weights_.cols; } |
||||
|
||||
virtual String getDefaultName() const {return "opencv_ml_svmsgd";} |
||||
|
||||
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN); |
||||
|
||||
CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType) |
||||
CV_IMPL_PROPERTY(int, MarginType, params.marginType) |
||||
CV_IMPL_PROPERTY(float, MarginRegularization, params.marginRegularization) |
||||
CV_IMPL_PROPERTY(float, InitialStepSize, params.initialStepSize) |
||||
CV_IMPL_PROPERTY(float, StepDecreasingPower, params.stepDecreasingPower) |
||||
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit) |
||||
|
||||
private: |
||||
void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights); |
||||
|
||||
void writeParams( FileStorage &fs ) const; |
||||
|
||||
void readParams( const FileNode &fn ); |
||||
|
||||
static inline bool isPositive(float val) { return val > 0; } |
||||
|
||||
static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier); |
||||
|
||||
float calcShift(InputArray _samples, InputArray _responses) const; |
||||
|
||||
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier); |
||||
|
||||
// Vector with SVM weights
|
||||
Mat weights_; |
||||
float shift_; |
||||
|
||||
// Parameters for learning
|
||||
struct SVMSGDParams |
||||
{ |
||||
float marginRegularization; |
||||
float initialStepSize; |
||||
float stepDecreasingPower; |
||||
TermCriteria termCrit; |
||||
int svmsgdType; |
||||
int marginType; |
||||
}; |
||||
|
||||
SVMSGDParams params; |
||||
}; |
||||
|
||||
Ptr<SVMSGD> SVMSGD::create() |
||||
{ |
||||
return makePtr<SVMSGDImpl>(); |
||||
} |
||||
|
||||
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier) |
||||
{ |
||||
int featuresCount = samples.cols; |
||||
int samplesCount = samples.rows; |
||||
|
||||
average = Mat(1, featuresCount, samples.type()); |
||||
CV_Assert(average.type() == CV_32FC1); |
||||
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++) |
||||
{ |
||||
average.at<float>(featureIndex) = static_cast<float>(mean(samples.col(featureIndex))[0]); |
||||
} |
||||
|
||||
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++) |
||||
{ |
||||
samples.row(sampleIndex) -= average; |
||||
} |
||||
|
||||
double normValue = norm(samples); |
||||
|
||||
multiplier = static_cast<float>(sqrt(samples.total()) / normValue); |
||||
|
||||
samples *= multiplier; |
||||
} |
||||
|
||||
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier) |
||||
{ |
||||
Mat normalizedTrainSamples = trainSamples.clone(); |
||||
int samplesCount = normalizedTrainSamples.rows; |
||||
|
||||
normalizeSamples(normalizedTrainSamples, average, multiplier); |
||||
|
||||
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F); |
||||
cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples); |
||||
} |
||||
|
||||
void SVMSGDImpl::updateWeights(InputArray _sample, bool positive, float stepSize, Mat& weights) |
||||
{ |
||||
Mat sample = _sample.getMat(); |
||||
|
||||
int response = positive ? 1 : -1; // ensure that trainResponses are -1 or 1
|
||||
|
||||
if ( sample.dot(weights) * response > 1) |
||||
{ |
||||
// Not a support vector, only apply weight decay
|
||||
weights *= (1.f - stepSize * params.marginRegularization); |
||||
} |
||||
else |
||||
{ |
||||
// It's a support vector, add it to the weights
|
||||
weights -= (stepSize * params.marginRegularization) * weights - (stepSize * response) * sample; |
||||
} |
||||
} |
||||
|
||||
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const |
||||
{ |
||||
float margin[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() }; |
||||
|
||||
Mat trainSamples = _samples.getMat(); |
||||
int trainSamplesCount = trainSamples.rows; |
||||
|
||||
Mat trainResponses = _responses.getMat(); |
||||
|
||||
CV_Assert(trainResponses.type() == CV_32FC1); |
||||
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++) |
||||
{ |
||||
Mat currentSample = trainSamples.row(samplesIndex); |
||||
float dotProduct = static_cast<float>(currentSample.dot(weights_)); |
||||
|
||||
bool positive = isPositive(trainResponses.at<float>(samplesIndex)); |
||||
int index = positive ? 0 : 1; |
||||
float signToMul = positive ? 1.f : -1.f; |
||||
float curMargin = dotProduct * signToMul; |
||||
|
||||
if (curMargin < margin[index]) |
||||
{ |
||||
margin[index] = curMargin; |
||||
} |
||||
} |
||||
|
||||
return -(margin[0] - margin[1]) / 2.f; |
||||
} |
||||
|
||||
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int) |
||||
{ |
||||
clear(); |
||||
CV_Assert( isClassifier() ); //toDo: consider
|
||||
|
||||
Mat trainSamples = data->getTrainSamples(); |
||||
|
||||
int featureCount = trainSamples.cols; |
||||
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
||||
|
||||
CV_Assert(trainResponses.rows == trainSamples.rows); |
||||
|
||||
if (trainResponses.empty()) |
||||
{ |
||||
return false; |
||||
} |
||||
|
||||
int positiveCount = countNonZero(trainResponses >= 0); |
||||
int negativeCount = countNonZero(trainResponses < 0); |
||||
|
||||
if ( positiveCount <= 0 || negativeCount <= 0 ) |
||||
{ |
||||
weights_ = Mat::zeros(1, featureCount, CV_32F); |
||||
shift_ = (positiveCount > 0) ? 1.f : -1.f; |
||||
return true; |
||||
} |
||||
|
||||
Mat extendedTrainSamples; |
||||
Mat average; |
||||
float multiplier = 0; |
||||
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier); |
||||
|
||||
int extendedTrainSamplesCount = extendedTrainSamples.rows; |
||||
int extendedFeatureCount = extendedTrainSamples.cols; |
||||
|
||||
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); |
||||
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); |
||||
Mat averageExtendedWeights; |
||||
if (params.svmsgdType == ASGD) |
||||
{ |
||||
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); |
||||
} |
||||
|
||||
RNG rng(0); |
||||
|
||||
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS); |
||||
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX; |
||||
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0; |
||||
|
||||
double err = DBL_MAX; |
||||
CV_Assert (trainResponses.type() == CV_32FC1); |
||||
// Stochastic gradient descent SVM
|
||||
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++) |
||||
{ |
||||
int randomNumber = rng.uniform(0, extendedTrainSamplesCount); //generate sample number
|
||||
|
||||
Mat currentSample = extendedTrainSamples.row(randomNumber); |
||||
|
||||
float stepSize = params.initialStepSize * std::pow((1 + params.marginRegularization * params.initialStepSize * (float)iter), (-params.stepDecreasingPower)); //update stepSize
|
||||
|
||||
updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), stepSize, extendedWeights ); |
||||
|
||||
//average weights (only for ASGD model)
|
||||
if (params.svmsgdType == ASGD) |
||||
{ |
||||
averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights + extendedWeights / (1 + (float) iter); |
||||
err = norm(averageExtendedWeights - previousWeights); |
||||
averageExtendedWeights.copyTo(previousWeights); |
||||
} |
||||
else |
||||
{ |
||||
err = norm(extendedWeights - previousWeights); |
||||
extendedWeights.copyTo(previousWeights); |
||||
} |
||||
} |
||||
|
||||
if (params.svmsgdType == ASGD) |
||||
{ |
||||
extendedWeights = averageExtendedWeights; |
||||
} |
||||
|
||||
Rect roi(0, 0, featureCount, 1); |
||||
weights_ = extendedWeights(roi); |
||||
weights_ *= multiplier; |
||||
|
||||
CV_Assert((params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN) && (extendedWeights.type() == CV_32FC1)); |
||||
|
||||
if (params.marginType == SOFT_MARGIN) |
||||
{ |
||||
shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average)); |
||||
} |
||||
else |
||||
{ |
||||
shift_ = calcShift(trainSamples, trainResponses); |
||||
} |
||||
|
||||
return true; |
||||
} |
||||
|
||||
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const |
||||
{ |
||||
float result = 0; |
||||
cv::Mat samples = _samples.getMat(); |
||||
int nSamples = samples.rows; |
||||
cv::Mat results; |
||||
|
||||
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32FC1); |
||||
|
||||
if( _results.needed() ) |
||||
{ |
||||
_results.create( nSamples, 1, samples.type() ); |
||||
results = _results.getMat(); |
||||
} |
||||
else |
||||
{ |
||||
CV_Assert( nSamples == 1 ); |
||||
results = Mat(1, 1, CV_32FC1, &result); |
||||
} |
||||
|
||||
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++) |
||||
{ |
||||
Mat currentSample = samples.row(sampleIndex); |
||||
float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_; |
||||
results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f; |
||||
} |
||||
|
||||
return result; |
||||
} |
||||
|
||||
bool SVMSGDImpl::isClassifier() const |
||||
{ |
||||
return (params.svmsgdType == SGD || params.svmsgdType == ASGD) |
||||
&& |
||||
(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN) |
||||
&& |
||||
(params.marginRegularization > 0) && (params.initialStepSize > 0) && (params.stepDecreasingPower >= 0); |
||||
} |
||||
|
||||
bool SVMSGDImpl::isTrained() const |
||||
{ |
||||
return !weights_.empty(); |
||||
} |
||||
|
||||
void SVMSGDImpl::write(FileStorage& fs) const |
||||
{ |
||||
if( !isTrained() ) |
||||
CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" ); |
||||
|
||||
writeParams( fs ); |
||||
|
||||
fs << "weights" << weights_; |
||||
fs << "shift" << shift_; |
||||
} |
||||
|
||||
void SVMSGDImpl::writeParams( FileStorage& fs ) const |
||||
{ |
||||
String SvmsgdTypeStr; |
||||
|
||||
switch (params.svmsgdType) |
||||
{ |
||||
case SGD: |
||||
SvmsgdTypeStr = "SGD"; |
||||
break; |
||||
case ASGD: |
||||
SvmsgdTypeStr = "ASGD"; |
||||
break; |
||||
default: |
||||
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType); |
||||
} |
||||
|
||||
fs << "svmsgdType" << SvmsgdTypeStr; |
||||
|
||||
String marginTypeStr; |
||||
|
||||
switch (params.marginType) |
||||
{ |
||||
case SOFT_MARGIN: |
||||
marginTypeStr = "SOFT_MARGIN"; |
||||
break; |
||||
case HARD_MARGIN: |
||||
marginTypeStr = "HARD_MARGIN"; |
||||
break; |
||||
default: |
||||
marginTypeStr = format("Unknown_%d", params.marginType); |
||||
} |
||||
|
||||
fs << "marginType" << marginTypeStr; |
||||
|
||||
fs << "marginRegularization" << params.marginRegularization; |
||||
fs << "initialStepSize" << params.initialStepSize; |
||||
fs << "stepDecreasingPower" << params.stepDecreasingPower; |
||||
|
||||
fs << "term_criteria" << "{:"; |
||||
if( params.termCrit.type & TermCriteria::EPS ) |
||||
fs << "epsilon" << params.termCrit.epsilon; |
||||
if( params.termCrit.type & TermCriteria::COUNT ) |
||||
fs << "iterations" << params.termCrit.maxCount; |
||||
fs << "}"; |
||||
} |
||||
void SVMSGDImpl::readParams( const FileNode& fn ) |
||||
{ |
||||
String svmsgdTypeStr = (String)fn["svmsgdType"]; |
||||
int svmsgdType = |
||||
svmsgdTypeStr == "SGD" ? SGD : |
||||
svmsgdTypeStr == "ASGD" ? ASGD : -1; |
||||
|
||||
if( svmsgdType < 0 ) |
||||
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" ); |
||||
|
||||
params.svmsgdType = svmsgdType; |
||||
|
||||
String marginTypeStr = (String)fn["marginType"]; |
||||
int marginType = |
||||
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN : |
||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1; |
||||
|
||||
if( marginType < 0 ) |
||||
CV_Error( CV_StsParseError, "Missing or invalid margin type" ); |
||||
|
||||
params.marginType = marginType; |
||||
|
||||
CV_Assert ( fn["marginRegularization"].isReal() ); |
||||
params.marginRegularization = (float)fn["marginRegularization"]; |
||||
|
||||
CV_Assert ( fn["initialStepSize"].isReal() ); |
||||
params.initialStepSize = (float)fn["initialStepSize"]; |
||||
|
||||
CV_Assert ( fn["stepDecreasingPower"].isReal() ); |
||||
params.stepDecreasingPower = (float)fn["stepDecreasingPower"]; |
||||
|
||||
FileNode tcnode = fn["term_criteria"]; |
||||
CV_Assert(!tcnode.empty()); |
||||
params.termCrit.epsilon = (double)tcnode["epsilon"]; |
||||
params.termCrit.maxCount = (int)tcnode["iterations"]; |
||||
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) + |
||||
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0); |
||||
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS)); |
||||
} |
||||
|
||||
void SVMSGDImpl::read(const FileNode& fn) |
||||
{ |
||||
clear(); |
||||
|
||||
readParams(fn); |
||||
|
||||
fn["weights"] >> weights_; |
||||
fn["shift"] >> shift_; |
||||
} |
||||
|
||||
void SVMSGDImpl::clear() |
||||
{ |
||||
weights_.release(); |
||||
shift_ = 0; |
||||
} |
||||
|
||||
|
||||
SVMSGDImpl::SVMSGDImpl() |
||||
{ |
||||
clear(); |
||||
setOptimalParameters(); |
||||
} |
||||
|
||||
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) |
||||
{ |
||||
switch (svmsgdType) |
||||
{ |
||||
case SGD: |
||||
params.svmsgdType = SGD; |
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : |
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1; |
||||
params.marginRegularization = 0.0001f; |
||||
params.initialStepSize = 0.05f; |
||||
params.stepDecreasingPower = 1.f; |
||||
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001); |
||||
break; |
||||
|
||||
case ASGD: |
||||
params.svmsgdType = ASGD; |
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : |
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1; |
||||
params.marginRegularization = 0.00001f; |
||||
params.initialStepSize = 0.05f; |
||||
params.stepDecreasingPower = 0.75f; |
||||
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001); |
||||
break; |
||||
|
||||
default: |
||||
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" ); |
||||
} |
||||
} |
||||
} //ml
|
||||
} //cv
|
@ -0,0 +1,318 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// Intel License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of Intel Corporation may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "test_precomp.hpp" |
||||
#include "opencv2/highgui.hpp" |
||||
|
||||
using namespace cv; |
||||
using namespace cv::ml; |
||||
using cv::ml::SVMSGD; |
||||
using cv::ml::TrainData; |
||||
|
||||
|
||||
|
||||
class CV_SVMSGDTrainTest : public cvtest::BaseTest |
||||
{ |
||||
public: |
||||
enum TrainDataType |
||||
{ |
||||
UNIFORM_SAME_SCALE, |
||||
UNIFORM_DIFFERENT_SCALES |
||||
}; |
||||
|
||||
CV_SVMSGDTrainTest(const Mat &_weights, float shift, TrainDataType type, double precision = 0.01); |
||||
private: |
||||
virtual void run( int start_from ); |
||||
static float decisionFunction(const Mat &sample, const Mat &weights, float shift); |
||||
void makeData(int samplesCount, const Mat &weights, float shift, RNG &rng, Mat &samples, Mat & responses); |
||||
void generateSameBorders(int featureCount); |
||||
void generateDifferentBorders(int featureCount); |
||||
|
||||
TrainDataType type; |
||||
double precision; |
||||
std::vector<std::pair<float,float> > borders; |
||||
cv::Ptr<TrainData> data; |
||||
cv::Mat testSamples; |
||||
cv::Mat testResponses; |
||||
static const int TEST_VALUE_LIMIT = 500; |
||||
}; |
||||
|
||||
void CV_SVMSGDTrainTest::generateSameBorders(int featureCount) |
||||
{ |
||||
float lowerLimit = -TEST_VALUE_LIMIT; |
||||
float upperLimit = TEST_VALUE_LIMIT; |
||||
|
||||
for (int featureIndex = 0; featureIndex < featureCount; featureIndex++) |
||||
{ |
||||
borders.push_back(std::pair<float,float>(lowerLimit, upperLimit)); |
||||
} |
||||
} |
||||
|
||||
void CV_SVMSGDTrainTest::generateDifferentBorders(int featureCount) |
||||
{ |
||||
float lowerLimit = -TEST_VALUE_LIMIT; |
||||
float upperLimit = TEST_VALUE_LIMIT; |
||||
cv::RNG rng(0); |
||||
|
||||
for (int featureIndex = 0; featureIndex < featureCount; featureIndex++) |
||||
{ |
||||
int crit = rng.uniform(0, 2); |
||||
|
||||
if (crit > 0) |
||||
{ |
||||
borders.push_back(std::pair<float,float>(lowerLimit, upperLimit)); |
||||
} |
||||
else |
||||
{ |
||||
borders.push_back(std::pair<float,float>(lowerLimit/1000, upperLimit/1000)); |
||||
} |
||||
} |
||||
} |
||||
|
||||
float CV_SVMSGDTrainTest::decisionFunction(const Mat &sample, const Mat &weights, float shift) |
||||
{ |
||||
return static_cast<float>(sample.dot(weights)) + shift; |
||||
} |
||||
|
||||
void CV_SVMSGDTrainTest::makeData(int samplesCount, const Mat &weights, float shift, RNG &rng, Mat &samples, Mat & responses) |
||||
{ |
||||
int featureCount = weights.cols; |
||||
|
||||
samples.create(samplesCount, featureCount, CV_32FC1); |
||||
for (int featureIndex = 0; featureIndex < featureCount; featureIndex++) |
||||
{ |
||||
rng.fill(samples.col(featureIndex), RNG::UNIFORM, borders[featureIndex].first, borders[featureIndex].second); |
||||
} |
||||
|
||||
responses.create(samplesCount, 1, CV_32FC1); |
||||
|
||||
for (int i = 0 ; i < samplesCount; i++) |
||||
{ |
||||
responses.at<float>(i) = decisionFunction(samples.row(i), weights, shift) > 0 ? 1.f : -1.f; |
||||
} |
||||
|
||||
} |
||||
|
||||
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(const Mat &weights, float shift, TrainDataType _type, double _precision) |
||||
{ |
||||
type = _type; |
||||
precision = _precision; |
||||
|
||||
int featureCount = weights.cols; |
||||
|
||||
switch(type) |
||||
{ |
||||
case UNIFORM_SAME_SCALE: |
||||
generateSameBorders(featureCount); |
||||
break; |
||||
case UNIFORM_DIFFERENT_SCALES: |
||||
generateDifferentBorders(featureCount); |
||||
break; |
||||
default: |
||||
CV_Error(CV_StsBadArg, "Unknown train data type"); |
||||
} |
||||
|
||||
RNG rng(0); |
||||
|
||||
Mat trainSamples; |
||||
Mat trainResponses; |
||||
int trainSamplesCount = 10000; |
||||
makeData(trainSamplesCount, weights, shift, rng, trainSamples, trainResponses); |
||||
data = TrainData::create(trainSamples, cv::ml::ROW_SAMPLE, trainResponses); |
||||
|
||||
int testSamplesCount = 100000; |
||||
makeData(testSamplesCount, weights, shift, rng, testSamples, testResponses); |
||||
} |
||||
|
||||
void CV_SVMSGDTrainTest::run( int /*start_from*/ ) |
||||
{ |
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); |
||||
|
||||
svmsgd->train(data); |
||||
|
||||
Mat responses; |
||||
|
||||
svmsgd->predict(testSamples, responses); |
||||
|
||||
int errCount = 0; |
||||
int testSamplesCount = testSamples.rows; |
||||
|
||||
CV_Assert((responses.type() == CV_32FC1) && (testResponses.type() == CV_32FC1)); |
||||
for (int i = 0; i < testSamplesCount; i++) |
||||
{ |
||||
if (responses.at<float>(i) * testResponses.at<float>(i) < 0) |
||||
errCount++; |
||||
} |
||||
|
||||
float err = (float)errCount / testSamplesCount; |
||||
|
||||
if ( err > precision ) |
||||
{ |
||||
ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); |
||||
} |
||||
} |
||||
|
||||
void makeWeightsAndShift(int featureCount, Mat &weights, float &shift) |
||||
{ |
||||
weights.create(1, featureCount, CV_32FC1); |
||||
cv::RNG rng(0); |
||||
double lowerLimit = -1; |
||||
double upperLimit = 1; |
||||
|
||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit); |
||||
shift = static_cast<float>(rng.uniform(-featureCount, featureCount)); |
||||
} |
||||
|
||||
|
||||
TEST(ML_SVMSGD, trainSameScale2) |
||||
{ |
||||
int featureCount = 2; |
||||
|
||||
Mat weights; |
||||
|
||||
float shift = 0; |
||||
makeWeightsAndShift(featureCount, weights, shift); |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE); |
||||
test.safe_run(); |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, trainSameScale5) |
||||
{ |
||||
int featureCount = 5; |
||||
|
||||
Mat weights; |
||||
|
||||
float shift = 0; |
||||
makeWeightsAndShift(featureCount, weights, shift); |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE); |
||||
test.safe_run(); |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, trainSameScale100) |
||||
{ |
||||
int featureCount = 100; |
||||
|
||||
Mat weights; |
||||
|
||||
float shift = 0; |
||||
makeWeightsAndShift(featureCount, weights, shift); |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE, 0.02); |
||||
test.safe_run(); |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales2) |
||||
{ |
||||
int featureCount = 2; |
||||
|
||||
Mat weights; |
||||
|
||||
float shift = 0; |
||||
makeWeightsAndShift(featureCount, weights, shift); |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.01); |
||||
test.safe_run(); |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales5) |
||||
{ |
||||
int featureCount = 5; |
||||
|
||||
Mat weights; |
||||
|
||||
float shift = 0; |
||||
makeWeightsAndShift(featureCount, weights, shift); |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.01); |
||||
test.safe_run(); |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales100) |
||||
{ |
||||
int featureCount = 100; |
||||
|
||||
Mat weights; |
||||
|
||||
float shift = 0; |
||||
makeWeightsAndShift(featureCount, weights, shift); |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.01); |
||||
test.safe_run(); |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, twoPoints) |
||||
{ |
||||
Mat samples(2, 2, CV_32FC1); |
||||
samples.at<float>(0,0) = 0; |
||||
samples.at<float>(0,1) = 0; |
||||
samples.at<float>(1,0) = 1000; |
||||
samples.at<float>(1,1) = 1; |
||||
|
||||
Mat responses(2, 1, CV_32FC1); |
||||
responses.at<float>(0) = -1; |
||||
responses.at<float>(1) = 1; |
||||
|
||||
cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses); |
||||
|
||||
Mat realWeights(1, 2, CV_32FC1); |
||||
realWeights.at<float>(0) = 1000; |
||||
realWeights.at<float>(1) = 1; |
||||
|
||||
float realShift = -500000.5; |
||||
|
||||
float normRealWeights = static_cast<float>(norm(realWeights)); |
||||
realWeights /= normRealWeights; |
||||
realShift /= normRealWeights; |
||||
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); |
||||
svmsgd->setOptimalParameters(); |
||||
svmsgd->train( trainData ); |
||||
|
||||
Mat foundWeights = svmsgd->getWeights(); |
||||
float foundShift = svmsgd->getShift(); |
||||
|
||||
float normFoundWeights = static_cast<float>(norm(foundWeights)); |
||||
foundWeights /= normFoundWeights; |
||||
foundShift /= normFoundWeights; |
||||
CV_Assert((norm(foundWeights - realWeights) < 0.001) && (abs((foundShift - realShift) / realShift) < 0.05)); |
||||
} |
@ -0,0 +1,210 @@ |
||||
#include <opencv2/opencv.hpp> |
||||
#include "opencv2/video/tracking.hpp" |
||||
#include "opencv2/imgproc/imgproc.hpp" |
||||
#include "opencv2/highgui/highgui.hpp" |
||||
|
||||
using namespace cv; |
||||
using namespace cv::ml; |
||||
|
||||
|
||||
struct Data |
||||
{ |
||||
Mat img; |
||||
Mat samples; //Set of train samples. Contains points on image
|
||||
Mat responses; //Set of responses for train samples
|
||||
|
||||
Data() |
||||
{ |
||||
const int WIDTH = 841; |
||||
const int HEIGHT = 594; |
||||
img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); |
||||
imshow("Train svmsgd", img); |
||||
} |
||||
}; |
||||
|
||||
//Train with SVMSGD algorithm
|
||||
//(samples, responses) is a train set
|
||||
//weights is a required vector for decision function of SVMSGD algorithm
|
||||
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift); |
||||
|
||||
//function finds two points for drawing line (wx = 0)
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point points[], int width, int height); |
||||
|
||||
// function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
|
||||
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint); |
||||
|
||||
//segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height); |
||||
|
||||
//redraw points' set and line (wx = 0)
|
||||
void redraw(Data data, const Point points[2]); |
||||
|
||||
//add point in train set, train SVMSGD algorithm and draw results on image
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y, int response); |
||||
|
||||
|
||||
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift) |
||||
{ |
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); |
||||
|
||||
cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses); |
||||
svmsgd->train( trainData ); |
||||
|
||||
if (svmsgd->isTrained()) |
||||
{ |
||||
weights = svmsgd->getWeights(); |
||||
shift = svmsgd->getShift(); |
||||
|
||||
return true; |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height) |
||||
{ |
||||
std::pair<Point,Point> currentSegment; |
||||
|
||||
currentSegment.first = Point(width, 0); |
||||
currentSegment.second = Point(width, height); |
||||
segments.push_back(currentSegment); |
||||
|
||||
currentSegment.first = Point(0, height); |
||||
currentSegment.second = Point(width, height); |
||||
segments.push_back(currentSegment); |
||||
|
||||
currentSegment.first = Point(0, 0); |
||||
currentSegment.second = Point(width, 0); |
||||
segments.push_back(currentSegment); |
||||
|
||||
currentSegment.first = Point(0, 0); |
||||
currentSegment.second = Point(0, height); |
||||
segments.push_back(currentSegment); |
||||
} |
||||
|
||||
|
||||
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint) |
||||
{ |
||||
int x = 0; |
||||
int y = 0; |
||||
int xMin = std::min(segment.first.x, segment.second.x); |
||||
int xMax = std::max(segment.first.x, segment.second.x); |
||||
int yMin = std::min(segment.first.y, segment.second.y); |
||||
int yMax = std::max(segment.first.y, segment.second.y); |
||||
|
||||
CV_Assert(weights.type() == CV_32FC1); |
||||
CV_Assert(xMin == xMax || yMin == yMax); |
||||
|
||||
if (xMin == xMax && weights.at<float>(1) != 0) |
||||
{ |
||||
x = xMin; |
||||
y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1))); |
||||
if (y >= yMin && y <= yMax) |
||||
{ |
||||
crossPoint.x = x; |
||||
crossPoint.y = y; |
||||
return true; |
||||
} |
||||
} |
||||
else if (yMin == yMax && weights.at<float>(0) != 0) |
||||
{ |
||||
y = yMin; |
||||
x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0))); |
||||
if (x >= xMin && x <= xMax) |
||||
{ |
||||
crossPoint.x = x; |
||||
crossPoint.y = y; |
||||
return true; |
||||
} |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height) |
||||
{ |
||||
if (weights.empty()) |
||||
{ |
||||
return false; |
||||
} |
||||
|
||||
int foundPointsCount = 0; |
||||
std::vector<std::pair<Point,Point> > segments; |
||||
fillSegments(segments, width, height); |
||||
|
||||
for (uint i = 0; i < segments.size(); i++) |
||||
{ |
||||
if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount])) |
||||
foundPointsCount++; |
||||
if (foundPointsCount >= 2) |
||||
break; |
||||
} |
||||
|
||||
return true; |
||||
} |
||||
|
||||
void redraw(Data data, const Point points[2]) |
||||
{ |
||||
data.img.setTo(0); |
||||
Point center; |
||||
int radius = 3; |
||||
Scalar color; |
||||
CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1)); |
||||
for (int i = 0; i < data.samples.rows; i++) |
||||
{ |
||||
center.x = static_cast<int>(data.samples.at<float>(i,0)); |
||||
center.y = static_cast<int>(data.samples.at<float>(i,1)); |
||||
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128); |
||||
circle(data.img, center, radius, color, 5); |
||||
} |
||||
line(data.img, points[0], points[1],cv::Scalar(1,255,1)); |
||||
|
||||
imshow("Train svmsgd", data.img); |
||||
} |
||||
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y, int response) |
||||
{ |
||||
Mat currentSample(1, 2, CV_32FC1); |
||||
|
||||
currentSample.at<float>(0,0) = (float)x; |
||||
currentSample.at<float>(0,1) = (float)y; |
||||
data.samples.push_back(currentSample); |
||||
data.responses.push_back(response); |
||||
|
||||
Mat weights(1, 2, CV_32FC1); |
||||
float shift = 0; |
||||
|
||||
if (doTrain(data.samples, data.responses, weights, shift)) |
||||
{ |
||||
Point points[2]; |
||||
findPointsForLine(weights, shift, points, data.img.cols, data.img.rows); |
||||
|
||||
redraw(data, points); |
||||
} |
||||
} |
||||
|
||||
|
||||
static void onMouse( int event, int x, int y, int, void* pData) |
||||
{ |
||||
Data &data = *(Data*)pData; |
||||
|
||||
switch( event ) |
||||
{ |
||||
case CV_EVENT_LBUTTONUP: |
||||
addPointRetrainAndRedraw(data, x, y, 1); |
||||
break; |
||||
|
||||
case CV_EVENT_RBUTTONDOWN: |
||||
addPointRetrainAndRedraw(data, x, y, -1); |
||||
break; |
||||
} |
||||
|
||||
} |
||||
|
||||
int main() |
||||
{ |
||||
Data data; |
||||
|
||||
setMouseCallback( "Train svmsgd", onMouse, &data ); |
||||
waitKey(); |
||||
|
||||
return 0; |
||||
} |
Loading…
Reference in new issue