diff --git a/modules/gapi/CMakeLists.txt b/modules/gapi/CMakeLists.txt index 82f04f8611..692a1c1423 100644 --- a/modules/gapi/CMakeLists.txt +++ b/modules/gapi/CMakeLists.txt @@ -71,6 +71,7 @@ set(gapi_srcs src/api/kernels_core.cpp src/api/kernels_imgproc.cpp src/api/kernels_video.cpp + src/api/kernels_nnparsers.cpp src/api/render.cpp src/api/render_ocv.cpp src/api/ginfer.cpp @@ -105,6 +106,7 @@ set(gapi_srcs src/backends/cpu/gcpuimgproc.cpp src/backends/cpu/gcpuvideo.cpp src/backends/cpu/gcpucore.cpp + src/backends/cpu/gnnparsers.cpp # Fluid Backend (also built-in, FIXME:move away) src/backends/fluid/gfluidbuffer.cpp diff --git a/modules/gapi/include/opencv2/gapi/core.hpp b/modules/gapi/include/opencv2/gapi/core.hpp index 5b9fb5adda..e3c7b8cd3d 100644 --- a/modules/gapi/include/opencv2/gapi/core.hpp +++ b/modules/gapi/include/opencv2/gapi/core.hpp @@ -31,7 +31,7 @@ namespace core { using GMat2 = std::tuple; using GMat3 = std::tuple; // FIXME: how to avoid this? using GMat4 = std::tuple; - using GMatScalar = std::tuple; + using GMatScalar = std::tuple; G_TYPED_KERNEL(GAdd, , "org.opencv.core.math.add") { static GMatDesc outMeta(GMatDesc a, GMatDesc b, int ddepth) { @@ -501,6 +501,18 @@ namespace core { return in.withType(in.depth, in.chan).withSize(dsize); } }; + + G_TYPED_KERNEL(GSize, (GMat)>, "org.opencv.core.size") { + static GOpaqueDesc outMeta(const GMatDesc&) { + return empty_gopaque_desc(); + } + }; + + G_TYPED_KERNEL(GSizeR, (GOpaque)>, "org.opencv.core.sizeR") { + static GOpaqueDesc outMeta(const GOpaqueDesc&) { + return empty_gopaque_desc(); + } + }; } //! @addtogroup gapi_math @@ -1720,6 +1732,24 @@ GAPI_EXPORTS GMat warpAffine(const GMat& src, const Mat& M, const Size& dsize, i int borderMode = cv::BORDER_CONSTANT, const Scalar& borderValue = Scalar()); //! @} gapi_transform +/** @brief Gets dimensions from Mat. + +@note Function textual ID is "org.opencv.core.size" + +@param src Input tensor +@return Size (tensor dimensions). +*/ +GAPI_EXPORTS GOpaque size(const GMat& src); + +/** @overload +Gets dimensions from rectangle. + +@note Function textual ID is "org.opencv.core.sizeR" + +@param r Input rectangle. +@return Size (rectangle dimensions). +*/ +GAPI_EXPORTS GOpaque size(const GOpaque& r); } //namespace gapi } //namespace cv diff --git a/modules/gapi/include/opencv2/gapi/infer/parsers.hpp b/modules/gapi/include/opencv2/gapi/infer/parsers.hpp new file mode 100644 index 0000000000..c3488f5799 --- /dev/null +++ b/modules/gapi/include/opencv2/gapi/infer/parsers.hpp @@ -0,0 +1,125 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. +// +// Copyright (C) 2020 Intel Corporation + + +#ifndef OPENCV_GAPI_PARSERS_HPP +#define OPENCV_GAPI_PARSERS_HPP + +#include // std::tuple + +#include +#include + +namespace cv { namespace gapi { +namespace nn { +namespace parsers { + using GRects = GArray; + using GDetections = std::tuple, GArray>; + + G_TYPED_KERNEL(GParseSSDBL, , float, int)>, + "org.opencv.nn.parsers.parseSSD_BL") { + static std::tuple outMeta(const GMatDesc&, const GOpaqueDesc&, float, int) { + return std::make_tuple(empty_array_desc(), empty_array_desc()); + } + }; + + G_TYPED_KERNEL(GParseSSD, , float, bool, bool)>, + "org.opencv.nn.parsers.parseSSD") { + static GArrayDesc outMeta(const GMatDesc&, const GOpaqueDesc&, float, bool, bool) { + return empty_array_desc(); + } + }; + + G_TYPED_KERNEL(GParseYolo, , float, float, std::vector)>, + "org.opencv.nn.parsers.parseYolo") { + static std::tuple outMeta(const GMatDesc&, const GOpaqueDesc&, + float, float, const std::vector&) { + return std::make_tuple(empty_array_desc(), empty_array_desc()); + } + static const std::vector& defaultAnchors() { + static std::vector anchors { + 0.57273f, 0.677385f, 1.87446f, 2.06253f, 3.33843f, 5.47434f, 7.88282f, 3.52778f, 9.77052f, 9.16828f + }; + return anchors; + } + }; +} // namespace parsers +} // namespace nn + +/** @brief Parses output of SSD network. + +Extracts detection information (box, confidence, label) from SSD output and +filters it by given confidence and label. + +@note Function textual ID is "org.opencv.nn.parsers.parseSSD_BL" + +@param in Input CV_32F tensor with {1,1,N,7} dimensions. +@param inSz Size to project detected boxes to (size of the input image). +@param confidenceThreshold If confidence of the +detection is smaller than confidence threshold, detection is rejected. +@param filterLabel If provided (!= -1), only detections with +given label will get to the output. +@return a tuple with a vector of detected boxes and a vector of appropriate labels. +*/ +GAPI_EXPORTS std::tuple, GArray> parseSSD(const GMat& in, + const GOpaque& inSz, + const float confidenceThreshold = 0.5f, + const int filterLabel = -1); + +/** @overload +Extracts detection information (box, confidence) from SSD output and +filters it by given confidence and by going out of bounds. + +@note Function textual ID is "org.opencv.nn.parsers.parseSSD" + +@param in Input CV_32F tensor with {1,1,N,7} dimensions. +@param inSz Size to project detected boxes to (size of the input image). +@param confidenceThreshold If confidence of the +detection is smaller than confidence threshold, detection is rejected. +@param alignmentToSquare If provided true, bounding boxes are extended to squares. +The center of the rectangle remains unchanged, the side of the square is +the larger side of the rectangle. +@param filterOutOfBounds If provided true, out-of-frame boxes are filtered. +@return a vector of detected bounding boxes. +*/ +GAPI_EXPORTS GArray parseSSD(const GMat& in, + const GOpaque& inSz, + const float confidenceThreshold = 0.5f, + const bool alignmentToSquare = false, + const bool filterOutOfBounds = false); + +/** @brief Parses output of Yolo network. + +Extracts detection information (box, confidence, label) from Yolo output, +filters it by given confidence and performs non-maximum supression for overlapping boxes. + +@note Function textual ID is "org.opencv.nn.parsers.parseYolo" + +@param in Input CV_32F tensor with {1,13,13,N} dimensions, N should satisfy: +\f[\texttt{N} = (\texttt{num_classes} + \texttt{5}) * \texttt{5},\f] +where num_classes - a number of classes Yolo network was trained with. +@param inSz Size to project detected boxes to (size of the input image). +@param confidenceThreshold If confidence of the +detection is smaller than confidence threshold, detection is rejected. +@param nmsThreshold Non-maximum supression threshold which controls minimum +relative box intersection area required for rejecting the box with a smaller confidence. +If 1.f, nms is not performed and no boxes are rejected. +@param anchors Anchors Yolo network was trained with. +@note The default anchor values are taken from openvinotoolkit docs: +https://docs.openvinotoolkit.org/latest/omz_models_intel_yolo_v2_tiny_vehicle_detection_0001_description_yolo_v2_tiny_vehicle_detection_0001.html#output. +@return a tuple with a vector of detected boxes and a vector of appropriate labels. +*/ +GAPI_EXPORTS std::tuple, GArray> parseYolo(const GMat& in, + const GOpaque& inSz, + const float confidenceThreshold = 0.5f, + const float nmsThreshold = 0.5f, + const std::vector& anchors + = nn::parsers::GParseYolo::defaultAnchors()); + +} // namespace gapi +} // namespace cv + +#endif // OPENCV_GAPI_PARSERS_HPP diff --git a/modules/gapi/perf/common/gapi_core_perf_tests.hpp b/modules/gapi/perf/common/gapi_core_perf_tests.hpp index ab6555aedd..ed954aded3 100644 --- a/modules/gapi/perf/common/gapi_core_perf_tests.hpp +++ b/modules/gapi/perf/common/gapi_core_perf_tests.hpp @@ -10,6 +10,7 @@ #include "../../test/common/gapi_tests_common.hpp" +#include "../../test/common/gapi_parsers_tests_common.hpp" #include namespace opencv_test @@ -73,5 +74,10 @@ namespace opencv_test class ConvertToPerfTest : public TestPerfParams> {}; class ResizePerfTest : public TestPerfParams> {}; class ResizeFxFyPerfTest : public TestPerfParams> {}; + class ParseSSDBLPerfTest : public TestPerfParams>, public ParserSSDTest {}; + class ParseSSDPerfTest : public TestPerfParams>, public ParserSSDTest {}; + class ParseYoloPerfTest : public TestPerfParams>, public ParserYoloTest {}; + class SizePerfTest : public TestPerfParams> {}; + class SizeRPerfTest : public TestPerfParams> {}; } #endif // OPENCV_GAPI_CORE_PERF_TESTS_HPP diff --git a/modules/gapi/perf/common/gapi_core_perf_tests_inl.hpp b/modules/gapi/perf/common/gapi_core_perf_tests_inl.hpp index 6ec27edf50..6b049c2425 100644 --- a/modules/gapi/perf/common/gapi_core_perf_tests_inl.hpp +++ b/modules/gapi/perf/common/gapi_core_perf_tests_inl.hpp @@ -1930,5 +1930,187 @@ PERF_TEST_P_(ResizeFxFyPerfTest, TestPerformance) //------------------------------------------------------------------------------ +PERF_TEST_P_(ParseSSDBLPerfTest, TestPerformance) +{ + cv::Size sz; + float confidence_threshold = 0.0f; + int filter_label = 0; + cv::GCompileArgs compile_args; + std::tie(sz, confidence_threshold, filter_label, compile_args) = GetParam(); + cv::Mat in_mat = generateSSDoutput(sz); + std::vector boxes_gapi, boxes_ref; + std::vector labels_gapi, labels_ref; + + // Reference code ////////////////////////////////////////////////////////// + parseSSDBLref(in_mat, sz, confidence_threshold, filter_label, boxes_ref, labels_ref); + + // G-API code ////////////////////////////////////////////////////////////// + cv::GMat in; + cv::GOpaque op_sz; + auto out = cv::gapi::parseSSD(in, op_sz, confidence_threshold, filter_label); + cv::GComputation c(cv::GIn(in, op_sz), cv::GOut(std::get<0>(out), std::get<1>(out))); + + // Warm-up graph engine: + auto cc = c.compile(descr_of(in_mat), descr_of(sz), std::move(compile_args)); + cc(cv::gin(in_mat, sz), cv::gout(boxes_gapi, labels_gapi)); + + TEST_CYCLE() + { + cc(cv::gin(in_mat, sz), cv::gout(boxes_gapi, labels_gapi)); + } + + // Comparison //////////////////////////////////////////////////////////// + { + EXPECT_TRUE(boxes_gapi == boxes_ref); + EXPECT_TRUE(labels_gapi == labels_ref); + } + + SANITY_CHECK_NOTHING(); +} + +//------------------------------------------------------------------------------ + +PERF_TEST_P_(ParseSSDPerfTest, TestPerformance) +{ + cv::Size sz; + float confidence_threshold = 0; + bool alignment_to_square = false, filter_out_of_bounds = false; + cv::GCompileArgs compile_args; + std::tie(sz, confidence_threshold, alignment_to_square, filter_out_of_bounds, compile_args) = GetParam(); + cv::Mat in_mat = generateSSDoutput(sz); + std::vector boxes_gapi, boxes_ref; + + // Reference code ////////////////////////////////////////////////////////// + parseSSDref(in_mat, sz, confidence_threshold, alignment_to_square, filter_out_of_bounds, boxes_ref); + + // G-API code ////////////////////////////////////////////////////////////// + cv::GMat in; + cv::GOpaque op_sz; + auto out = cv::gapi::parseSSD(in, op_sz, confidence_threshold, alignment_to_square, filter_out_of_bounds); + cv::GComputation c(cv::GIn(in, op_sz), cv::GOut(out)); + + // Warm-up graph engine: + auto cc = c.compile(descr_of(in_mat), descr_of(sz), std::move(compile_args)); + cc(cv::gin(in_mat, sz), cv::gout(boxes_gapi)); + + TEST_CYCLE() + { + cc(cv::gin(in_mat, sz), cv::gout(boxes_gapi)); + } + + // Comparison //////////////////////////////////////////////////////////// + { + EXPECT_TRUE(boxes_gapi == boxes_ref); + } + + SANITY_CHECK_NOTHING(); +} + +//------------------------------------------------------------------------------ + +PERF_TEST_P_(ParseYoloPerfTest, TestPerformance) +{ + cv::Size sz; + float confidence_threshold = 0.0f, nms_threshold = 0.0f; + int num_classes = 0; + cv::GCompileArgs compile_args; + std::tie(sz, confidence_threshold, nms_threshold, num_classes, compile_args) = GetParam(); + cv::Mat in_mat = generateYoloOutput(num_classes); + auto anchors = cv::gapi::nn::parsers::GParseYolo::defaultAnchors(); + std::vector boxes_gapi, boxes_ref; + std::vector labels_gapi, labels_ref; + + // Reference code ////////////////////////////////////////////////////////// + parseYoloRef(in_mat, sz, confidence_threshold, nms_threshold, num_classes, anchors, boxes_ref, labels_ref); + + // G-API code ////////////////////////////////////////////////////////////// + cv::GMat in; + cv::GOpaque op_sz; + auto out = cv::gapi::parseYolo(in, op_sz, confidence_threshold, nms_threshold, anchors); + cv::GComputation c(cv::GIn(in, op_sz), cv::GOut(std::get<0>(out), std::get<1>(out))); + + // Warm-up graph engine: + auto cc = c.compile(descr_of(in_mat), descr_of(sz), std::move(compile_args)); + cc(cv::gin(in_mat, sz), cv::gout(boxes_gapi, labels_gapi)); + + TEST_CYCLE() + { + cc(cv::gin(in_mat, sz), cv::gout(boxes_gapi, labels_gapi)); + } + + // Comparison //////////////////////////////////////////////////////////// + { + EXPECT_TRUE(boxes_gapi == boxes_ref); + EXPECT_TRUE(labels_gapi == labels_ref); + } + + SANITY_CHECK_NOTHING(); +} + +//------------------------------------------------------------------------------ + +PERF_TEST_P_(SizePerfTest, TestPerformance) +{ + MatType type; + cv::Size sz; + cv::GCompileArgs compile_args; + std::tie(type, sz, compile_args) = GetParam(); + in_mat1 = cv::Mat(sz, type); + + // G-API code ////////////////////////////////////////////////////////////// + cv::GMat in; + auto out = cv::gapi::size(in); + cv::GComputation c(cv::GIn(in), cv::GOut(out)); + cv::Size out_sz; + + // Warm-up graph engine: + auto cc = c.compile(descr_of(in_mat1), std::move(compile_args)); + cc(cv::gin(in_mat1), cv::gout(out_sz)); + + TEST_CYCLE() + { + cc(cv::gin(in_mat1), cv::gout(out_sz)); + } + + // Comparison //////////////////////////////////////////////////////////// + { + EXPECT_EQ(out_sz, sz); + } + + SANITY_CHECK_NOTHING(); +} + +//------------------------------------------------------------------------------ + +PERF_TEST_P_(SizeRPerfTest, TestPerformance) +{ + cv::Size sz; + cv::GCompileArgs compile_args; + std::tie(sz, compile_args) = GetParam(); + cv::Rect rect(cv::Point(0,0), sz); + + // G-API code ////////////////////////////////////////////////////////////// + cv::GOpaque op_rect; + auto out = cv::gapi::size(op_rect); + cv::GComputation c(cv::GIn(op_rect), cv::GOut(out)); + cv::Size out_sz; + + // Warm-up graph engine: + auto cc = c.compile(descr_of(rect), std::move(compile_args)); + cc(cv::gin(rect), cv::gout(out_sz)); + + TEST_CYCLE() + { + cc(cv::gin(rect), cv::gout(out_sz)); + } + + // Comparison //////////////////////////////////////////////////////////// + { + EXPECT_EQ(out_sz, sz); + } + + SANITY_CHECK_NOTHING(); +} + } #endif // OPENCV_GAPI_CORE_PERF_TESTS_INL_HPP diff --git a/modules/gapi/perf/cpu/gapi_core_perf_tests_cpu.cpp b/modules/gapi/perf/cpu/gapi_core_perf_tests_cpu.cpp index 9f90d73171..8369ed193c 100644 --- a/modules/gapi/perf/cpu/gapi_core_perf_tests_cpu.cpp +++ b/modules/gapi/perf/cpu/gapi_core_perf_tests_cpu.cpp @@ -288,4 +288,33 @@ INSTANTIATE_TEST_CASE_P(ResizeFxFyPerfTestCPU, ResizeFxFyPerfTest, Values(0.5, 0.1), Values(0.5, 0.1), Values(cv::compile_args(CORE_CPU)))); + +INSTANTIATE_TEST_CASE_P(ParseSSDBLPerfTestCPU, ParseSSDBLPerfTest, + Combine(Values(sz720p, sz1080p), + Values(0.3f, 0.7f), + Values(0, 1), + Values(cv::compile_args(CORE_CPU)))); + +INSTANTIATE_TEST_CASE_P(ParseSSDPerfTestCPU, ParseSSDPerfTest, + Combine(Values(sz720p, sz1080p), + Values(0.3f, 0.7f), + testing::Bool(), + testing::Bool(), + Values(cv::compile_args(CORE_CPU)))); + +INSTANTIATE_TEST_CASE_P(ParseYoloPerfTestCPU, ParseYoloPerfTest, + Combine(Values(sz720p, sz1080p), + Values(0.3f, 0.7f), + Values(0.5), + Values(7, 80), + Values(cv::compile_args(CORE_CPU)))); + +INSTANTIATE_TEST_CASE_P(SizePerfTestCPU, SizePerfTest, + Combine(Values(CV_8UC1, CV_8UC3, CV_32FC1), + Values(szSmall128, szVGA, sz720p, sz1080p), + Values(cv::compile_args(CORE_CPU)))); + +INSTANTIATE_TEST_CASE_P(SizeRPerfTestCPU, SizeRPerfTest, + Combine(Values(szSmall128, szVGA, sz720p, sz1080p), + Values(cv::compile_args(CORE_CPU)))); } // opencv_test diff --git a/modules/gapi/src/api/kernels_core.cpp b/modules/gapi/src/api/kernels_core.cpp index 13a04595ca..961d19cdaa 100644 --- a/modules/gapi/src/api/kernels_core.cpp +++ b/modules/gapi/src/api/kernels_core.cpp @@ -383,5 +383,15 @@ GMat warpAffine(const GMat& src, const Mat& M, const Size& dsize, int flags, return core::GWarpAffine::on(src, M, dsize, flags, borderMode, borderValue); } +GOpaque size(const GMat& src) +{ + return core::GSize::on(src); +} + +GOpaque size(const GOpaque& r) +{ + return core::GSizeR::on(r); +} + } //namespace gapi } //namespace cv diff --git a/modules/gapi/src/api/kernels_nnparsers.cpp b/modules/gapi/src/api/kernels_nnparsers.cpp new file mode 100644 index 0000000000..bd6c70f59c --- /dev/null +++ b/modules/gapi/src/api/kernels_nnparsers.cpp @@ -0,0 +1,44 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. +// +// Copyright (C) 2020 Intel Corporation + + +#include "precomp.hpp" + +#include + +#include +#include + +namespace cv { namespace gapi { + +nn::parsers::GDetections parseSSD(const GMat& in, + const GOpaque& inSz, + const float confidenceThreshold, + const int filterLabel) +{ + return nn::parsers::GParseSSDBL::on(in, inSz, confidenceThreshold, filterLabel); +} + +nn::parsers::GRects parseSSD(const GMat& in, + const GOpaque& inSz, + const float confidenceThreshold, + const bool alignmentToSquare, + const bool filterOutOfBounds) +{ + return nn::parsers::GParseSSD::on(in, inSz, confidenceThreshold, alignmentToSquare, filterOutOfBounds); +} + +nn::parsers::GDetections parseYolo(const GMat& in, + const GOpaque& inSz, + const float confidenceThreshold, + const float nmsThreshold, + const std::vector& anchors) +{ + return nn::parsers::GParseYolo::on(in, inSz, confidenceThreshold, nmsThreshold, anchors); +} + +} //namespace gapi +} //namespace cv diff --git a/modules/gapi/src/backends/cpu/gcpucore.cpp b/modules/gapi/src/backends/cpu/gcpucore.cpp index bf2d034db9..d9c3c3ae2a 100644 --- a/modules/gapi/src/backends/cpu/gcpucore.cpp +++ b/modules/gapi/src/backends/cpu/gcpucore.cpp @@ -6,6 +6,7 @@ #include "precomp.hpp" +#include "gnnparsers.hpp" #include #include @@ -576,6 +577,63 @@ GAPI_OCV_KERNEL(GCPUWarpAffine, cv::gapi::core::GWarpAffine) } }; +GAPI_OCV_KERNEL(GCPUParseSSDBL, cv::gapi::nn::parsers::GParseSSDBL) +{ + static void run(const cv::Mat& in_ssd_result, + const cv::Size& in_size, + const float confidence_threshold, + const int filter_label, + std::vector& out_boxes, + std::vector& out_labels) + { + cv::parseSSDBL(in_ssd_result, in_size, confidence_threshold, filter_label, out_boxes, out_labels); + } +}; + +GAPI_OCV_KERNEL(GOCVParseSSD, cv::gapi::nn::parsers::GParseSSD) +{ + static void run(const cv::Mat& in_ssd_result, + const cv::Size& in_size, + const float confidence_threshold, + const bool alignment_to_square, + const bool filter_out_of_bounds, + std::vector& out_boxes) + { + cv::parseSSD(in_ssd_result, in_size, confidence_threshold, alignment_to_square, filter_out_of_bounds, out_boxes); + } +}; + +GAPI_OCV_KERNEL(GCPUParseYolo, cv::gapi::nn::parsers::GParseYolo) +{ + static void run(const cv::Mat& in_yolo_result, + const cv::Size& in_size, + const float confidence_threshold, + const float nms_threshold, + const std::vector& anchors, + std::vector& out_boxes, + std::vector& out_labels) + { + cv::parseYolo(in_yolo_result, in_size, confidence_threshold, nms_threshold, anchors, out_boxes, out_labels); + } +}; + +GAPI_OCV_KERNEL(GCPUSize, cv::gapi::core::GSize) +{ + static void run(const cv::Mat& in, cv::Size& out) + { + out.width = in.cols; + out.height = in.rows; + } +}; + +GAPI_OCV_KERNEL(GCPUSizeR, cv::gapi::core::GSizeR) +{ + static void run(const cv::Rect& in, cv::Size& out) + { + out.width = in.width; + out.height = in.height; + } +}; cv::gapi::GKernelPackage cv::gapi::core::cpu::kernels() { @@ -647,6 +705,11 @@ cv::gapi::GKernelPackage cv::gapi::core::cpu::kernels() , GCPUNormalize , GCPUWarpPerspective , GCPUWarpAffine + , GCPUParseSSDBL + , GOCVParseSSD + , GCPUParseYolo + , GCPUSize + , GCPUSizeR >(); return pkg; } diff --git a/modules/gapi/src/backends/cpu/gnnparsers.cpp b/modules/gapi/src/backends/cpu/gnnparsers.cpp new file mode 100644 index 0000000000..234382d530 --- /dev/null +++ b/modules/gapi/src/backends/cpu/gnnparsers.cpp @@ -0,0 +1,338 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. +// +// Copyright (C) 2020 Intel Corporation + +#include "gnnparsers.hpp" + +namespace cv +{ +namespace gapi +{ +namespace nn +{ +class YoloParser +{ +public: + YoloParser(const float* out, const int side, const int lcoords, const int lclasses) + : m_out(out), m_side(side), m_lcoords(lcoords), m_lclasses(lclasses) + {} + + float scale(const int i, const int b) + { + int obj_index = index(i, b, m_lcoords); + return m_out[obj_index]; + } + + double x(const int i, const int b) + { + int box_index = index(i, b, 0); + int col = i % m_side; + return (col + m_out[box_index]) / m_side; + } + + double y(const int i, const int b) + { + int box_index = index(i, b, 0); + int row = i / m_side; + return (row + m_out[box_index + m_side * m_side]) / m_side; + } + + double width(const int i, const int b, const float anchor) + { + int box_index = index(i, b, 0); + return std::exp(m_out[box_index + 2 * m_side * m_side]) * anchor / m_side; + } + + double height(const int i, const int b, const float anchor) + { + int box_index = index(i, b, 0); + return std::exp(m_out[box_index + 3 * m_side * m_side]) * anchor / m_side; + } + + float classConf(const int i, const int b, const int label) + { + int class_index = index(i, b, m_lcoords + 1 + label); + return m_out[class_index]; + } + + cv::Rect toBox(const double x, const double y, const double h, const double w, const cv::Size& in_sz) + { + auto h_scale = in_sz.height; + auto w_scale = in_sz.width; + cv::Rect r; + r.x = static_cast((x - w / 2) * w_scale); + r.y = static_cast((y - h / 2) * h_scale); + r.width = static_cast(w * w_scale); + r.height = static_cast(h * h_scale); + return r; + } + +private: + const float* m_out = nullptr; + int m_side = 0, m_lcoords = 0, m_lclasses = 0; + + int index(const int i, const int b, const int entry) + { + return b * m_side * m_side * (m_lcoords + m_lclasses + 1) + entry * m_side * m_side + i; + } +}; + +struct YoloParams +{ + int num = 5; + int coords = 4; +}; + +struct Detection +{ + Detection(const cv::Rect& in_rect, const float in_conf, const int in_label) + : rect(in_rect), conf(in_conf), label(in_label) + {} + cv::Rect rect; + float conf = 0.0f; + int label = 0; +}; + +class SSDParser +{ +public: + SSDParser(const cv::MatSize& in_ssd_dims, const cv::Size& in_size, const float* data) + : m_dims(in_ssd_dims), m_maxProp(in_ssd_dims[2]), m_objSize(in_ssd_dims[3]), + m_data(data), m_surface(cv::Rect({0,0}, in_size)), m_size(in_size) + { + GAPI_Assert(in_ssd_dims.dims() == 4u); // Fixed output layout + GAPI_Assert(m_objSize == 7); // Fixed SSD object size + } + + void adjustBoundingBox(cv::Rect& boundingBox) + { + auto w = boundingBox.width; + auto h = boundingBox.height; + + boundingBox.x -= static_cast(0.067 * w); + boundingBox.y -= static_cast(0.028 * h); + + boundingBox.width += static_cast(0.15 * w); + boundingBox.height += static_cast(0.13 * h); + + if (boundingBox.width < boundingBox.height) + { + auto dx = (boundingBox.height - boundingBox.width); + boundingBox.x -= dx / 2; + boundingBox.width += dx; + } + else + { + auto dy = (boundingBox.width - boundingBox.height); + boundingBox.y -= dy / 2; + boundingBox.height += dy; + } + } + + std::tuple extract(const size_t step) + { + const float* it = m_data + step * m_objSize; + float image_id = it[0]; + int label = static_cast(it[1]); + float confidence = it[2]; + float rc_left = it[3]; + float rc_top = it[4]; + float rc_right = it[5]; + float rc_bottom = it[6]; + + cv::Rect rc; // Map relative coordinates to the original image scale + rc.x = static_cast(rc_left * m_size.width); + rc.y = static_cast(rc_top * m_size.height); + rc.width = static_cast(rc_right * m_size.width) - rc.x; + rc.height = static_cast(rc_bottom * m_size.height) - rc.y; + return std::make_tuple(rc, image_id, confidence, label); + } + + int getMaxProposals() + { + return m_maxProp; + } + + cv::Rect getSurface() + { + return m_surface; + } + +private: + const cv::MatSize m_dims; + int m_maxProp = 0, m_objSize = 0; + const float* m_data = nullptr; + const cv::Rect m_surface; + const cv::Size m_size; +}; +} // namespace nn +} // namespace gapi + +void parseSSDBL(const cv::Mat& in_ssd_result, + const cv::Size& in_size, + const float confidence_threshold, + const int filter_label, + std::vector& out_boxes, + std::vector& out_labels) +{ + cv::gapi::nn::SSDParser parser(in_ssd_result.size, in_size, in_ssd_result.ptr()); + out_boxes.clear(); + out_labels.clear(); + cv::Rect rc; + float image_id, confidence; + int label; + const size_t range = parser.getMaxProposals(); + for (size_t i = 0; i < range; ++i) + { + std::tie(rc, image_id, confidence, label) = parser.extract(i); + + if (image_id < 0.f) + { + break; // marks end-of-detections + } + + if (confidence < confidence_threshold || + (filter_label != -1 && label != filter_label)) + { + continue; // filter out object classes if filter is specified + } // and skip objects with low confidence + out_boxes.emplace_back(rc & parser.getSurface()); + out_labels.emplace_back(label); + } +} + +void parseSSD(const cv::Mat& in_ssd_result, + const cv::Size& in_size, + const float confidence_threshold, + const bool alignment_to_square, + const bool filter_out_of_bounds, + std::vector& out_boxes) +{ + cv::gapi::nn::SSDParser parser(in_ssd_result.size, in_size, in_ssd_result.ptr()); + out_boxes.clear(); + cv::Rect rc; + float image_id, confidence; + int label; + const size_t range = parser.getMaxProposals(); + for (size_t i = 0; i < range; ++i) + { + std::tie(rc, image_id, confidence, label) = parser.extract(i); + + if (image_id < 0.f) + { + break; // marks end-of-detections + } + if (confidence < confidence_threshold) + { + continue; // skip objects with low confidence + } + + if (alignment_to_square) + { + parser.adjustBoundingBox(rc); + } + + const auto clipped_rc = rc & parser.getSurface(); + if (filter_out_of_bounds) + { + if (clipped_rc.area() != rc.area()) + { + continue; + } + } + out_boxes.emplace_back(clipped_rc); + } +} + +void parseYolo(const cv::Mat& in_yolo_result, + const cv::Size& in_size, + const float confidence_threshold, + const float nms_threshold, + const std::vector& anchors, + std::vector& out_boxes, + std::vector& out_labels) +{ + const auto& dims = in_yolo_result.size; + GAPI_Assert(dims.dims() == 4); + GAPI_Assert(dims[0] == 1); + GAPI_Assert(dims[1] == 13); + GAPI_Assert(dims[2] == 13); + GAPI_Assert(dims[3] % 5 == 0); // 5 boxes + const auto num_classes = dims[3] / 5 - 5; + GAPI_Assert(num_classes > 0); + GAPI_Assert(0 < nms_threshold && nms_threshold <= 1); + out_boxes.clear(); + out_labels.clear(); + gapi::nn::YoloParams params; + constexpr auto side = 13; + constexpr auto side_square = side * side; + const auto output = in_yolo_result.ptr(); + + gapi::nn::YoloParser parser(output, side, params.coords, num_classes); + + std::vector detections; + + for (int i = 0; i < side_square; ++i) + { + for (int b = 0; b < params.num; ++b) + { + float scale = parser.scale(i, b); + if (scale < confidence_threshold) + { + continue; + } + double x = parser.x(i, b); + double y = parser.y(i, b); + double height = parser.height(i, b, anchors[2 * b + 1]); + double width = parser.width(i, b, anchors[2 * b]); + + for (int label = 0; label < num_classes; ++label) + { + float prob = scale * parser.classConf(i,b,label); + if (prob < confidence_threshold) + { + continue; + } + auto box = parser.toBox(x, y, height, width, in_size); + detections.emplace_back(gapi::nn::Detection(box, prob, label)); + } + } + } + std::stable_sort(std::begin(detections), std::end(detections), + [](const gapi::nn::Detection& a, const gapi::nn::Detection& b) + { + return a.conf > b.conf; + }); + + if (nms_threshold < 1.0f) + { + for (const auto& d : detections) + { + // Reject boxes which overlap with previously pushed ones + // (They are sorted by confidence, so rejected box + // always has a smaller confidence + if (std::end(out_boxes) == + std::find_if(std::begin(out_boxes), std::end(out_boxes), + [&d, nms_threshold](const cv::Rect& r) + { + float rectOverlap = 1.f - static_cast(jaccardDistance(r, d.rect)); + return rectOverlap > nms_threshold; + })) + { + out_boxes. emplace_back(d.rect); + out_labels.emplace_back(d.label); + } + } + } + else + { + for (const auto& d: detections) + { + out_boxes. emplace_back(d.rect); + out_labels.emplace_back(d.label); + } + } +} +} // namespace cv diff --git a/modules/gapi/src/backends/cpu/gnnparsers.hpp b/modules/gapi/src/backends/cpu/gnnparsers.hpp new file mode 100644 index 0000000000..2ae6693705 --- /dev/null +++ b/modules/gapi/src/backends/cpu/gnnparsers.hpp @@ -0,0 +1,36 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. +// +// Copyright (C) 2020 Intel Corporation + +#include + +#ifndef OPENCV_NNPARSERS_OCV_HPP +#define OPENCV_NNPARSERS_OCV_HPP + +namespace cv +{ +void parseSSDBL(const cv::Mat& in_ssd_result, + const cv::Size& in_size, + const float confidence_threshold, + const int filter_label, + std::vector& out_boxes, + std::vector& out_labels); + +void parseSSD(const cv::Mat& in_ssd_result, + const cv::Size& in_size, + const float confidence_threshold, + const bool alignment_to_square, + const bool filter_out_of_bounds, + std::vector& out_boxes); + +void parseYolo(const cv::Mat& in_yolo_result, + const cv::Size& in_size, + const float confidence_threshold, + const float nms_threshold, + const std::vector& anchors, + std::vector& out_boxes, + std::vector& out_labels); +} +#endif // OPENCV_NNPARSERS_OCV_HPP diff --git a/modules/gapi/test/common/gapi_core_tests.hpp b/modules/gapi/test/common/gapi_core_tests.hpp index 8a786a9ab0..52e184ac69 100644 --- a/modules/gapi/test/common/gapi_core_tests.hpp +++ b/modules/gapi/test/common/gapi_core_tests.hpp @@ -11,6 +11,7 @@ #include #include "gapi_tests_common.hpp" +#include "gapi_parsers_tests_common.hpp" namespace opencv_test { @@ -149,6 +150,15 @@ GAPI_TEST_FIXTURE(WarpPerspectiveTest, initMatrixRandU, GAPI_TEST_FIXTURE(WarpAffineTest, initMatrixRandU, FIXTURE_API(CompareMats, double , double, int, int, cv::Scalar), 6, cmpF, angle, scale, flags, border_mode, border_value) + +GAPI_TEST_EXT_BASE_FIXTURE(ParseSSDBLTest, ParserSSDTest, initNothing, + FIXTURE_API(float, int), 2, confidence_threshold, filter_label) +GAPI_TEST_EXT_BASE_FIXTURE(ParseSSDTest, ParserSSDTest, initNothing, + FIXTURE_API(float, bool, bool), 3, confidence_threshold, alignment_to_square, filter_out_of_bounds) +GAPI_TEST_EXT_BASE_FIXTURE(ParseYoloTest, ParserYoloTest, initNothing, + FIXTURE_API(float, float, int), 3, confidence_threshold, nms_threshold, num_classes) +GAPI_TEST_FIXTURE(SizeTest, initMatrixRandU, <>, 0) +GAPI_TEST_FIXTURE(SizeRTest, initNothing, <>, 0) } // opencv_test #endif //OPENCV_GAPI_CORE_TESTS_HPP diff --git a/modules/gapi/test/common/gapi_core_tests_inl.hpp b/modules/gapi/test/common/gapi_core_tests_inl.hpp index b5ed58c703..7226fa3198 100644 --- a/modules/gapi/test/common/gapi_core_tests_inl.hpp +++ b/modules/gapi/test/common/gapi_core_tests_inl.hpp @@ -9,6 +9,7 @@ #define OPENCV_GAPI_CORE_TESTS_INL_HPP #include +#include #include "gapi_core_tests.hpp" namespace opencv_test @@ -1578,6 +1579,95 @@ TEST_P(ReInitOutTest, TestWithAdd) run_and_compare(); } +TEST_P(ParseSSDBLTest, ParseTest) +{ + cv::Mat in_mat = generateSSDoutput(sz); + std::vector boxes_gapi, boxes_ref; + std::vector labels_gapi, labels_ref; + + // G-API code ////////////////////////////////////////////////////////////// + cv::GMat in; + cv::GOpaque op_sz; + auto out = cv::gapi::parseSSD(in, op_sz, confidence_threshold, filter_label); + cv::GComputation c(cv::GIn(in, op_sz), cv::GOut(std::get<0>(out), std::get<1>(out))); + c.apply(cv::gin(in_mat, sz), cv::gout(boxes_gapi, labels_gapi), getCompileArgs()); + + // Reference code ////////////////////////////////////////////////////////// + parseSSDBLref(in_mat, sz, confidence_threshold, filter_label, boxes_ref, labels_ref); + + // Comparison ////////////////////////////////////////////////////////////// + EXPECT_TRUE(boxes_gapi == boxes_ref); + EXPECT_TRUE(labels_gapi == labels_ref); +} + +TEST_P(ParseSSDTest, ParseTest) +{ + cv::Mat in_mat = generateSSDoutput(sz); + std::vector boxes_gapi, boxes_ref; + + // G-API code ////////////////////////////////////////////////////////////// + cv::GMat in; + cv::GOpaque op_sz; + auto out = cv::gapi::parseSSD(in, op_sz, confidence_threshold, + alignment_to_square, filter_out_of_bounds); + cv::GComputation c(cv::GIn(in, op_sz), cv::GOut(out)); + c.apply(cv::gin(in_mat, sz), cv::gout(boxes_gapi), getCompileArgs()); + + // Reference code ////////////////////////////////////////////////////////// + parseSSDref(in_mat, sz, confidence_threshold, alignment_to_square, + filter_out_of_bounds, boxes_ref); + + // Comparison ////////////////////////////////////////////////////////////// + EXPECT_TRUE(boxes_gapi == boxes_ref); +} + +TEST_P(ParseYoloTest, ParseTest) +{ + cv::Mat in_mat = generateYoloOutput(num_classes); + auto anchors = cv::gapi::nn::parsers::GParseYolo::defaultAnchors(); + std::vector boxes_gapi, boxes_ref; + std::vector labels_gapi, labels_ref; + + // G-API code ////////////////////////////////////////////////////////////// + cv::GMat in; + cv::GOpaque op_sz; + auto out = cv::gapi::parseYolo(in, op_sz, confidence_threshold, nms_threshold, anchors); + cv::GComputation c(cv::GIn(in, op_sz), cv::GOut(std::get<0>(out), std::get<1>(out))); + c.apply(cv::gin(in_mat, sz), cv::gout(boxes_gapi, labels_gapi), getCompileArgs()); + + // Reference code ////////////////////////////////////////////////////////// + parseYoloRef(in_mat, sz, confidence_threshold, nms_threshold, num_classes, anchors, boxes_ref, labels_ref); + + // Comparison ////////////////////////////////////////////////////////////// + EXPECT_TRUE(boxes_gapi == boxes_ref); + EXPECT_TRUE(labels_gapi == labels_ref); +} + +TEST_P(SizeTest, ParseTest) +{ + cv::GMat in; + cv::Size out_sz; + + auto out = cv::gapi::size(in); + cv::GComputation c(cv::GIn(in), cv::GOut(out)); + c.apply(cv::gin(in_mat1), cv::gout(out_sz), getCompileArgs()); + + EXPECT_EQ(out_sz, sz); +} + +TEST_P(SizeRTest, ParseTest) +{ + cv::Rect rect(cv::Point(0,0), sz); + cv::Size out_sz; + + cv::GOpaque op_rect; + auto out = cv::gapi::size(op_rect); + cv::GComputation c(cv::GIn(op_rect), cv::GOut(out)); + c.apply(cv::gin(rect), cv::gout(out_sz), getCompileArgs()); + + EXPECT_EQ(out_sz, sz); +} + } // opencv_test #endif //OPENCV_GAPI_CORE_TESTS_INL_HPP diff --git a/modules/gapi/test/common/gapi_parsers_tests_common.hpp b/modules/gapi/test/common/gapi_parsers_tests_common.hpp new file mode 100644 index 0000000000..127a1c5a5e --- /dev/null +++ b/modules/gapi/test/common/gapi_parsers_tests_common.hpp @@ -0,0 +1,397 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. +// +// Copyright (C) 2020 Intel Corporation + + +#ifndef OPENCV_GAPI_PARSERS_TESTS_COMMON_HPP +#define OPENCV_GAPI_PARSERS_TESTS_COMMON_HPP + +#include "gapi_tests_common.hpp" +#include "../../include/opencv2/gapi/infer/parsers.hpp" + +namespace opencv_test +{ +class ParserSSDTest +{ +public: + cv::Mat generateSSDoutput(const cv::Size& in_sz) + { + constexpr int maxN = 200; + constexpr int objSize = 7; + std::vector dims{ 1, 1, maxN, objSize }; + cv::Mat mat(dims, CV_32FC1); + auto data = mat.ptr(); + + for (int i = 0; i < maxN; ++i) + { + float* it = data + i * objSize; + auto ssdIt = generateItem(i, in_sz); + it[0] = ssdIt.image_id; + it[1] = ssdIt.label; + it[2] = ssdIt.confidence; + it[3] = ssdIt.rc_left; + it[4] = ssdIt.rc_top; + it[5] = ssdIt.rc_right; + it[6] = ssdIt.rc_bottom; + } + return mat; + } + + void parseSSDref(const cv::Mat& in_ssd_result, + const cv::Size& in_size, + const float confidence_threshold, + const bool alignment_to_square, + const bool filter_out_of_bounds, + std::vector& out_boxes) + { + out_boxes.clear(); + const auto &in_ssd_dims = in_ssd_result.size; + CV_Assert(in_ssd_dims.dims() == 4u); + + const int MAX_PROPOSALS = in_ssd_dims[2]; + const int OBJECT_SIZE = in_ssd_dims[3]; + CV_Assert(OBJECT_SIZE == 7); // fixed SSD object size + + const float *data = in_ssd_result.ptr(); + cv::Rect surface({0,0}, in_size), rc; + float image_id, confidence; + int label; + for (int i = 0; i < MAX_PROPOSALS; ++i) + { + std::tie(rc, image_id, confidence, label) + = extract(data + i*OBJECT_SIZE, in_size); + if (image_id < 0.f) + { + break; // marks end-of-detections + } + + if (confidence < confidence_threshold) + { + continue; // skip objects with low confidence + } + + if (alignment_to_square) + { + adjustBoundingBox(rc); + } + + const auto clipped_rc = rc & surface; + if (filter_out_of_bounds) + { + if (clipped_rc.area() != rc.area()) + { + continue; + } + } + out_boxes.emplace_back(clipped_rc); + } + } + + void parseSSDBLref(const cv::Mat& in_ssd_result, + const cv::Size& in_size, + const float confidence_threshold, + const int filter_label, + std::vector& out_boxes, + std::vector& out_labels) + { + out_boxes.clear(); + out_labels.clear(); + const auto &in_ssd_dims = in_ssd_result.size; + CV_Assert(in_ssd_dims.dims() == 4u); + + const int MAX_PROPOSALS = in_ssd_dims[2]; + const int OBJECT_SIZE = in_ssd_dims[3]; + CV_Assert(OBJECT_SIZE == 7); // fixed SSD object size + cv::Rect surface({0,0}, in_size), rc; + float image_id, confidence; + int label; + const float *data = in_ssd_result.ptr(); + for (int i = 0; i < MAX_PROPOSALS; i++) + { + std::tie(rc, image_id, confidence, label) + = extract(data + i*OBJECT_SIZE, in_size); + if (image_id < 0.f) + { + break; // marks end-of-detections + } + + if (confidence < confidence_threshold || + (filter_label != -1 && label != filter_label)) + { + continue; // filter out object classes if filter is specified + } + + out_boxes.emplace_back(rc & surface); + out_labels.emplace_back(label); + } + } + +private: + void adjustBoundingBox(cv::Rect& boundingBox) + { + auto w = boundingBox.width; + auto h = boundingBox.height; + + boundingBox.x -= static_cast(0.067 * w); + boundingBox.y -= static_cast(0.028 * h); + + boundingBox.width += static_cast(0.15 * w); + boundingBox.height += static_cast(0.13 * h); + + if (boundingBox.width < boundingBox.height) + { + auto dx = (boundingBox.height - boundingBox.width); + boundingBox.x -= dx / 2; + boundingBox.width += dx; + } + else + { + auto dy = (boundingBox.width - boundingBox.height); + boundingBox.y -= dy / 2; + boundingBox.height += dy; + } + } + + std::tuple extract(const float* it, + const cv::Size& in_size) + { + float image_id = it[0]; + int label = static_cast(it[1]); + float confidence = it[2]; + float rc_left = it[3]; + float rc_top = it[4]; + float rc_right = it[5]; + float rc_bottom = it[6]; + + cv::Rect rc; // map relative coordinates to the original image scale + rc.x = static_cast(rc_left * in_size.width); + rc.y = static_cast(rc_top * in_size.height); + rc.width = static_cast(rc_right * in_size.width) - rc.x; + rc.height = static_cast(rc_bottom * in_size.height) - rc.y; + return std::make_tuple(rc, image_id, confidence, label); + } + + int randInRange(const int start, const int end) + { + GAPI_Assert(start <= end); + return start + std::rand() % (end - start + 1); + } + + cv::Rect generateBox(const cv::Size& in_sz) + { + // Generated rectangle can reside outside of the initial image by border pixels + constexpr int border = 10; + constexpr int minW = 16; + constexpr int minH = 16; + cv::Rect box; + box.width = randInRange(minW, in_sz.width + 2*border); + box.height = randInRange(minH, in_sz.height + 2*border); + box.x = randInRange(-border, in_sz.width + border - box.width); + box.y = randInRange(-border, in_sz.height + border - box.height); + return box; + } + + struct SSDitem + { + float image_id = 0.0f; + float label = 0.0f; + float confidence = 0.0f; + float rc_left = 0.0f; + float rc_top = 0.0f; + float rc_right = 0.0f; + float rc_bottom = 0.0f; + }; + + SSDitem generateItem(const int i, const cv::Size& in_sz) + { + const auto normalize = [](int v, int range) { return static_cast(v) / range; }; + + SSDitem it; + it.image_id = static_cast(i); + it.label = static_cast(randInRange(0, 9)); + it.confidence = static_cast(std::rand()) / RAND_MAX; + auto box = generateBox(in_sz); + it.rc_left = normalize(box.x, in_sz.width); + it.rc_right = normalize(box.x + box.width, in_sz.width); + it.rc_top = normalize(box.y, in_sz.height); + it.rc_bottom = normalize(box.y + box.height, in_sz.height); + + return it; + } +}; + +class ParserYoloTest +{ +public: + cv::Mat generateYoloOutput(const int num_classes) + { + std::vector dims = { 1, 13, 13, (num_classes + 5) * 5 }; + cv::Mat mat(dims, CV_32FC1); + auto data = mat.ptr(); + + const size_t range = dims[0] * dims[1] * dims[2] * dims[3]; + for (size_t i = 0; i < range; ++i) + { + data[i] = static_cast(std::rand()) / RAND_MAX; + } + return mat; + } + + void parseYoloRef(const cv::Mat& in_yolo_result, + const cv::Size& in_size, + const float confidence_threshold, + const float nms_threshold, + const int num_classes, + const std::vector& anchors, + std::vector& out_boxes, + std::vector& out_labels) + { + YoloParams params; + constexpr auto side_square = 13 * 13; + this->m_out = in_yolo_result.ptr(); + this->m_side = 13; + this->m_lcoords = params.coords; + this->m_lclasses = num_classes; + + std::vector detections; + + for (int i = 0; i < side_square; ++i) + { + for (int b = 0; b < params.num; ++b) + { + float scale = this->scale(i, b); + if (scale < confidence_threshold) + { + continue; + } + double x = this->x(i, b); + double y = this->y(i, b); + double height = this->height(i, b, anchors[2 * b + 1]); + double width = this->width(i, b, anchors[2 * b]); + + for (int label = 0; label < num_classes; ++label) + { + float prob = scale * classConf(i,b,label); + if (prob < confidence_threshold) + { + continue; + } + auto box = toBox(x, y, height, width, in_size); + detections.emplace_back(Detection(box, prob, label)); + } + } + } + std::stable_sort(std::begin(detections), std::end(detections), + [](const Detection& a, const Detection& b) + { + return a.conf > b.conf; + }); + + if (nms_threshold < 1.0f) + { + for (const auto& d : detections) + { + if (std::end(out_boxes) == + std::find_if(std::begin(out_boxes), std::end(out_boxes), + [&d, nms_threshold](const cv::Rect& r) + { + float rectOverlap = 1.f - static_cast(jaccardDistance(r, d.rect)); + return rectOverlap > nms_threshold; + })) + { + out_boxes. emplace_back(d.rect); + out_labels.emplace_back(d.label); + } + } + } + else + { + for (const auto& d: detections) + { + out_boxes. emplace_back(d.rect); + out_labels.emplace_back(d.label); + } + } + } + +private: + struct Detection + { + Detection(const cv::Rect& in_rect, const float in_conf, const int in_label) + : rect(in_rect), conf(in_conf), label(in_label) + {} + cv::Rect rect; + float conf = 0.0f; + int label = 0; + }; + + struct YoloParams + { + int num = 5; + int coords = 4; + }; + + float scale(const int i, const int b) + { + int obj_index = index(i, b, m_lcoords); + return m_out[obj_index]; + } + + double x(const int i, const int b) + { + int box_index = index(i, b, 0); + int col = i % m_side; + return (col + m_out[box_index]) / m_side; + } + + double y(const int i, const int b) + { + int box_index = index(i, b, 0); + int row = i / m_side; + return (row + m_out[box_index + m_side * m_side]) / m_side; + } + + double width(const int i, const int b, const float anchor) + { + int box_index = index(i, b, 0); + return std::exp(m_out[box_index + 2 * m_side * m_side]) * anchor / m_side; + } + + double height(const int i, const int b, const float anchor) + { + int box_index = index(i, b, 0); + return std::exp(m_out[box_index + 3 * m_side * m_side]) * anchor / m_side; + } + + float classConf(const int i, const int b, const int label) + { + int class_index = index(i, b, m_lcoords + 1 + label); + return m_out[class_index]; + } + + cv::Rect toBox(const double x, const double y, const double h, const double w, const cv::Size& in_sz) + { + auto h_scale = in_sz.height; + auto w_scale = in_sz.width; + cv::Rect r; + r.x = static_cast((x - w / 2) * w_scale); + r.y = static_cast((y - h / 2) * h_scale); + r.width = static_cast(w * w_scale); + r.height = static_cast(h * h_scale); + return r; + } + + int index(const int i, const int b, const int entry) + { + return b * m_side * m_side * (m_lcoords + m_lclasses + 1) + entry * m_side * m_side + i; + } + + const float* m_out = nullptr; + int m_side = 0, m_lcoords = 0, m_lclasses = 0; +}; + +} // namespace opencv_test + +#endif // OPENCV_GAPI_PARSERS_TESTS_COMMON_HPP diff --git a/modules/gapi/test/common/gapi_tests_common.hpp b/modules/gapi/test/common/gapi_tests_common.hpp index a21cae460b..144ca5fcc6 100644 --- a/modules/gapi/test/common/gapi_tests_common.hpp +++ b/modules/gapi/test/common/gapi_tests_common.hpp @@ -351,6 +351,27 @@ struct TestWithParamsSpecific : public TestWithParamsBase") must be specified. + * @param Number number of user-defined parameters (corresponds to the number of types in API). + * if there are no such parameters, 0 must be specified. + * @param ... list of names of user-defined parameters. if there are no parameters, the list + * must be empty. + */ +#define GAPI_TEST_EXT_BASE_FIXTURE(Fixture, ExtBase, InitF, API, Number, ...) \ + struct Fixture : public TestWithParams API, public ExtBase { \ + static_assert(Number == AllParams::specific_params_size, \ + "Number of user-defined parameters doesn't match size of __VA_ARGS__"); \ + __WRAP_VAARGS(DEFINE_SPECIFIC_PARAMS_##Number(__VA_ARGS__)) \ + Fixture() { InitF(type, sz, dtype); } \ + }; + /** * @private * @brief Create G-API test fixture with TestWithParamsSpecific base class diff --git a/modules/gapi/test/cpu/gapi_core_tests_cpu.cpp b/modules/gapi/test/cpu/gapi_core_tests_cpu.cpp index 50384ca9f2..121b939736 100644 --- a/modules/gapi/test/cpu/gapi_core_tests_cpu.cpp +++ b/modules/gapi/test/cpu/gapi_core_tests_cpu.cpp @@ -496,4 +496,43 @@ INSTANTIATE_TEST_CASE_P(ReInitOutTestCPU, ReInitOutTest, Values(cv::Size(640, 400), cv::Size(10, 480)))); +INSTANTIATE_TEST_CASE_P(ParseTestCPU, ParseSSDBLTest, + Combine(Values(CV_8UC1, CV_8UC3, CV_32FC1), + Values(cv::Size(1920, 1080)), + Values(-1), + Values(CORE_CPU), + Values(0.3f, 0.5f, 0.7f), + Values(-1, 0, 1))); + +INSTANTIATE_TEST_CASE_P(ParseTestCPU, ParseSSDTest, + Combine(Values(CV_8UC1, CV_8UC3, CV_32FC1), + Values(cv::Size(1920, 1080)), + Values(-1), + Values(CORE_CPU), + Values(0.3f, 0.5f, 0.7f), + testing::Bool(), + testing::Bool())); + +INSTANTIATE_TEST_CASE_P(ParseTestCPU, ParseYoloTest, + Combine(Values(CV_8UC1, CV_8UC3, CV_32FC1), + Values(cv::Size(1920, 1080)), + Values(-1), + Values(CORE_CPU), + Values(0.3f, 0.5f, 0.7f), + Values(0.5f, 1.0f), + Values(80, 7))); + +INSTANTIATE_TEST_CASE_P(SizeTestCPU, SizeTest, + Combine(Values(CV_8UC1, CV_8UC3, CV_32FC1), + Values(cv::Size(32, 32), + cv::Size(640, 320)), + Values(-1), + Values(CORE_CPU))); + +INSTANTIATE_TEST_CASE_P(SizeRTestCPU, SizeRTest, + Combine(Values(CV_8UC1, CV_8UC3, CV_32FC1), + Values(cv::Size(32, 32), + cv::Size(640, 320)), + Values(-1), + Values(CORE_CPU))); }