diff --git a/modules/dnn/src/layers/detection_output_layer.cpp b/modules/dnn/src/layers/detection_output_layer.cpp index 2dd6f5fb73..8374d74293 100644 --- a/modules/dnn/src/layers/detection_output_layer.cpp +++ b/modules/dnn/src/layers/detection_output_layer.cpp @@ -133,6 +133,12 @@ public: typedef std::map > LabelBBox; + inline int getNumOfTargetClasses() { + unsigned numBackground = + (_backgroundLabelId >= 0 && _backgroundLabelId < _numClasses) ? 1 : 0; + return (_numClasses - numBackground); + } + bool getParameterDict(const LayerParams ¶ms, const std::string ¶meterName, DictValue& result) @@ -584,12 +590,13 @@ public: LabelBBox::const_iterator label_bboxes = decodeBBoxes.find(label); if (label_bboxes == decodeBBoxes.end()) CV_Error_(cv::Error::StsError, ("Could not find location predictions for label %d", label)); + int limit = (getNumOfTargetClasses() == 1) ? _keepTopK : std::numeric_limits::max(); if (_bboxesNormalized) NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK, - indices[c], util::caffe_norm_box_overlap); + indices[c], util::caffe_norm_box_overlap, limit); else NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK, - indices[c], util::caffe_box_overlap); + indices[c], util::caffe_box_overlap, limit); numDetections += indices[c].size(); } if (_keepTopK > -1 && numDetections > (size_t)_keepTopK) diff --git a/modules/dnn/src/nms.inl.hpp b/modules/dnn/src/nms.inl.hpp index 89e3adfcf5..7b84839c02 100644 --- a/modules/dnn/src/nms.inl.hpp +++ b/modules/dnn/src/nms.inl.hpp @@ -62,12 +62,15 @@ inline void GetMaxScoreIndex(const std::vector& scores, const float thres // score_threshold: a threshold used to filter detection results. // nms_threshold: a threshold used in non maximum suppression. // top_k: if not > 0, keep at most top_k picked indices. +// limit: early terminate once the # of picked indices has reached it. // indices: the kept indices of bboxes after nms. template inline void NMSFast_(const std::vector& bboxes, const std::vector& scores, const float score_threshold, const float nms_threshold, const float eta, const int top_k, - std::vector& indices, float (*computeOverlap)(const BoxType&, const BoxType&)) + std::vector& indices, + float (*computeOverlap)(const BoxType&, const BoxType&), + int limit = std::numeric_limits::max()) { CV_Assert(bboxes.size() == scores.size()); @@ -86,8 +89,12 @@ inline void NMSFast_(const std::vector& bboxes, float overlap = computeOverlap(bboxes[idx], bboxes[kept_idx]); keep = overlap <= adaptive_threshold; } - if (keep) + if (keep) { indices.push_back(idx); + if (indices.size() >= limit) { + break; + } + } if (keep && eta < 1 && adaptive_threshold > 0.5) { adaptive_threshold *= eta; }