|
|
|
@ -89,14 +89,8 @@ public: |
|
|
|
|
|
|
|
|
|
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN); |
|
|
|
|
|
|
|
|
|
virtual int getSvmsgdType() const; |
|
|
|
|
|
|
|
|
|
virtual void setSvmsgdType(int svmsgdType); |
|
|
|
|
|
|
|
|
|
virtual int getMarginType() const; |
|
|
|
|
|
|
|
|
|
virtual void setMarginType(int marginType); |
|
|
|
|
|
|
|
|
|
CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType) |
|
|
|
|
CV_IMPL_PROPERTY(int, MarginType, params.marginType) |
|
|
|
|
CV_IMPL_PROPERTY(float, Lambda, params.lambda) |
|
|
|
|
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0) |
|
|
|
|
CV_IMPL_PROPERTY(float, C, params.c) |
|
|
|
@ -132,8 +126,8 @@ private: |
|
|
|
|
float gamma0; //learning rate
|
|
|
|
|
float c; |
|
|
|
|
TermCriteria termCrit; |
|
|
|
|
SvmsgdType svmsgdType; |
|
|
|
|
MarginType marginType; |
|
|
|
|
int svmsgdType; |
|
|
|
|
int marginType; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
SVMSGDParams params; |
|
|
|
@ -148,9 +142,9 @@ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses) |
|
|
|
|
{ |
|
|
|
|
CV_Assert(responses.cols == 1 || responses.rows == 1); |
|
|
|
|
std::pair<bool,bool> emptyInClasses(true, true); |
|
|
|
|
int limit_index = responses.rows; |
|
|
|
|
int limitIndex = responses.rows; |
|
|
|
|
|
|
|
|
|
for(int index = 0; index < limit_index; index++) |
|
|
|
|
for(int index = 0; index < limitIndex; index++) |
|
|
|
|
{ |
|
|
|
|
if (isPositive(responses.at<float>(index))) |
|
|
|
|
emptyInClasses.first = false; |
|
|
|
@ -276,9 +270,9 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int) |
|
|
|
|
int extendedTrainSamplesCount = extendedTrainSamples.rows; |
|
|
|
|
int extendedFeatureCount = extendedTrainSamples.cols; |
|
|
|
|
|
|
|
|
|
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); // Initialize extendedWeights vector with zeros
|
|
|
|
|
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); //extendedWeights vector for calculating terminal criterion
|
|
|
|
|
Mat averageExtendedWeights; //average extendedWeights vector for ASGD model
|
|
|
|
|
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); |
|
|
|
|
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); |
|
|
|
|
Mat averageExtendedWeights; |
|
|
|
|
if (params.svmsgdType == ASGD) |
|
|
|
|
{ |
|
|
|
|
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); |
|
|
|
@ -407,10 +401,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const |
|
|
|
|
case ASGD: |
|
|
|
|
SvmsgdTypeStr = "ASGD"; |
|
|
|
|
break; |
|
|
|
|
case ILLEGAL_SVMSGD_TYPE: |
|
|
|
|
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType); |
|
|
|
|
default: |
|
|
|
|
std::cout << "params.svmsgdType isn't initialized" << std::endl; |
|
|
|
|
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fs << "svmsgdType" << SvmsgdTypeStr; |
|
|
|
@ -425,10 +417,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const |
|
|
|
|
case HARD_MARGIN: |
|
|
|
|
marginTypeStr = "HARD_MARGIN"; |
|
|
|
|
break; |
|
|
|
|
case ILLEGAL_MARGIN_TYPE: |
|
|
|
|
marginTypeStr = format("Unknown_%d", params.marginType); |
|
|
|
|
default: |
|
|
|
|
std::cout << "params.marginType isn't initialized" << std::endl; |
|
|
|
|
marginTypeStr = format("Unknown_%d", params.marginType); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fs << "marginType" << marginTypeStr; |
|
|
|
@ -458,21 +448,21 @@ void SVMSGDImpl::read(const FileNode& fn) |
|
|
|
|
void SVMSGDImpl::readParams( const FileNode& fn ) |
|
|
|
|
{ |
|
|
|
|
String svmsgdTypeStr = (String)fn["svmsgdType"]; |
|
|
|
|
SvmsgdType svmsgdType = |
|
|
|
|
int svmsgdType = |
|
|
|
|
svmsgdTypeStr == "SGD" ? SGD : |
|
|
|
|
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE; |
|
|
|
|
svmsgdTypeStr == "ASGD" ? ASGD : -1; |
|
|
|
|
|
|
|
|
|
if( svmsgdType == ILLEGAL_SVMSGD_TYPE ) |
|
|
|
|
if( svmsgdType < 0 ) |
|
|
|
|
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" ); |
|
|
|
|
|
|
|
|
|
params.svmsgdType = svmsgdType; |
|
|
|
|
|
|
|
|
|
String marginTypeStr = (String)fn["marginType"]; |
|
|
|
|
MarginType marginType = |
|
|
|
|
int marginType = |
|
|
|
|
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN : |
|
|
|
|
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; |
|
|
|
|
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1; |
|
|
|
|
|
|
|
|
|
if( marginType == ILLEGAL_MARGIN_TYPE ) |
|
|
|
|
if( marginType < 0 ) |
|
|
|
|
CV_Error( CV_StsParseError, "Missing or invalid margin type" ); |
|
|
|
|
|
|
|
|
|
params.marginType = marginType; |
|
|
|
@ -510,8 +500,8 @@ SVMSGDImpl::SVMSGDImpl() |
|
|
|
|
{ |
|
|
|
|
clear(); |
|
|
|
|
|
|
|
|
|
params.svmsgdType = ILLEGAL_SVMSGD_TYPE; |
|
|
|
|
params.marginType = ILLEGAL_MARGIN_TYPE; |
|
|
|
|
params.svmsgdType = -1; |
|
|
|
|
params.marginType = -1; |
|
|
|
|
|
|
|
|
|
// Parameters for learning
|
|
|
|
|
params.lambda = 0; // regularization
|
|
|
|
@ -529,7 +519,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) |
|
|
|
|
case SGD: |
|
|
|
|
params.svmsgdType = SGD; |
|
|
|
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1; |
|
|
|
|
params.lambda = 0.0001f; |
|
|
|
|
params.gamma0 = 0.05f; |
|
|
|
|
params.c = 1.f; |
|
|
|
@ -539,7 +529,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) |
|
|
|
|
case ASGD: |
|
|
|
|
params.svmsgdType = ASGD; |
|
|
|
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1; |
|
|
|
|
params.lambda = 0.00001f; |
|
|
|
|
params.gamma0 = 0.05f; |
|
|
|
|
params.c = 0.75f; |
|
|
|
@ -550,45 +540,5 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) |
|
|
|
|
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" ); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void SVMSGDImpl::setSvmsgdType(int type) |
|
|
|
|
{ |
|
|
|
|
switch (type) |
|
|
|
|
{ |
|
|
|
|
case SGD: |
|
|
|
|
params.svmsgdType = SGD; |
|
|
|
|
break; |
|
|
|
|
case ASGD: |
|
|
|
|
params.svmsgdType = ASGD; |
|
|
|
|
break; |
|
|
|
|
default: |
|
|
|
|
params.svmsgdType = ILLEGAL_SVMSGD_TYPE; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
int SVMSGDImpl::getSvmsgdType() const |
|
|
|
|
{ |
|
|
|
|
return params.svmsgdType; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void SVMSGDImpl::setMarginType(int type) |
|
|
|
|
{ |
|
|
|
|
switch (type) |
|
|
|
|
{ |
|
|
|
|
case HARD_MARGIN: |
|
|
|
|
params.marginType = HARD_MARGIN; |
|
|
|
|
break; |
|
|
|
|
case SOFT_MARGIN: |
|
|
|
|
params.marginType = SOFT_MARGIN; |
|
|
|
|
break; |
|
|
|
|
default: |
|
|
|
|
params.marginType = ILLEGAL_MARGIN_TYPE; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
int SVMSGDImpl::getMarginType() const |
|
|
|
|
{ |
|
|
|
|
return params.marginType; |
|
|
|
|
} |
|
|
|
|
} //ml
|
|
|
|
|
} //cv
|
|
|
|
|