|
|
|
@ -97,7 +97,7 @@ public: |
|
|
|
|
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit) |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
void updateWeights(InputArray sample, bool isPositive, float stepSize, Mat &weights); |
|
|
|
|
void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights); |
|
|
|
|
|
|
|
|
|
void writeParams( FileStorage &fs ) const; |
|
|
|
|
|
|
|
|
@ -111,8 +111,6 @@ private: |
|
|
|
|
|
|
|
|
|
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Vector with SVM weights
|
|
|
|
|
Mat weights_; |
|
|
|
|
float shift_; |
|
|
|
@ -263,11 +261,12 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int) |
|
|
|
|
|
|
|
|
|
RNG rng(0); |
|
|
|
|
|
|
|
|
|
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS) && (trainResponses.type() == CV_32FC1)); |
|
|
|
|
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS); |
|
|
|
|
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX; |
|
|
|
|
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0; |
|
|
|
|
|
|
|
|
|
double err = DBL_MAX; |
|
|
|
|
CV_Assert (trainResponses.type() == CV_32FC1); |
|
|
|
|
// Stochastic gradient descent SVM
|
|
|
|
|
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++) |
|
|
|
|
{ |
|
|
|
@ -288,8 +287,8 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int) |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
err = norm(extendedWeights - previousWeights); |
|
|
|
|
extendedWeights.copyTo(previousWeights); |
|
|
|
|
err = norm(extendedWeights - previousWeights); |
|
|
|
|
extendedWeights.copyTo(previousWeights); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -316,7 +315,6 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int) |
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const |
|
|
|
|
{ |
|
|
|
|
float result = 0; |
|
|
|
@ -417,17 +415,6 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const |
|
|
|
|
fs << "iterations" << params.termCrit.maxCount; |
|
|
|
|
fs << "}"; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void SVMSGDImpl::read(const FileNode& fn) |
|
|
|
|
{ |
|
|
|
|
clear(); |
|
|
|
|
|
|
|
|
|
readParams(fn); |
|
|
|
|
|
|
|
|
|
fn["weights"] >> weights_; |
|
|
|
|
fn["shift"] >> shift_; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void SVMSGDImpl::readParams( const FileNode& fn ) |
|
|
|
|
{ |
|
|
|
|
String svmsgdTypeStr = (String)fn["svmsgdType"]; |
|
|
|
@ -443,7 +430,7 @@ void SVMSGDImpl::readParams( const FileNode& fn ) |
|
|
|
|
String marginTypeStr = (String)fn["marginType"]; |
|
|
|
|
int marginType = |
|
|
|
|
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN : |
|
|
|
|
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1; |
|
|
|
|
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1; |
|
|
|
|
|
|
|
|
|
if( marginType < 0 ) |
|
|
|
|
CV_Error( CV_StsParseError, "Missing or invalid margin type" ); |
|
|
|
@ -460,16 +447,22 @@ void SVMSGDImpl::readParams( const FileNode& fn ) |
|
|
|
|
params.stepDecreasingPower = (float)fn["stepDecreasingPower"]; |
|
|
|
|
|
|
|
|
|
FileNode tcnode = fn["term_criteria"]; |
|
|
|
|
if( !tcnode.empty() ) |
|
|
|
|
{ |
|
|
|
|
params.termCrit.epsilon = (double)tcnode["epsilon"]; |
|
|
|
|
params.termCrit.maxCount = (int)tcnode["iterations"]; |
|
|
|
|
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) + |
|
|
|
|
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 100000, FLT_EPSILON ); |
|
|
|
|
CV_Assert(!tcnode.empty()); |
|
|
|
|
params.termCrit.epsilon = (double)tcnode["epsilon"]; |
|
|
|
|
params.termCrit.maxCount = (int)tcnode["iterations"]; |
|
|
|
|
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) + |
|
|
|
|
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0); |
|
|
|
|
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void SVMSGDImpl::read(const FileNode& fn) |
|
|
|
|
{ |
|
|
|
|
clear(); |
|
|
|
|
|
|
|
|
|
readParams(fn); |
|
|
|
|
|
|
|
|
|
fn["weights"] >> weights_; |
|
|
|
|
fn["shift"] >> shift_; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void SVMSGDImpl::clear() |
|
|
|
@ -492,7 +485,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) |
|
|
|
|
case SGD: |
|
|
|
|
params.svmsgdType = SGD; |
|
|
|
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1; |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1; |
|
|
|
|
params.marginRegularization = 0.0001f; |
|
|
|
|
params.initialStepSize = 0.05f; |
|
|
|
|
params.stepDecreasingPower = 1.f; |
|
|
|
@ -502,7 +495,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) |
|
|
|
|
case ASGD: |
|
|
|
|
params.svmsgdType = ASGD; |
|
|
|
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1; |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1; |
|
|
|
|
params.marginRegularization = 0.00001f; |
|
|
|
|
params.initialStepSize = 0.05f; |
|
|
|
|
params.stepDecreasingPower = 0.75f; |
|
|
|
|