|
|
|
@ -11,7 +11,7 @@ |
|
|
|
|
// For Open Source Computer Vision Library
|
|
|
|
|
//
|
|
|
|
|
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
|
|
|
|
// Copyright (C) 2014, Itseez Inc, all rights reserved.
|
|
|
|
|
// Copyright (C) 2016, Itseez Inc, all rights reserved.
|
|
|
|
|
// Third party copyrights are property of their respective owners.
|
|
|
|
|
//
|
|
|
|
|
// Redistribution and use in source and binary forms, with or without modification,
|
|
|
|
@ -103,7 +103,7 @@ public: |
|
|
|
|
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit) |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
void updateWeights(InputArray sample, bool isFirstClass, float gamma, Mat &weights); |
|
|
|
|
void updateWeights(InputArray sample, bool isPositive, float gamma, Mat &weights); |
|
|
|
|
|
|
|
|
|
std::pair<bool,bool> areClassesEmpty(Mat responses); |
|
|
|
|
|
|
|
|
@ -111,7 +111,7 @@ private: |
|
|
|
|
|
|
|
|
|
void readParams( const FileNode &fn ); |
|
|
|
|
|
|
|
|
|
static inline bool isFirstClass(float val) { return val > 0; } |
|
|
|
|
static inline bool isPositive(float val) { return val > 0; } |
|
|
|
|
|
|
|
|
|
static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier); |
|
|
|
|
|
|
|
|
@ -152,7 +152,7 @@ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses) |
|
|
|
|
|
|
|
|
|
for(int index = 0; index < limit_index; index++) |
|
|
|
|
{ |
|
|
|
|
if (isFirstClass(responses.at<float>(index))) |
|
|
|
|
if (isPositive(responses.at<float>(index))) |
|
|
|
|
emptyInClasses.first = false; |
|
|
|
|
else |
|
|
|
|
emptyInClasses.second = false; |
|
|
|
@ -172,7 +172,7 @@ void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier) |
|
|
|
|
average = Mat(1, featuresCount, samples.type()); |
|
|
|
|
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++) |
|
|
|
|
{ |
|
|
|
|
Scalar scalAverage = mean(samples.col(featureIndex))[0]; |
|
|
|
|
Scalar scalAverage = mean(samples.col(featureIndex)); |
|
|
|
|
average.at<float>(featureIndex) = static_cast<float>(scalAverage[0]); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -190,13 +190,13 @@ void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier) |
|
|
|
|
|
|
|
|
|
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier) |
|
|
|
|
{ |
|
|
|
|
Mat normalisedTrainSamples = trainSamples.clone(); |
|
|
|
|
int samplesCount = normalisedTrainSamples.rows; |
|
|
|
|
Mat normalizedTrainSamples = trainSamples.clone(); |
|
|
|
|
int samplesCount = normalizedTrainSamples.rows; |
|
|
|
|
|
|
|
|
|
normalizeSamples(normalisedTrainSamples, average, multiplier); |
|
|
|
|
normalizeSamples(normalizedTrainSamples, average, multiplier); |
|
|
|
|
|
|
|
|
|
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F); |
|
|
|
|
cv::hconcat(normalisedTrainSamples, onesCol, extendedTrainSamples); |
|
|
|
|
cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma, Mat& weights) |
|
|
|
@ -231,7 +231,7 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const |
|
|
|
|
Mat currentSample = trainSamples.row(samplesIndex); |
|
|
|
|
float dotProduct = static_cast<float>(currentSample.dot(weights_)); |
|
|
|
|
|
|
|
|
|
bool firstClass = isFirstClass(trainResponses.at<float>(samplesIndex)); |
|
|
|
|
bool firstClass = isPositive(trainResponses.at<float>(samplesIndex)); |
|
|
|
|
int index = firstClass ? 0 : 1; |
|
|
|
|
float signToMul = firstClass ? 1.f : -1.f; |
|
|
|
|
float curDistance = dotProduct * signToMul; |
|
|
|
@ -297,11 +297,10 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int) |
|
|
|
|
int randomNumber = rng.uniform(0, extendedTrainSamplesCount); //generate sample number
|
|
|
|
|
|
|
|
|
|
Mat currentSample = extendedTrainSamples.row(randomNumber); |
|
|
|
|
bool firstClass = isFirstClass(trainResponses.at<float>(randomNumber)); |
|
|
|
|
|
|
|
|
|
float gamma = params.gamma0 * std::pow((1 + params.lambda * params.gamma0 * (float)iter), (-params.c)); //update gamma
|
|
|
|
|
|
|
|
|
|
updateWeights( currentSample, firstClass, gamma, extendedWeights ); |
|
|
|
|
updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), gamma, extendedWeights ); |
|
|
|
|
|
|
|
|
|
//average weights (only for ASGD model)
|
|
|
|
|
if (params.svmsgdType == ASGD) |
|
|
|
|