Merge pull request #10265 from dkurt:nms_for_region_layer

pull/10399/merge
Vadim Pisarevsky 7 years ago
commit 0742e12f0b
  1. 2
      modules/dnn/src/darknet/darknet_io.cpp
  2. 133
      modules/dnn/src/layers/region_layer.cpp

@ -482,7 +482,7 @@ namespace cv {
} }
else if (layer_type == "region") else if (layer_type == "region")
{ {
float thresh = 0.001; // in the original Darknet is equal to the detection threshold set by the user float thresh = getParam<float>(layer_params, "thresh", 0.001);
int coords = getParam<int>(layer_params, "coords", 4); int coords = getParam<int>(layer_params, "coords", 4);
int classes = getParam<int>(layer_params, "classes", -1); int classes = getParam<int>(layer_params, "classes", -1);
int num_of_anchors = getParam<int>(layer_params, "num", -1); int num_of_anchors = getParam<int>(layer_params, "num", -1);

@ -43,7 +43,7 @@
#include "../precomp.hpp" #include "../precomp.hpp"
#include <opencv2/dnn/shape_utils.hpp> #include <opencv2/dnn/shape_utils.hpp>
#include <opencv2/dnn/all_layers.hpp> #include <opencv2/dnn/all_layers.hpp>
#include <iostream> #include "nms.inl.hpp"
#include "opencl_kernels_dnn.hpp" #include "opencl_kernels_dnn.hpp"
namespace cv namespace cv
@ -173,8 +173,7 @@ public:
if (nmsThreshold > 0) { if (nmsThreshold > 0) {
Mat mat = outBlob.getMat(ACCESS_WRITE); Mat mat = outBlob.getMat(ACCESS_WRITE);
float *dstData = mat.ptr<float>(); float *dstData = mat.ptr<float>();
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold); do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
} }
} }
@ -263,128 +262,48 @@ public:
} }
if (nmsThreshold > 0) { if (nmsThreshold > 0) {
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold); do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
} }
} }
} }
static inline float rectOverlap(const Rect2f& a, const Rect2f& b)
struct box {
float x, y, w, h;
float *probs;
};
float overlap(float x1, float w1, float x2, float w2)
{ {
float l1 = x1 - w1 / 2; return 1.0f - jaccardDistance(a, b);
float l2 = x2 - w2 / 2;
float left = l1 > l2 ? l1 : l2;
float r1 = x1 + w1 / 2;
float r2 = x2 + w2 / 2;
float right = r1 < r2 ? r1 : r2;
return right - left;
} }
float box_intersection(box a, box b) void do_nms_sort(float *detections, int total, float score_thresh, float nms_thresh)
{ {
float w = overlap(a.x, a.w, b.x, b.w); std::vector<Rect2f> boxes(total);
float h = overlap(a.y, a.h, b.y, b.h); std::vector<float> scores(total);
if (w < 0 || h < 0) return 0;
float area = w*h;
return area;
}
float box_union(box a, box b) for (int i = 0; i < total; ++i)
{ {
float i = box_intersection(a, b); Rect2f &b = boxes[i];
float u = a.w*a.h + b.w*b.h - i; int box_index = i * (classes + coords + 1);
return u; b.width = detections[box_index + 2];
b.height = detections[box_index + 3];
b.x = detections[box_index + 0] - b.width / 2;
b.y = detections[box_index + 1] - b.height / 2;
} }
float box_iou(box a, box b) std::vector<int> indices;
for (int k = 0; k < classes; ++k)
{ {
return box_intersection(a, b) / box_union(a, b); for (int i = 0; i < total; ++i)
}
struct sortable_bbox {
int index;
float *probs;
};
struct nms_comparator {
int k;
nms_comparator(int _k) : k(_k) {}
bool operator ()(sortable_bbox v1, sortable_bbox v2) {
return v2.probs[k] < v1.probs[k];
}
};
void do_nms_sort(float *detections, int total, float nms_thresh)
{ {
std::vector<box> boxes(total);
for (int i = 0; i < total; ++i) {
box &b = boxes[i];
int box_index = i * (classes + coords + 1); int box_index = i * (classes + coords + 1);
b.x = detections[box_index + 0]; int class_index = box_index + 5;
b.y = detections[box_index + 1]; scores[i] = detections[class_index + k];
b.w = detections[box_index + 2]; detections[class_index + k] = 0;
b.h = detections[box_index + 3];
int class_index = i * (classes + 5) + 5;
b.probs = (detections + class_index);
}
std::vector<sortable_bbox> s(total);
for (int i = 0; i < total; ++i) {
s[i].index = i;
int class_index = i * (classes + 5) + 5;
s[i].probs = (detections + class_index);
}
for (int k = 0; k < classes; ++k) {
std::stable_sort(s.begin(), s.end(), nms_comparator(k));
for (int i = 0; i < total; ++i) {
if (boxes[s[i].index].probs[k] == 0) continue;
box a = boxes[s[i].index];
for (int j = i + 1; j < total; ++j) {
box b = boxes[s[j].index];
if (box_iou(a, b) > nms_thresh) {
boxes[s[j].index].probs[k] = 0;
}
}
}
}
} }
NMSFast_(boxes, scores, score_thresh, nms_thresh, 1, 0, indices, rectOverlap);
void do_nms(float *detections, int total, float nms_thresh) for (int i = 0, n = indices.size(); i < n; ++i)
{ {
std::vector<box> boxes(total); int box_index = indices[i] * (classes + coords + 1);
for (int i = 0; i < total; ++i) { int class_index = box_index + 5;
box &b = boxes[i]; detections[class_index + k] = scores[indices[i]];
int box_index = i * (classes + coords + 1);
b.x = detections[box_index + 0];
b.y = detections[box_index + 1];
b.w = detections[box_index + 2];
b.h = detections[box_index + 3];
int class_index = i * (classes + 5) + 5;
b.probs = (detections + class_index);
}
for (int i = 0; i < total; ++i) {
bool any = false;
for (int k = 0; k < classes; ++k) any = any || (boxes[i].probs[k] > 0);
if (!any) {
continue;
}
for (int j = i + 1; j < total; ++j) {
if (box_iou(boxes[i], boxes[j]) > nms_thresh) {
for (int k = 0; k < classes; ++k) {
if (boxes[i].probs[k] < boxes[j].probs[k]) boxes[i].probs[k] = 0;
else boxes[j].probs[k] = 0;
}
}
} }
} }
} }

Loading…
Cancel
Save