Deleted illegal type values.

pull/6096/head
Marina Noskova 9 years ago
parent ff54952769
commit 02cd8cf039
  1. 2
      modules/ml/include/opencv2/ml.hpp
  2. 92
      modules/ml/src/svmsgd.cpp

@ -1588,7 +1588,6 @@ public:
ASGD is often the preferable choice. */
enum SvmsgdType
{
ILLEGAL_SVMSGD_TYPE,
SGD, //!< Stochastic Gradient Descent
ASGD //!< Average Stochastic Gradient Descent
};
@ -1596,7 +1595,6 @@ public:
/** Margin type.*/
enum MarginType
{
ILLEGAL_MARGIN_TYPE,
SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers.
HARD_MARGIN //!< More accurate for the case of linearly separable sets.
};

@ -89,14 +89,8 @@ public:
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
virtual int getSvmsgdType() const;
virtual void setSvmsgdType(int svmsgdType);
virtual int getMarginType() const;
virtual void setMarginType(int marginType);
CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType)
CV_IMPL_PROPERTY(int, MarginType, params.marginType)
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
CV_IMPL_PROPERTY(float, C, params.c)
@ -132,8 +126,8 @@ private:
float gamma0; //learning rate
float c;
TermCriteria termCrit;
SvmsgdType svmsgdType;
MarginType marginType;
int svmsgdType;
int marginType;
};
SVMSGDParams params;
@ -148,9 +142,9 @@ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
{
CV_Assert(responses.cols == 1 || responses.rows == 1);
std::pair<bool,bool> emptyInClasses(true, true);
int limit_index = responses.rows;
int limitIndex = responses.rows;
for(int index = 0; index < limit_index; index++)
for(int index = 0; index < limitIndex; index++)
{
if (isPositive(responses.at<float>(index)))
emptyInClasses.first = false;
@ -276,9 +270,9 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
int extendedTrainSamplesCount = extendedTrainSamples.rows;
int extendedFeatureCount = extendedTrainSamples.cols;
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); // Initialize extendedWeights vector with zeros
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); //extendedWeights vector for calculating terminal criterion
Mat averageExtendedWeights; //average extendedWeights vector for ASGD model
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
Mat averageExtendedWeights;
if (params.svmsgdType == ASGD)
{
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
@ -407,10 +401,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
case ASGD:
SvmsgdTypeStr = "ASGD";
break;
case ILLEGAL_SVMSGD_TYPE:
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
default:
std::cout << "params.svmsgdType isn't initialized" << std::endl;
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
}
fs << "svmsgdType" << SvmsgdTypeStr;
@ -425,10 +417,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
case HARD_MARGIN:
marginTypeStr = "HARD_MARGIN";
break;
case ILLEGAL_MARGIN_TYPE:
marginTypeStr = format("Unknown_%d", params.marginType);
default:
std::cout << "params.marginType isn't initialized" << std::endl;
marginTypeStr = format("Unknown_%d", params.marginType);
}
fs << "marginType" << marginTypeStr;
@ -458,21 +448,21 @@ void SVMSGDImpl::read(const FileNode& fn)
void SVMSGDImpl::readParams( const FileNode& fn )
{
String svmsgdTypeStr = (String)fn["svmsgdType"];
SvmsgdType svmsgdType =
int svmsgdType =
svmsgdTypeStr == "SGD" ? SGD :
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE;
svmsgdTypeStr == "ASGD" ? ASGD : -1;
if( svmsgdType == ILLEGAL_SVMSGD_TYPE )
if( svmsgdType < 0 )
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
params.svmsgdType = svmsgdType;
String marginTypeStr = (String)fn["marginType"];
MarginType marginType =
int marginType =
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
if( marginType == ILLEGAL_MARGIN_TYPE )
if( marginType < 0 )
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
params.marginType = marginType;
@ -510,8 +500,8 @@ SVMSGDImpl::SVMSGDImpl()
{
clear();
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
params.marginType = ILLEGAL_MARGIN_TYPE;
params.svmsgdType = -1;
params.marginType = -1;
// Parameters for learning
params.lambda = 0; // regularization
@ -529,7 +519,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
case SGD:
params.svmsgdType = SGD;
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
params.lambda = 0.0001f;
params.gamma0 = 0.05f;
params.c = 1.f;
@ -539,7 +529,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
case ASGD:
params.svmsgdType = ASGD;
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
params.lambda = 0.00001f;
params.gamma0 = 0.05f;
params.c = 0.75f;
@ -550,45 +540,5 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
}
}
void SVMSGDImpl::setSvmsgdType(int type)
{
switch (type)
{
case SGD:
params.svmsgdType = SGD;
break;
case ASGD:
params.svmsgdType = ASGD;
break;
default:
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
}
}
int SVMSGDImpl::getSvmsgdType() const
{
return params.svmsgdType;
}
void SVMSGDImpl::setMarginType(int type)
{
switch (type)
{
case HARD_MARGIN:
params.marginType = HARD_MARGIN;
break;
case SOFT_MARGIN:
params.marginType = SOFT_MARGIN;
break;
default:
params.marginType = ILLEGAL_MARGIN_TYPE;
}
}
int SVMSGDImpl::getMarginType() const
{
return params.marginType;
}
} //ml
} //cv

Loading…
Cancel
Save