mirror of https://github.com/opencv/opencv.git
parent
a2f0963d66
commit
40bf97c6d1
11 changed files with 965 additions and 226 deletions
@ -0,0 +1,134 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
|
||||
// Copyright (C) 2014, Itseez Inc, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#ifndef __OPENCV_ML_SVMSGD_HPP__ |
||||
#define __OPENCV_ML_SVMSGD_HPP__ |
||||
|
||||
#ifdef __cplusplus |
||||
|
||||
#include "opencv2/ml.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace ml |
||||
{ |
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* Stochastic Gradient Descent SVM Classifier * |
||||
\****************************************************************************************/ |
||||
|
||||
/*!
|
||||
@brief Stochastic Gradient Descent SVM classifier |
||||
|
||||
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large. |
||||
The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which |
||||
is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example). |
||||
|
||||
First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined. |
||||
|
||||
Then the SVM model can be trained using the train features and the correspondent labels. |
||||
|
||||
After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically. |
||||
|
||||
@code |
||||
// Initialize object
|
||||
SVMSGD SvmSgd; |
||||
|
||||
// Train the Stochastic Gradient Descent SVM
|
||||
SvmSgd.train(trainFeatures, labels); |
||||
|
||||
// Predict label for the new feature vector (1xM)
|
||||
predictedLabel = SvmSgd.predict(newFeatureVector); |
||||
@endcode |
||||
|
||||
*/ |
||||
|
||||
class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel |
||||
{ |
||||
public: |
||||
|
||||
enum SvmsgdType |
||||
{ |
||||
ILLEGAL_VALUE, |
||||
SGD, //Stochastic Gradient Descent
|
||||
ASGD //Average Stochastic Gradient Descent
|
||||
}; |
||||
|
||||
/**
|
||||
* @return the weights of the trained model. |
||||
*/ |
||||
CV_WRAP virtual Mat getWeights() = 0; |
||||
|
||||
CV_WRAP virtual float getShift() = 0; |
||||
|
||||
CV_WRAP static Ptr<SVMSGD> create();
|
||||
|
||||
CV_WRAP virtual void setOptimalParameters(int type = ASGD) = 0; |
||||
|
||||
CV_WRAP virtual int getType() const = 0; |
||||
|
||||
CV_WRAP virtual void setType(int type) = 0; |
||||
|
||||
CV_WRAP virtual float getLambda() const = 0; |
||||
|
||||
CV_WRAP virtual void setLambda(float lambda) = 0; |
||||
|
||||
CV_WRAP virtual float getGamma0() const = 0; |
||||
|
||||
CV_WRAP virtual void setGamma0(float gamma0) = 0; |
||||
|
||||
CV_WRAP virtual float getC() const = 0; |
||||
|
||||
CV_WRAP virtual void setC(float c) = 0; |
||||
|
||||
CV_WRAP virtual cv::TermCriteria getTermCriteria() const = 0; |
||||
|
||||
CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0; |
||||
}; |
||||
|
||||
} //ml
|
||||
} //cv
|
||||
|
||||
#endif // __clpusplus
|
||||
#endif // __OPENCV_ML_SVMSGD_HPP
|
@ -0,0 +1,182 @@ |
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// Intel License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of Intel Corporation may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "test_precomp.hpp" |
||||
#include "opencv2/highgui.hpp" |
||||
|
||||
using namespace cv; |
||||
using namespace cv::ml; |
||||
using cv::ml::SVMSGD; |
||||
using cv::ml::TrainData; |
||||
|
||||
|
||||
|
||||
class CV_SVMSGDTrainTest : public cvtest::BaseTest |
||||
{ |
||||
public: |
||||
CV_SVMSGDTrainTest(Mat _weights, float _shift); |
||||
private: |
||||
virtual void run( int start_from ); |
||||
float decisionFunction(Mat sample, Mat weights, float shift); |
||||
|
||||
cv::Ptr<TrainData> data; |
||||
cv::Mat testSamples; |
||||
cv::Mat testResponses; |
||||
static const int TEST_VALUE_LIMIT = 50; |
||||
}; |
||||
|
||||
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift) |
||||
{ |
||||
int datasize = 100000; |
||||
int varCount = weights.cols; |
||||
cv::Mat samples = cv::Mat::zeros( datasize, varCount, CV_32FC1 ); |
||||
cv::Mat responses = cv::Mat::zeros( datasize, 1, CV_32FC1 ); |
||||
cv::RNG rng(0); |
||||
|
||||
float lowerLimit = -TEST_VALUE_LIMIT; |
||||
float upperLimit = TEST_VALUE_LIMIT; |
||||
|
||||
|
||||
rng.fill(samples, RNG::UNIFORM, lowerLimit, upperLimit); |
||||
for (int sampleIndex = 0; sampleIndex < datasize; sampleIndex++) |
||||
{ |
||||
responses.at<float>( sampleIndex ) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1; |
||||
} |
||||
|
||||
data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses ); |
||||
|
||||
int testSamplesCount = 100000; |
||||
|
||||
testSamples.create(testSamplesCount, varCount, CV_32FC1); |
||||
rng.fill(testSamples, RNG::UNIFORM, lowerLimit, upperLimit); |
||||
testResponses.create(testSamplesCount, 1, CV_32FC1); |
||||
|
||||
for (int i = 0 ; i < testSamplesCount; i++) |
||||
{ |
||||
testResponses.at<float>(i) = decisionFunction(testSamples.row(i), weights, shift) > 0 ? 1 : -1; |
||||
} |
||||
} |
||||
|
||||
void CV_SVMSGDTrainTest::run( int /*start_from*/ ) |
||||
{ |
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); |
||||
|
||||
svmsgd->setOptimalParameters(SVMSGD::ASGD); |
||||
|
||||
svmsgd->train( data ); |
||||
|
||||
Mat responses; |
||||
|
||||
svmsgd->predict(testSamples, responses); |
||||
|
||||
int errCount = 0; |
||||
int testSamplesCount = testSamples.rows; |
||||
|
||||
for (int i = 0; i < testSamplesCount; i++) |
||||
{ |
||||
if (responses.at<float>(i) * testResponses.at<float>(i) < 0 ) |
||||
errCount++; |
||||
} |
||||
|
||||
float err = (float)errCount / testSamplesCount; |
||||
std::cout << "err " << err << std::endl; |
||||
|
||||
if ( err > 0.01 ) |
||||
{ |
||||
ts->set_failed_test_info( cvtest::TS::FAIL_BAD_ACCURACY ); |
||||
} |
||||
} |
||||
|
||||
float CV_SVMSGDTrainTest::decisionFunction(Mat sample, Mat weights, float shift) |
||||
{ |
||||
return sample.dot(weights) + shift; |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, train0) |
||||
{ |
||||
int varCount = 2; |
||||
|
||||
Mat weights; |
||||
weights.create(1, varCount, CV_32FC1); |
||||
weights.at<float>(0) = 1; |
||||
weights.at<float>(1) = 0; |
||||
|
||||
float shift = 5; |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift); |
||||
test.safe_run(); |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, train1) |
||||
{ |
||||
int varCount = 5; |
||||
|
||||
Mat weights; |
||||
weights.create(1, varCount, CV_32FC1); |
||||
|
||||
float lowerLimit = -1; |
||||
float upperLimit = 1; |
||||
cv::RNG rng(0); |
||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit); |
||||
|
||||
float shift = rng.uniform(-5.f, 5.f); |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift); |
||||
test.safe_run(); |
||||
} |
||||
|
||||
TEST(ML_SVMSGD, train2) |
||||
{ |
||||
int varCount = 100; |
||||
|
||||
Mat weights; |
||||
weights.create(1, varCount, CV_32FC1); |
||||
|
||||
float lowerLimit = -1; |
||||
float upperLimit = 1; |
||||
cv::RNG rng(0); |
||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit); |
||||
|
||||
float shift = rng.uniform(-1000.f, 1000.f); |
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift); |
||||
test.safe_run(); |
||||
} |
@ -0,0 +1,226 @@ |
||||
#include <opencv2/opencv.hpp> |
||||
#include "opencv2/video/tracking.hpp" |
||||
#include "opencv2/imgproc/imgproc.hpp" |
||||
#include "opencv2/highgui/highgui.hpp" |
||||
|
||||
using namespace cv; |
||||
using namespace cv::ml; |
||||
|
||||
#define WIDTH 841 |
||||
#define HEIGHT 594 |
||||
|
||||
struct Data |
||||
{ |
||||
Mat img; |
||||
Mat samples; |
||||
Mat responses; |
||||
RNG rng; |
||||
//Point points[2];
|
||||
|
||||
Data() |
||||
{ |
||||
img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); |
||||
imshow("Train svmsgd", img); |
||||
} |
||||
}; |
||||
|
||||
bool doTrain(const Mat samples,const Mat responses, Mat &weights, float &shift); |
||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]); |
||||
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint); |
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments); |
||||
void redraw(Data data, const Point points[2]); |
||||
void addPointsRetrainAndRedraw(Data &data, int x, int y); |
||||
|
||||
|
||||
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift) |
||||
{ |
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); |
||||
svmsgd->setOptimalParameters(SVMSGD::ASGD); |
||||
svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 50000, 0.0000001)); |
||||
svmsgd->setLambda(0.01); |
||||
svmsgd->setGamma0(1); |
||||
// svmsgd->setC(5);
|
||||
|
||||
cv::Ptr<TrainData> train_data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses ); |
||||
svmsgd->train( train_data ); |
||||
|
||||
if (svmsgd->isTrained()) |
||||
{ |
||||
weights = svmsgd->getWeights(); |
||||
shift = svmsgd->getShift(); |
||||
|
||||
std::cout << weights << std::endl; |
||||
std::cout << shift << std::endl; |
||||
|
||||
return true; |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
|
||||
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint) |
||||
{ |
||||
int x = 0; |
||||
int y = 0; |
||||
//с (0,0) всё плохо
|
||||
if (segment.first.x == segment.second.x && weights.at<float>(1) != 0) |
||||
{ |
||||
x = segment.first.x; |
||||
y = -(weights.at<float>(0) * x + shift) / weights.at<float>(1); |
||||
if (y >= 0 && y <= HEIGHT) |
||||
{ |
||||
crossPoint.x = x; |
||||
crossPoint.y = y; |
||||
return true; |
||||
} |
||||
} |
||||
else if (segment.first.y == segment.second.y && weights.at<float>(0) != 0) |
||||
{ |
||||
y = segment.first.y; |
||||
x = - (weights.at<float>(1) * y + shift) / weights.at<float>(0); |
||||
if (x >= 0 && x <= WIDTH) |
||||
{ |
||||
crossPoint.x = x; |
||||
crossPoint.y = y; |
||||
return true; |
||||
} |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]) |
||||
{ |
||||
if (weights.empty()) |
||||
{ |
||||
return false; |
||||
} |
||||
|
||||
int foundPointsCount = 0; |
||||
std::vector<std::pair<Point,Point> > segments; |
||||
fillSegments(segments); |
||||
|
||||
for (int i = 0; i < 4; i++) |
||||
{ |
||||
if (findCrossPoint(weights, shift, segments[i], points[foundPointsCount])) |
||||
foundPointsCount++; |
||||
if (foundPointsCount > 2) |
||||
break; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments) |
||||
{ |
||||
std::pair<Point,Point> curSegment; |
||||
|
||||
curSegment.first = Point(0,0); |
||||
curSegment.second = Point(0,HEIGHT); |
||||
segments.push_back(curSegment); |
||||
|
||||
curSegment.first = Point(0,0); |
||||
curSegment.second = Point(WIDTH,0); |
||||
segments.push_back(curSegment); |
||||
|
||||
curSegment.first = Point(WIDTH,0); |
||||
curSegment.second = Point(WIDTH,HEIGHT); |
||||
segments.push_back(curSegment); |
||||
|
||||
curSegment.first = Point(0,HEIGHT); |
||||
curSegment.second = Point(WIDTH,HEIGHT); |
||||
segments.push_back(curSegment); |
||||
} |
||||
|
||||
void redraw(Data data, const Point points[2]) |
||||
{ |
||||
data.img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); |
||||
Point center; |
||||
int radius = 3; |
||||
Scalar color; |
||||
for (int i = 0; i < data.samples.rows; i++) |
||||
{ |
||||
center.x = data.samples.at<float>(i,0); |
||||
center.y = data.samples.at<float>(i,1); |
||||
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128); |
||||
circle(data.img, center, radius, color, 5); |
||||
} |
||||
line(data.img, points[0],points[1],cv::Scalar(1,255,1)); |
||||
|
||||
imshow("Train svmsgd", data.img); |
||||
} |
||||
|
||||
void addPointsRetrainAndRedraw(Data &data, int x, int y) |
||||
{ |
||||
|
||||
Mat currentSample(1, 2, CV_32F); |
||||
//start
|
||||
/*
|
||||
Mat _weights; |
||||
_weights.create(1, 2, CV_32FC1); |
||||
_weights.at<float>(0) = 1; |
||||
_weights.at<float>(1) = -1; |
||||
|
||||
int _x, _y; |
||||
|
||||
for (int i=0;i<199;i++) |
||||
{ |
||||
_x = data.rng.uniform(0,800); |
||||
_y = data.rng.uniform(0,500);*/ |
||||
currentSample.at<float>(0,0) = x; |
||||
currentSample.at<float>(0,1) = y; |
||||
//if (currentSample.dot(_weights) > 0)
|
||||
//data.responses.push_back(1);
|
||||
// else data.responses.push_back(-1);
|
||||
|
||||
//finish
|
||||
data.samples.push_back(currentSample); |
||||
|
||||
|
||||
|
||||
Mat weights(1, 2, CV_32F); |
||||
float shift = 0; |
||||
|
||||
if (doTrain(data.samples, data.responses, weights, shift)) |
||||
{ |
||||
Point points[2]; |
||||
shift = 0; |
||||
|
||||
findPointsForLine(weights, shift, points); |
||||
|
||||
redraw(data, points); |
||||
} |
||||
} |
||||
|
||||
|
||||
static void onMouse( int event, int x, int y, int, void* pData) |
||||
{ |
||||
Data &data = *(Data*)pData; |
||||
|
||||
switch( event ) |
||||
{ |
||||
case CV_EVENT_LBUTTONUP: |
||||
data.responses.push_back(1); |
||||
addPointsRetrainAndRedraw(data, x, y); |
||||
|
||||
break; |
||||
|
||||
case CV_EVENT_RBUTTONDOWN: |
||||
data.responses.push_back(-1); |
||||
addPointsRetrainAndRedraw(data, x, y); |
||||
break; |
||||
} |
||||
|
||||
} |
||||
|
||||
int main() |
||||
{ |
||||
|
||||
Data data; |
||||
|
||||
setMouseCallback( "Train svmsgd", onMouse, &data ); |
||||
waitKey(); |
||||
|
||||
|
||||
|
||||
|
||||
return 0; |
||||
} |
Loading…
Reference in new issue