|
|
|
@ -172,7 +172,8 @@ void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier) |
|
|
|
|
average = Mat(1, featuresCount, samples.type()); |
|
|
|
|
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++) |
|
|
|
|
{ |
|
|
|
|
average.at<float>(featureIndex) = mean(samples.col(featureIndex))[0]; |
|
|
|
|
Scalar scalAverage = mean(samples.col(featureIndex))[0]; |
|
|
|
|
average.at<float>(featureIndex) = static_cast<float>(scalAverage[0]); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++) |
|
|
|
@ -182,7 +183,7 @@ void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier) |
|
|
|
|
|
|
|
|
|
double normValue = norm(samples); |
|
|
|
|
|
|
|
|
|
multiplier = sqrt(samples.total()) / normValue; |
|
|
|
|
multiplier = static_cast<float>(sqrt(samples.total()) / normValue); |
|
|
|
|
|
|
|
|
|
samples *= multiplier; |
|
|
|
|
} |
|
|
|
@ -228,11 +229,11 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const |
|
|
|
|
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++) |
|
|
|
|
{ |
|
|
|
|
Mat currentSample = trainSamples.row(samplesIndex); |
|
|
|
|
float dotProduct = currentSample.dot(weights_); |
|
|
|
|
float dotProduct = static_cast<float>(currentSample.dot(weights_)); |
|
|
|
|
|
|
|
|
|
bool firstClass = isFirstClass(trainResponses.at<float>(samplesIndex)); |
|
|
|
|
int index = firstClass ? 0:1; |
|
|
|
|
float signToMul = firstClass ? 1 : -1; |
|
|
|
|
int index = firstClass ? 0 : 1; |
|
|
|
|
float signToMul = firstClass ? 1.f : -1.f; |
|
|
|
|
float curDistance = dotProduct * signToMul; |
|
|
|
|
|
|
|
|
|
if (curDistance < distanceToClasses[index]) |
|
|
|
@ -263,7 +264,7 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int) |
|
|
|
|
if ( areEmpty.first || areEmpty.second ) |
|
|
|
|
{ |
|
|
|
|
weights_ = Mat::zeros(1, featureCount, CV_32F); |
|
|
|
|
shift_ = areEmpty.first ? -1 : 1; |
|
|
|
|
shift_ = areEmpty.first ? -1.f : 1.f; |
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -329,7 +330,7 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int) |
|
|
|
|
|
|
|
|
|
if (params.marginType == SOFT_MARGIN) |
|
|
|
|
{ |
|
|
|
|
shift_ = extendedWeights.at<float>(featureCount) - weights_.dot(average); |
|
|
|
|
shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average)); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
@ -363,8 +364,8 @@ float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) cons |
|
|
|
|
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++) |
|
|
|
|
{ |
|
|
|
|
Mat currentSample = samples.row(sampleIndex); |
|
|
|
|
float criterion = currentSample.dot(weights_) + shift_; |
|
|
|
|
results.at<float>(sampleIndex) = (criterion >= 0) ? 1 : -1; |
|
|
|
|
float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_; |
|
|
|
|
results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return result; |
|
|
|
@ -530,9 +531,9 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) |
|
|
|
|
params.svmsgdType = SGD; |
|
|
|
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; |
|
|
|
|
params.lambda = 0.0001; |
|
|
|
|
params.gamma0 = 0.05; |
|
|
|
|
params.c = 1; |
|
|
|
|
params.lambda = 0.0001f; |
|
|
|
|
params.gamma0 = 0.05f; |
|
|
|
|
params.c = 1.f; |
|
|
|
|
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001); |
|
|
|
|
break; |
|
|
|
|
|
|
|
|
@ -540,9 +541,9 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) |
|
|
|
|
params.svmsgdType = ASGD; |
|
|
|
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : |
|
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; |
|
|
|
|
params.lambda = 0.00001; |
|
|
|
|
params.gamma0 = 0.05; |
|
|
|
|
params.c = 0.75; |
|
|
|
|
params.lambda = 0.00001f; |
|
|
|
|
params.gamma0 = 0.05f; |
|
|
|
|
params.c = 0.75f; |
|
|
|
|
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001); |
|
|
|
|
break; |
|
|
|
|
|
|
|
|
|