|
|
|
@ -1675,6 +1675,7 @@ public: |
|
|
|
|
Mat samples = data->getTrainSamples(); |
|
|
|
|
Mat responses; |
|
|
|
|
bool is_classification = false; |
|
|
|
|
Mat class_labels0; |
|
|
|
|
int class_count = (int)class_labels.total(); |
|
|
|
|
|
|
|
|
|
if( svmType == C_SVC || svmType == NU_SVC ) |
|
|
|
@ -1688,7 +1689,8 @@ public: |
|
|
|
|
setRangeVector(temp_class_labels, class_count); |
|
|
|
|
|
|
|
|
|
// temporarily replace class labels with 0, 1, ..., NCLASSES-1
|
|
|
|
|
Mat(temp_class_labels).copyTo(class_labels); |
|
|
|
|
class_labels0 = class_labels; |
|
|
|
|
class_labels = Mat(temp_class_labels).clone(); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
responses = data->getTrainResponses(); |
|
|
|
@ -1821,6 +1823,7 @@ public: |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
params = best_params; |
|
|
|
|
class_labels = class_labels0; |
|
|
|
|
return do_train( samples, responses ); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|