Merge pull request #19613 from WeiChungChang:NMS_refine

pull/19703/head
Alexander Alekhin 4 years ago
commit e4692ac079
  1. 11
      modules/dnn/src/layers/detection_output_layer.cpp
  2. 11
      modules/dnn/src/nms.inl.hpp

@ -133,6 +133,12 @@ public:
typedef std::map<int, std::vector<util::NormalizedBBox> > LabelBBox; typedef std::map<int, std::vector<util::NormalizedBBox> > LabelBBox;
inline int getNumOfTargetClasses() {
unsigned numBackground =
(_backgroundLabelId >= 0 && _backgroundLabelId < _numClasses) ? 1 : 0;
return (_numClasses - numBackground);
}
bool getParameterDict(const LayerParams &params, bool getParameterDict(const LayerParams &params,
const std::string &parameterName, const std::string &parameterName,
DictValue& result) DictValue& result)
@ -584,12 +590,13 @@ public:
LabelBBox::const_iterator label_bboxes = decodeBBoxes.find(label); LabelBBox::const_iterator label_bboxes = decodeBBoxes.find(label);
if (label_bboxes == decodeBBoxes.end()) if (label_bboxes == decodeBBoxes.end())
CV_Error_(cv::Error::StsError, ("Could not find location predictions for label %d", label)); CV_Error_(cv::Error::StsError, ("Could not find location predictions for label %d", label));
int limit = (getNumOfTargetClasses() == 1) ? _keepTopK : std::numeric_limits<int>::max();
if (_bboxesNormalized) if (_bboxesNormalized)
NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK, NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK,
indices[c], util::caffe_norm_box_overlap); indices[c], util::caffe_norm_box_overlap, limit);
else else
NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK, NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK,
indices[c], util::caffe_box_overlap); indices[c], util::caffe_box_overlap, limit);
numDetections += indices[c].size(); numDetections += indices[c].size();
} }
if (_keepTopK > -1 && numDetections > (size_t)_keepTopK) if (_keepTopK > -1 && numDetections > (size_t)_keepTopK)

@ -62,12 +62,15 @@ inline void GetMaxScoreIndex(const std::vector<float>& scores, const float thres
// score_threshold: a threshold used to filter detection results. // score_threshold: a threshold used to filter detection results.
// nms_threshold: a threshold used in non maximum suppression. // nms_threshold: a threshold used in non maximum suppression.
// top_k: if not > 0, keep at most top_k picked indices. // top_k: if not > 0, keep at most top_k picked indices.
// limit: early terminate once the # of picked indices has reached it.
// indices: the kept indices of bboxes after nms. // indices: the kept indices of bboxes after nms.
template <typename BoxType> template <typename BoxType>
inline void NMSFast_(const std::vector<BoxType>& bboxes, inline void NMSFast_(const std::vector<BoxType>& bboxes,
const std::vector<float>& scores, const float score_threshold, const std::vector<float>& scores, const float score_threshold,
const float nms_threshold, const float eta, const int top_k, const float nms_threshold, const float eta, const int top_k,
std::vector<int>& indices, float (*computeOverlap)(const BoxType&, const BoxType&)) std::vector<int>& indices,
float (*computeOverlap)(const BoxType&, const BoxType&),
int limit = std::numeric_limits<int>::max())
{ {
CV_Assert(bboxes.size() == scores.size()); CV_Assert(bboxes.size() == scores.size());
@ -86,8 +89,12 @@ inline void NMSFast_(const std::vector<BoxType>& bboxes,
float overlap = computeOverlap(bboxes[idx], bboxes[kept_idx]); float overlap = computeOverlap(bboxes[idx], bboxes[kept_idx]);
keep = overlap <= adaptive_threshold; keep = overlap <= adaptive_threshold;
} }
if (keep) if (keep) {
indices.push_back(idx); indices.push_back(idx);
if (indices.size() >= limit) {
break;
}
}
if (keep && eta < 1 && adaptive_threshold > 0.5) { if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta; adaptive_threshold *= eta;
} }

Loading…
Cancel
Save