diff --git a/modules/ml/src/svm.cpp b/modules/ml/src/svm.cpp index f71730a81c..3905f57b50 100644 --- a/modules/ml/src/svm.cpp +++ b/modules/ml/src/svm.cpp @@ -1551,6 +1551,8 @@ void CvSVM::optimize_linear_svm() return; int var_count = get_var_count(); + cv::AutoBuffer vbuf; + double* v = vbuf; int sample_size = (int)(var_count*sizeof(sv[0][0])); float** new_sv = (float**)cvMemStorageAlloc(storage, df_count*sizeof(new_sv[0])); @@ -1558,15 +1560,17 @@ void CvSVM::optimize_linear_svm() { new_sv[i] = (float*)cvMemStorageAlloc(storage, sample_size); float* dst = new_sv[i]; - memset(dst, 0, sample_size); + memset(v, 0, var_count*sizeof(v[0])); int j, k, sv_count = df[i].sv_count; for( j = 0; j < sv_count; j++ ) { const float* src = class_count > 1 && df[i].sv_index ? sv[df[i].sv_index[j]] : sv[j]; double a = df[i].alpha[j]; for( k = 0; k < var_count; k++ ) - dst[k] = (float)(dst[k] + src[k]*a); + v[k] += src[k]*a; } + for( k = 0; k < var_count; k++ ) + dst[k] = (float)v[k]; df[i].sv_count = 1; df[i].alpha[0] = 1.; if( class_count > 1 && df[i].sv_index ) @@ -2570,7 +2574,8 @@ void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node ) CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader ); } - optimize_linear_svm(); + if( cvReadIntByName(fs, svm_node, "optimize_linear", 1) != 0 ) + optimize_linear_svm(); create_kernel(); __END__; diff --git a/modules/ml/test/test_save_load.cpp b/modules/ml/test/test_save_load.cpp index fde5410ca2..6ce54a9edc 100644 --- a/modules/ml/test/test_save_load.cpp +++ b/modules/ml/test/test_save_load.cpp @@ -133,4 +133,32 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); } TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); } TEST(ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); } + +TEST(DISABLED_ML_SVM, linear_save_load) +{ + CvSVM svm1, svm2, svm3; + svm1.load("SVM45_X_38-1.xml"); + svm2.load("SVM45_X_38-2.xml"); + string tname = tempfile("a.xml"); + svm2.save(tname.c_str()); + svm3.load(tname.c_str()); + + ASSERT_EQ(svm1.get_var_count(), svm2.get_var_count()); + ASSERT_EQ(svm1.get_var_count(), svm3.get_var_count()); + + int m = 10000, n = svm1.get_var_count(); + Mat samples(m, n, CV_32F), r1, r2, r3; + randu(samples, 0., 1.); + + svm1.predict(samples, r1); + svm2.predict(samples, r2); + svm3.predict(samples, r3); + + double eps = 1e-4; + EXPECT_LE(norm(r1, r2, NORM_INF), eps); + EXPECT_LE(norm(r1, r3, NORM_INF), eps); + + remove(tname.c_str()); +} + /* End of file. */