From 0d706f679647ee27c1f3e9b8ef60fb0d4c5204ba Mon Sep 17 00:00:00 2001 From: Deanna Hood Date: Sat, 18 Apr 2015 21:32:29 -0400 Subject: [PATCH] Return uncompressed support vectors for getSupportVectors on linear SVM (Bug #4096) --- modules/ml/include/opencv2/ml.hpp | 10 +++- modules/ml/src/svm.cpp | 54 +++++++++++++++++-- modules/ml/test/test_svmtrainauto.cpp | 48 +++++++++++++++++ .../introduction_to_svm.cpp | 2 +- 4 files changed, 109 insertions(+), 5 deletions(-) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 715cbd998a..862f3f950c 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -675,11 +675,19 @@ public: /** @brief Retrieves all the support vectors - The method returns all the support vector as floating-point matrix, where support vectors are + The method returns all the support vectors as a floating-point matrix, where support vectors are stored as matrix rows. */ CV_WRAP virtual Mat getSupportVectors() const = 0; + /** @brief Retrieves all the uncompressed support vectors of a linear %SVM + + The method returns all the uncompressed support vectors of a linear %SVM that the compressed + support vector, used for prediction, was derived from. They are returned in a floating-point + matrix, where the support vectors are stored as matrix rows. + */ + CV_WRAP Mat getUncompressedSupportVectors() const; + /** @brief Retrieves the decision function @param i the index of the decision function. If the problem solved is regression, 1-class or diff --git a/modules/ml/src/svm.cpp b/modules/ml/src/svm.cpp index 0fd73a3891..757bb7a171 100644 --- a/modules/ml/src/svm.cpp +++ b/modules/ml/src/svm.cpp @@ -1241,6 +1241,12 @@ public: df_alpha.clear(); df_index.clear(); sv.release(); + uncompressed_sv.release(); + } + + Mat getUncompressedSupportVectors_() const + { + return uncompressed_sv; } Mat getSupportVectors() const @@ -1538,6 +1544,7 @@ public: } optimize_linear_svm(); + return true; } @@ -1588,6 +1595,7 @@ public: setRangeVector(df_index, df_count); df_alpha.assign(df_count, 1.); + sv.copyTo(uncompressed_sv); std::swap(sv, new_sv); std::swap(decision_func, new_df); } @@ -2056,6 +2064,21 @@ public: } fs << "]"; + if ( !uncompressed_sv.empty() ) + { + // write the joint collection of uncompressed support vectors + int uncompressed_sv_total = uncompressed_sv.rows; + fs << "uncompressed_sv_total" << uncompressed_sv_total; + fs << "uncompressed_support_vectors" << "["; + for( i = 0; i < uncompressed_sv_total; i++ ) + { + fs << "[:"; + fs.writeRaw("f", uncompressed_sv.ptr(i), uncompressed_sv.cols*uncompressed_sv.elemSize()); + fs << "]"; + } + fs << "]"; + } + // write decision functions int df_count = (int)decision_func.size(); @@ -2096,7 +2119,7 @@ public: svm_type_str == "NU_SVR" ? NU_SVR : -1; if( svmType < 0 ) - CV_Error( CV_StsParseError, "Missing of invalid SVM type" ); + CV_Error( CV_StsParseError, "Missing or invalid SVM type" ); FileNode kernel_node = fn["kernel"]; if( kernel_node.empty() ) @@ -2168,14 +2191,31 @@ public: FileNode sv_node = fn["support_vectors"]; CV_Assert((int)sv_node.size() == sv_total); - sv.create(sv_total, var_count, CV_32F); + sv.create(sv_total, var_count, CV_32F); FileNodeIterator sv_it = sv_node.begin(); for( i = 0; i < sv_total; i++, ++sv_it ) { (*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize()); } + int uncompressed_sv_total = (int)fn["uncompressed_sv_total"]; + + if( uncompressed_sv_total > 0 ) + { + // read uncompressed support vectors + FileNode uncompressed_sv_node = fn["uncompressed_support_vectors"]; + + CV_Assert((int)uncompressed_sv_node.size() == uncompressed_sv_total); + uncompressed_sv.create(uncompressed_sv_total, var_count, CV_32F); + + FileNodeIterator uncompressed_sv_it = uncompressed_sv_node.begin(); + for( i = 0; i < uncompressed_sv_total; i++, ++uncompressed_sv_it ) + { + (*uncompressed_sv_it).readRaw("f", uncompressed_sv.ptr(i), var_count*uncompressed_sv.elemSize()); + } + } + // read decision functions int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1; FileNode df_node = fn["decision_functions"]; @@ -2207,7 +2247,7 @@ public: SvmParams params; Mat class_labels; int var_count; - Mat sv; + Mat sv, uncompressed_sv; vector decision_func; vector df_alpha; vector df_index; @@ -2221,6 +2261,14 @@ Ptr SVM::create() return makePtr(); } +Mat SVM::getUncompressedSupportVectors() const +{ + const SVMImpl* this_ = dynamic_cast(this); + if(!this_) + CV_Error(Error::StsNotImplemented, "the class is not SVMImpl"); + return this_->getUncompressedSupportVectors_(); +} + } } diff --git a/modules/ml/test/test_svmtrainauto.cpp b/modules/ml/test/test_svmtrainauto.cpp index 3c4b729245..13cbe98f42 100644 --- a/modules/ml/test/test_svmtrainauto.cpp +++ b/modules/ml/test/test_svmtrainauto.cpp @@ -118,3 +118,51 @@ TEST(ML_SVM, trainAuto_regression_5369) EXPECT_EQ(0., result0); EXPECT_EQ(1., result1); } + +class CV_SVMGetSupportVectorsTest : public cvtest::BaseTest { +public: + CV_SVMGetSupportVectorsTest() {} +protected: + virtual void run( int startFrom ); +}; +void CV_SVMGetSupportVectorsTest::run(int /*startFrom*/ ) +{ + int code = cvtest::TS::OK; + + // Set up training data + int labels[4] = {1, -1, -1, -1}; + float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} }; + Mat trainingDataMat(4, 2, CV_32FC1, trainingData); + Mat labelsMat(4, 1, CV_32SC1, labels); + + Ptr svm = SVM::create(); + svm->setType(SVM::C_SVC); + svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6)); + + + // Test retrieval of SVs and compressed SVs on linear SVM + svm->setKernel(SVM::LINEAR); + svm->train(trainingDataMat, cv::ml::ROW_SAMPLE, labelsMat); + + Mat sv = svm->getSupportVectors(); + CV_Assert(sv.rows == 1); // by default compressed SV returned + sv = svm->getUncompressedSupportVectors(); + CV_Assert(sv.rows == 3); + + + // Test retrieval of SVs and compressed SVs on non-linear SVM + svm->setKernel(SVM::POLY); + svm->setDegree(2); + svm->train(trainingDataMat, cv::ml::ROW_SAMPLE, labelsMat); + + sv = svm->getSupportVectors(); + CV_Assert(sv.rows == 3); + sv = svm->getUncompressedSupportVectors(); + CV_Assert(sv.rows == 0); // inapplicable for non-linear SVMs + + + ts->set_failed_test_info(code); +} + + +TEST(ML_SVM, getSupportVectors) { CV_SVMGetSupportVectorsTest test; test.safe_run(); } diff --git a/samples/cpp/tutorial_code/ml/introduction_to_svm/introduction_to_svm.cpp b/samples/cpp/tutorial_code/ml/introduction_to_svm/introduction_to_svm.cpp index 0513e367d6..9b0d569c65 100644 --- a/samples/cpp/tutorial_code/ml/introduction_to_svm/introduction_to_svm.cpp +++ b/samples/cpp/tutorial_code/ml/introduction_to_svm/introduction_to_svm.cpp @@ -65,7 +65,7 @@ int main(int, char**) //! [show_vectors] thickness = 2; lineType = 8; - Mat sv = svm->getSupportVectors(); + Mat sv = svm->getUncompressedSupportVectors(); for (int i = 0; i < sv.rows; ++i) {