diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 534820645b..eab88e0d64 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -505,6 +505,14 @@ public: The static method creates empty %KNearest classifier. It should be then trained using StatModel::train method. */ CV_WRAP static Ptr create(); + /** @brief Loads and creates a serialized knearest from a file + * + * Use KNearest::save to serialize and store an KNearest to disk. + * Load the KNearest from this file again, by calling this function with the path to the file. + * + * @param filepath path to serialized KNearest + */ + CV_WRAP static Ptr load(const String& filepath); }; /****************************************************************************************\ diff --git a/modules/ml/misc/python/test/test_knearest.py b/modules/ml/misc/python/test/test_knearest.py new file mode 100644 index 0000000000..8ae0be5f73 --- /dev/null +++ b/modules/ml/misc/python/test/test_knearest.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +import cv2 as cv + +from tests_common import NewOpenCVTests + +class knearest_test(NewOpenCVTests): + def test_load(self): + k_nearest = cv.ml.KNearest_load(self.find_file("ml/opencv_ml_knn.xml")) + self.assertFalse(k_nearest.empty()) + self.assertTrue(k_nearest.isTrained()) + +if __name__ == '__main__': + NewOpenCVTests.bootstrap() diff --git a/modules/ml/src/knearest.cpp b/modules/ml/src/knearest.cpp index cee7bdfdb0..dcc201158d 100644 --- a/modules/ml/src/knearest.cpp +++ b/modules/ml/src/knearest.cpp @@ -515,6 +515,17 @@ Ptr KNearest::create() return makePtr(); } +Ptr KNearest::load(const String& filepath) +{ + FileStorage fs; + fs.open(filepath, FileStorage::READ); + + Ptr knearest = makePtr(); + + ((KNearestImpl*)knearest.get())->read(fs.getFirstTopLevelNode()); + return knearest; +} + } }