Merge pull request #11322 from dkurt:dnn_yolov3

pull/11374/head
Vadim Pisarevsky 7 years ago
commit b290bdafb9
  1. 10
      modules/dnn/perf/perf_net.cpp
  2. 136
      modules/dnn/src/darknet/darknet_io.cpp
  3. 3
      modules/dnn/src/dnn.cpp
  4. 89
      modules/dnn/src/layers/region_layer.cpp
  5. 23
      modules/dnn/src/layers/resize_nearest_neighbor_layer.cpp
  6. 327
      modules/dnn/test/test_darknet_importer.cpp
  7. 86
      samples/dnn/object_detection.cpp
  8. 55
      samples/dnn/object_detection.py

@ -217,6 +217,16 @@ PERF_TEST_P_(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
Mat(cv::Size(300, 300), CV_32FC3));
}
PERF_TEST_P_(DNNTestNetwork, YOLOv3)
{
if (backend != DNN_BACKEND_DEFAULT)
throw SkipTestException("");
Mat sample = imread(findDataFile("dnn/dog416.png", false));
Mat inp;
sample.convertTo(inp, CV_32FC3);
processNet("dnn/yolov3.cfg", "dnn/yolov3.weights", "", inp / 255);
}
const tuple<DNNBackend, DNNTarget> testCases[] = {
#ifdef HAVE_HALIDE
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_HALIDE, DNN_TARGET_CPU),

@ -89,6 +89,8 @@ namespace cv {
return init_val;
}
static const std::string kFirstLayerName = "data";
class setLayersParams {
NetParameter *net;
@ -97,8 +99,8 @@ namespace cv {
std::vector<std::string> fused_layer_names;
public:
setLayersParams(NetParameter *_net, std::string _first_layer = "data") :
net(_net), layer_id(0), last_layer(_first_layer)
setLayersParams(NetParameter *_net) :
net(_net), layer_id(0), last_layer(kFirstLayerName)
{}
void setLayerBlobs(int i, std::vector<cv::Mat> blobs)
@ -275,7 +277,7 @@ namespace cv {
fused_layer_names.push_back(last_layer);
}
void setPermute()
void setPermute(bool isDarknetLayer = true)
{
cv::dnn::LayerParams permute_params;
permute_params.name = "Permute-name";
@ -294,8 +296,11 @@ namespace cv {
last_layer = layer_name;
net->layers.push_back(lp);
layer_id++;
fused_layer_names.push_back(last_layer);
if (isDarknetLayer)
{
layer_id++;
fused_layer_names.push_back(last_layer);
}
}
void setRegion(float thresh, int coords, int classes, int anchors, int classfix, int softmax, int softmax_tree, float *biasData)
@ -327,6 +332,85 @@ namespace cv {
layer_id++;
fused_layer_names.push_back(last_layer);
}
void setYolo(int classes, const std::vector<int>& mask, const std::vector<float>& anchors)
{
cv::dnn::LayerParams region_param;
region_param.name = "Region-name";
region_param.type = "Region";
const int numAnchors = mask.size();
region_param.set<int>("classes", classes);
region_param.set<int>("anchors", numAnchors);
region_param.set<bool>("logistic", true);
std::vector<float> usedAnchors(numAnchors * 2);
for (int i = 0; i < numAnchors; ++i)
{
usedAnchors[i * 2] = anchors[mask[i] * 2];
usedAnchors[i * 2 + 1] = anchors[mask[i] * 2 + 1];
}
cv::Mat biasData_mat = cv::Mat(1, numAnchors * 2, CV_32F, &usedAnchors[0]).clone();
region_param.blobs.push_back(biasData_mat);
darknet::LayerParameter lp;
std::string layer_name = cv::format("yolo_%d", layer_id);
lp.layer_name = layer_name;
lp.layer_type = region_param.type;
lp.layerParams = region_param;
lp.bottom_indexes.push_back(last_layer);
lp.bottom_indexes.push_back(kFirstLayerName);
last_layer = layer_name;
net->layers.push_back(lp);
layer_id++;
fused_layer_names.push_back(last_layer);
}
void setShortcut(int from)
{
cv::dnn::LayerParams shortcut_param;
shortcut_param.name = "Shortcut-name";
shortcut_param.type = "Eltwise";
shortcut_param.set<std::string>("op", "sum");
darknet::LayerParameter lp;
std::string layer_name = cv::format("shortcut_%d", layer_id);
lp.layer_name = layer_name;
lp.layer_type = shortcut_param.type;
lp.layerParams = shortcut_param;
lp.bottom_indexes.push_back(fused_layer_names.at(from));
lp.bottom_indexes.push_back(last_layer);
last_layer = layer_name;
net->layers.push_back(lp);
layer_id++;
fused_layer_names.push_back(last_layer);
}
void setUpsample(int scaleFactor)
{
cv::dnn::LayerParams param;
param.name = "Upsample-name";
param.type = "ResizeNearestNeighbor";
param.set<int>("zoom_factor", scaleFactor);
darknet::LayerParameter lp;
std::string layer_name = cv::format("upsample_%d", layer_id);
lp.layer_name = layer_name;
lp.layer_type = param.type;
lp.layerParams = param;
lp.bottom_indexes.push_back(last_layer);
last_layer = layer_name;
net->layers.push_back(lp);
layer_id++;
fused_layer_names.push_back(last_layer);
}
};
std::string escapeString(const std::string &src)
@ -464,7 +548,7 @@ namespace cv {
current_channels = 0;
for (size_t k = 0; k < layers_vec.size(); ++k) {
layers_vec[k] += layers_counter;
layers_vec[k] = layers_vec[k] > 0 ? layers_vec[k] : (layers_vec[k] + layers_counter);
current_channels += net->out_channels_vec[layers_vec[k]];
}
@ -496,9 +580,43 @@ namespace cv {
CV_Assert(classes > 0 && num_of_anchors > 0 && (num_of_anchors * 2) == anchors_vec.size());
setParams.setPermute();
setParams.setPermute(false);
setParams.setRegion(thresh, coords, classes, num_of_anchors, classfix, softmax, softmax_tree, anchors_vec.data());
}
else if (layer_type == "shortcut")
{
std::string bottom_layer = getParam<std::string>(layer_params, "from", "");
CV_Assert(!bottom_layer.empty());
int from = std::atoi(bottom_layer.c_str());
from += layers_counter;
current_channels = net->out_channels_vec[from];
setParams.setShortcut(from);
}
else if (layer_type == "upsample")
{
int scaleFactor = getParam<int>(layer_params, "stride", 1);
setParams.setUpsample(scaleFactor);
}
else if (layer_type == "yolo")
{
int classes = getParam<int>(layer_params, "classes", -1);
int num_of_anchors = getParam<int>(layer_params, "num", -1);
std::string anchors_values = getParam<std::string>(layer_params, "anchors", std::string());
CV_Assert(!anchors_values.empty());
std::vector<float> anchors_vec = getNumbers<float>(anchors_values);
std::string mask_values = getParam<std::string>(layer_params, "mask", std::string());
CV_Assert(!mask_values.empty());
std::vector<int> mask_vec = getNumbers<int>(mask_values);
CV_Assert(classes > 0 && num_of_anchors > 0 && (num_of_anchors * 2) == anchors_vec.size());
setParams.setPermute(false);
setParams.setYolo(classes, mask_vec, anchors_vec);
}
else {
CV_Error(cv::Error::StsParseError, "Unknown layer type: " + layer_type);
}
@ -598,6 +716,10 @@ namespace cv {
if(activation == "leaky")
++cv_layers_counter;
}
if (layer_type == "region" || layer_type == "yolo")
{
++cv_layers_counter; // For permute.
}
current_channels = net->out_channels_vec[darknet_layers_counter];
}
return true;

@ -1527,12 +1527,11 @@ struct Net::Impl
convLayer = downLayerData->layerInstance.dynamicCast<ConvolutionLayer>();
// first input layer is convolution layer
if( !convLayer.empty() )
if( !convLayer.empty() && eltwiseData->consumers.size() == 1 )
{
// fuse eltwise + activation layer
LayerData *firstConvLayerData = downLayerData;
{
CV_Assert(eltwiseData->consumers.size() == 1);
nextData = &layers[eltwiseData->consumers[0].lid];
lpNext = LayerPin(eltwiseData->consumers[0].lid, 0);
Ptr<ActivationLayer> nextActivLayer;

@ -59,7 +59,7 @@ class RegionLayerImpl CV_FINAL : public RegionLayer
public:
int coords, classes, anchors, classfix;
float thresh, nmsThreshold;
bool useSoftmaxTree, useSoftmax;
bool useSoftmax, useLogistic;
RegionLayerImpl(const LayerParams& params)
{
@ -71,15 +71,17 @@ public:
classes = params.get<int>("classes", 0);
anchors = params.get<int>("anchors", 5);
classfix = params.get<int>("classfix", 0);
useSoftmaxTree = params.get<bool>("softmax_tree", false);
useSoftmax = params.get<bool>("softmax", false);
useLogistic = params.get<bool>("logistic", false);
nmsThreshold = params.get<float>("nms_threshold", 0.4);
CV_Assert(nmsThreshold >= 0.);
CV_Assert(coords == 4);
CV_Assert(classes >= 1);
CV_Assert(anchors >= 1);
CV_Assert(useSoftmaxTree || useSoftmax);
CV_Assert(useLogistic || useSoftmax);
if (params.get<bool>("softmax_tree", false))
CV_Error(cv::Error::StsNotImplemented, "Yolo9000 is not implemented");
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
@ -89,7 +91,7 @@ public:
{
CV_Assert(inputs.size() > 0);
CV_Assert(inputs[0][3] == (1 + coords + classes)*anchors);
outputs = std::vector<MatShape>(inputs.size(), shape(inputs[0][1] * inputs[0][2] * anchors, inputs[0][3] / anchors));
outputs = std::vector<MatShape>(1, shape(inputs[0][1] * inputs[0][2] * anchors, inputs[0][3] / anchors));
return false;
}
@ -124,14 +126,13 @@ public:
std::vector<UMat> inputs;
std::vector<UMat> outputs;
// TODO: implement a logistic activation to classification scores.
if (useLogistic)
return false;
inps.getUMatVector(inputs);
outs.getUMatVector(outputs);
if (useSoftmaxTree) { // Yolo 9000
CV_Error(cv::Error::StsNotImplemented, "Yolo9000 is not implemented");
return false;
}
CV_Assert(inputs.size() >= 1);
int const cell_size = classes + coords + 1;
UMat blob_umat = blobs[0].getUMat(ACCESS_READ);
@ -203,6 +204,7 @@ public:
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_Assert(inputs.size() >= 1);
CV_Assert(outputs.size() == 1);
int const cell_size = classes + coords + 1;
const float* biasData = blobs[0].ptr<float>();
@ -214,6 +216,9 @@ public:
int rows = inpBlob.size[1];
int cols = inpBlob.size[2];
CV_Assert(inputs.size() < 2 || inputs[1]->dims == 4);
int hNorm = inputs.size() > 1 ? inputs[1]->size[2] : rows;
int wNorm = inputs.size() > 1 ? inputs[1]->size[3] : cols;
const float *srcData = inpBlob.ptr<float>();
float *dstData = outBlob.ptr<float>();
@ -225,49 +230,47 @@ public:
dstData[index + 4] = logistic_activate(x); // logistic activation
}
if (useSoftmaxTree) { // Yolo 9000
CV_Error(cv::Error::StsNotImplemented, "Yolo9000 is not implemented");
}
else if (useSoftmax) { // Yolo v2
if (useSoftmax) { // Yolo v2
// softmax activation for Probability, for each grid cell (X x Y x Anchor-index)
for (int i = 0; i < rows*cols*anchors; ++i) {
int index = cell_size*i;
softmax_activate(srcData + index + 5, classes, 1, dstData + index + 5);
}
for (int x = 0; x < cols; ++x)
for(int y = 0; y < rows; ++y)
for (int a = 0; a < anchors; ++a) {
int index = (y*cols + x)*anchors + a; // index for each grid-cell & anchor
int p_index = index * cell_size + 4;
float scale = dstData[p_index];
if (classfix == -1 && scale < .5) scale = 0; // if(t0 < 0.5) t0 = 0;
int box_index = index * cell_size;
dstData[box_index + 0] = (x + logistic_activate(srcData[box_index + 0])) / cols;
dstData[box_index + 1] = (y + logistic_activate(srcData[box_index + 1])) / rows;
dstData[box_index + 2] = exp(srcData[box_index + 2]) * biasData[2 * a] / cols;
dstData[box_index + 3] = exp(srcData[box_index + 3]) * biasData[2 * a + 1] / rows;
int class_index = index * cell_size + 5;
if (useSoftmaxTree) {
CV_Error(cv::Error::StsNotImplemented, "Yolo9000 is not implemented");
}
else {
for (int j = 0; j < classes; ++j) {
float prob = scale*dstData[class_index + j]; // prob = IoU(box, object) = t0 * class-probability
dstData[class_index + j] = (prob > thresh) ? prob : 0; // if (IoU < threshold) IoU = 0;
}
}
}
}
else if (useLogistic) { // Yolo v3
for (int i = 0; i < rows*cols*anchors; ++i)
{
int index = cell_size*i;
const float* input = srcData + index + 5;
float* output = dstData + index + 5;
for (int i = 0; i < classes; ++i)
output[i] = logistic_activate(input[i]);
}
}
for (int x = 0; x < cols; ++x)
for(int y = 0; y < rows; ++y)
for (int a = 0; a < anchors; ++a) {
int index = (y*cols + x)*anchors + a; // index for each grid-cell & anchor
int p_index = index * cell_size + 4;
float scale = dstData[p_index];
if (classfix == -1 && scale < .5) scale = 0; // if(t0 < 0.5) t0 = 0;
int box_index = index * cell_size;
dstData[box_index + 0] = (x + logistic_activate(srcData[box_index + 0])) / cols;
dstData[box_index + 1] = (y + logistic_activate(srcData[box_index + 1])) / rows;
dstData[box_index + 2] = exp(srcData[box_index + 2]) * biasData[2 * a] / hNorm;
dstData[box_index + 3] = exp(srcData[box_index + 3]) * biasData[2 * a + 1] / wNorm;
int class_index = index * cell_size + 5;
for (int j = 0; j < classes; ++j) {
float prob = scale*dstData[class_index + j]; // prob = IoU(box, object) = t0 * class-probability
dstData[class_index + j] = (prob > thresh) ? prob : 0; // if (IoU < threshold) IoU = 0;
}
}
if (nmsThreshold > 0) {
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
}
}
}

@ -16,9 +16,11 @@ public:
ResizeNearestNeighborLayerImpl(const LayerParams& params)
{
setParamsFrom(params);
CV_Assert(params.has("width"), params.has("height"));
outWidth = params.get<float>("width");
outHeight = params.get<float>("height");
CV_Assert(params.has("width") && params.has("height") || params.has("zoom_factor"));
CV_Assert(!params.has("width") && !params.has("height") || !params.has("zoom_factor"));
outWidth = params.get<float>("width", 0);
outHeight = params.get<float>("height", 0);
zoomFactor = params.get<int>("zoom_factor", 1);
alignCorners = params.get<bool>("align_corners", false);
if (alignCorners)
CV_Error(Error::StsNotImplemented, "Nearest neighborhood resize with align_corners=true is not implemented");
@ -31,12 +33,21 @@ public:
{
CV_Assert(inputs.size() == 1, inputs[0].size() == 4);
outputs.resize(1, inputs[0]);
outputs[0][2] = outHeight;
outputs[0][3] = outWidth;
outputs[0][2] = outHeight > 0 ? outHeight : (outputs[0][2] * zoomFactor);
outputs[0][3] = outWidth > 0 ? outWidth : (outputs[0][3] * zoomFactor);
// We can work in-place (do nothing) if input shape == output shape.
return (outputs[0][2] == inputs[0][2]) && (outputs[0][3] == inputs[0][3]);
}
virtual void finalize(const std::vector<Mat*>& inputs, std::vector<Mat> &outputs) CV_OVERRIDE
{
if (!outWidth && !outHeight)
{
outHeight = outputs[0].size[2];
outWidth = outputs[0].size[3];
}
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
@ -65,7 +76,7 @@ public:
}
}
private:
int outWidth, outHeight;
int outWidth, outHeight, zoomFactor;
bool alignCorners;
};

@ -42,9 +42,8 @@
//M*/
#include "test_precomp.hpp"
#include "npy_blob.hpp"
#include <opencv2/dnn/shape_utils.hpp>
#include <opencv2/core/ocl.hpp>
#include <opencv2/ts/ocl_test.hpp>
namespace opencv_test { namespace {
@ -66,238 +65,136 @@ TEST(Test_Darknet, read_yolo_voc)
ASSERT_FALSE(net.empty());
}
OCL_TEST(Reproducibility_TinyYoloVoc, Accuracy)
// Test object detection network from Darknet framework.
static void testDarknetModel(const std::string& cfg, const std::string& weights,
const std::vector<cv::String>& outNames,
const std::vector<int>& refClassIds,
const std::vector<float>& refConfidences,
const std::vector<Rect2f>& refBoxes,
int targetId, float confThreshold = 0.24)
{
Net net;
Mat sample = imread(_tf("dog416.png"));
Mat inp = blobFromImage(sample, 1.0/255, Size(416, 416), Scalar(), true, false);
Net net = readNet(findDataFile("dnn/" + cfg, false),
findDataFile("dnn/" + weights, false));
net.setPreferableTarget(targetId);
net.setInput(inp);
std::vector<Mat> outs;
net.forward(outs, outNames);
std::vector<int> classIds;
std::vector<float> confidences;
std::vector<Rect2f> boxes;
for (int i = 0; i < outs.size(); ++i)
{
const string cfg = findDataFile("dnn/tiny-yolo-voc.cfg", false);
const string model = findDataFile("dnn/tiny-yolo-voc.weights", false);
net = readNetFromDarknet(cfg, model);
ASSERT_FALSE(net.empty());
Mat& out = outs[i];
for (int j = 0; j < out.rows; ++j)
{
Mat scores = out.row(j).colRange(5, out.cols);
double confidence;
Point maxLoc;
minMaxLoc(scores, 0, &confidence, 0, &maxLoc);
if (confidence > confThreshold)
{
float* detection = out.ptr<float>(j);
float centerX = detection[0];
float centerY = detection[1];
float width = detection[2];
float height = detection[3];
boxes.push_back(Rect2f(centerX - 0.5 * width, centerY - 0.5 * height,
width, height));
confidences.push_back(confidence);
classIds.push_back(maxLoc.x);
}
}
}
net.setPreferableBackend(DNN_BACKEND_DEFAULT);
net.setPreferableTarget(DNN_TARGET_OPENCL);
// dog416.png is dog.jpg that resized to 416x416 in the lossless PNG format
Mat sample = imread(_tf("dog416.png"));
ASSERT_TRUE(!sample.empty());
Size inputSize(416, 416);
if (sample.size() != inputSize)
resize(sample, sample, inputSize);
net.setInput(blobFromImage(sample, 1 / 255.F), "data");
Mat out = net.forward("detection_out");
Mat detection;
const float confidenceThreshold = 0.24;
for (int i = 0; i < out.rows; i++) {
const int probability_index = 5;
const int probability_size = out.cols - probability_index;
float *prob_array_ptr = &out.at<float>(i, probability_index);
size_t objectClass = std::max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr;
float confidence = out.at<float>(i, (int)objectClass + probability_index);
if (confidence > confidenceThreshold)
detection.push_back(out.row(i));
ASSERT_EQ(classIds.size(), refClassIds.size());
ASSERT_EQ(confidences.size(), refConfidences.size());
ASSERT_EQ(boxes.size(), refBoxes.size());
for (int i = 0; i < boxes.size(); ++i)
{
ASSERT_EQ(classIds[i], refClassIds[i]);
ASSERT_LE(std::abs(confidences[i] - refConfidences[i]), 1e-4);
float iou = (boxes[i] & refBoxes[i]).area() / (boxes[i] | refBoxes[i]).area();
ASSERT_LE(std::abs(iou - 1.0f), 1e-4);
}
}
// obtained by: ./darknet detector test ./cfg/voc.data ./cfg/tiny-yolo-voc.cfg ./tiny-yolo-voc.weights -thresh 0.24 ./dog416.png
// There are 2 objects (6-car, 11-dog) with 25 values for each:
// { relative_center_x, relative_center_y, relative_width, relative_height, unused_t0, probability_for_each_class[20] }
float ref_array[] = {
0.736762F, 0.239551F, 0.315440F, 0.160779F, 0.761977F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.761967F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.287486F, 0.653731F, 0.315579F, 0.534527F, 0.782737F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.780595F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F
};
const int number_of_objects = 2;
Mat ref(number_of_objects, sizeof(ref_array) / (number_of_objects * sizeof(float)), CV_32FC1, &ref_array);
typedef testing::TestWithParam<DNNTarget> Test_Darknet_nets;
normAssert(ref, detection);
TEST_P(Test_Darknet_nets, YoloVoc)
{
int targetId = GetParam();
std::vector<cv::String> outNames(1, "detection_out");
std::vector<int> classIds(3);
std::vector<float> confidences(3);
std::vector<Rect2f> boxes(3);
classIds[0] = 6; confidences[0] = 0.750469f; boxes[0] = Rect2f(0.577374, 0.127391, 0.325575, 0.173418); // a car
classIds[1] = 1; confidences[1] = 0.780879f; boxes[1] = Rect2f(0.270762, 0.264102, 0.461713, 0.48131); // a bycicle
classIds[2] = 11; confidences[2] = 0.901615f; boxes[2] = Rect2f(0.1386, 0.338509, 0.282737, 0.60028); // a dog
testDarknetModel("yolo-voc.cfg", "yolo-voc.weights", outNames,
classIds, confidences, boxes, targetId);
}
TEST(Reproducibility_TinyYoloVoc, Accuracy)
TEST_P(Test_Darknet_nets, TinyYoloVoc)
{
Net net;
{
const string cfg = findDataFile("dnn/tiny-yolo-voc.cfg", false);
const string model = findDataFile("dnn/tiny-yolo-voc.weights", false);
net = readNetFromDarknet(cfg, model);
ASSERT_FALSE(net.empty());
}
// dog416.png is dog.jpg that resized to 416x416 in the lossless PNG format
Mat sample = imread(_tf("dog416.png"));
ASSERT_TRUE(!sample.empty());
Size inputSize(416, 416);
if (sample.size() != inputSize)
resize(sample, sample, inputSize);
net.setInput(blobFromImage(sample, 1 / 255.F), "data");
Mat out = net.forward("detection_out");
Mat detection;
const float confidenceThreshold = 0.24;
for (int i = 0; i < out.rows; i++) {
const int probability_index = 5;
const int probability_size = out.cols - probability_index;
float *prob_array_ptr = &out.at<float>(i, probability_index);
size_t objectClass = std::max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr;
float confidence = out.at<float>(i, (int)objectClass + probability_index);
if (confidence > confidenceThreshold)
detection.push_back(out.row(i));
}
// obtained by: ./darknet detector test ./cfg/voc.data ./cfg/tiny-yolo-voc.cfg ./tiny-yolo-voc.weights -thresh 0.24 ./dog416.png
// There are 2 objects (6-car, 11-dog) with 25 values for each:
// { relative_center_x, relative_center_y, relative_width, relative_height, unused_t0, probability_for_each_class[20] }
float ref_array[] = {
0.736762F, 0.239551F, 0.315440F, 0.160779F, 0.761977F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.761967F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.287486F, 0.653731F, 0.315579F, 0.534527F, 0.782737F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.780595F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F
};
const int number_of_objects = 2;
Mat ref(number_of_objects, sizeof(ref_array) / (number_of_objects * sizeof(float)), CV_32FC1, &ref_array);
normAssert(ref, detection);
int targetId = GetParam();
std::vector<cv::String> outNames(1, "detection_out");
std::vector<int> classIds(2);
std::vector<float> confidences(2);
std::vector<Rect2f> boxes(2);
classIds[0] = 6; confidences[0] = 0.761967f; boxes[0] = Rect2f(0.579042, 0.159161, 0.31544, 0.160779); // a car
classIds[1] = 11; confidences[1] = 0.780595f; boxes[1] = Rect2f(0.129696, 0.386467, 0.315579, 0.534527); // a dog
testDarknetModel("tiny-yolo-voc.cfg", "tiny-yolo-voc.weights", outNames,
classIds, confidences, boxes, targetId);
}
OCL_TEST(Reproducibility_YoloVoc, Accuracy)
TEST_P(Test_Darknet_nets, YOLOv3)
{
Net net;
{
const string cfg = findDataFile("dnn/yolo-voc.cfg", false);
const string model = findDataFile("dnn/yolo-voc.weights", false);
net = readNetFromDarknet(cfg, model);
ASSERT_FALSE(net.empty());
}
net.setPreferableBackend(DNN_BACKEND_DEFAULT);
net.setPreferableTarget(DNN_TARGET_OPENCL);
// dog416.png is dog.jpg that resized to 416x416 in the lossless PNG format
Mat sample = imread(_tf("dog416.png"));
ASSERT_TRUE(!sample.empty());
Size inputSize(416, 416);
if (sample.size() != inputSize)
resize(sample, sample, inputSize);
net.setInput(blobFromImage(sample, 1 / 255.F), "data");
Mat out = net.forward("detection_out");
Mat detection;
const float confidenceThreshold = 0.24;
for (int i = 0; i < out.rows; i++) {
const int probability_index = 5;
const int probability_size = out.cols - probability_index;
float *prob_array_ptr = &out.at<float>(i, probability_index);
size_t objectClass = std::max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr;
float confidence = out.at<float>(i, (int)objectClass + probability_index);
if (confidence > confidenceThreshold)
detection.push_back(out.row(i));
}
// obtained by: ./darknet detector test ./cfg/voc.data ./cfg/yolo-voc.cfg ./yolo-voc.weights -thresh 0.24 ./dog416.png
// There are 3 objects (6-car, 1-bicycle, 11-dog) with 25 values for each:
// { relative_center_x, relative_center_y, relative_width, relative_height, unused_t0, probability_for_each_class[20] }
float ref_array[] = {
0.740161F, 0.214100F, 0.325575F, 0.173418F, 0.750769F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.750469F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.501618F, 0.504757F, 0.461713F, 0.481310F, 0.783550F, 0.000000F, 0.780879F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.279968F, 0.638651F, 0.282737F, 0.600284F, 0.901864F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.901615F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F
};
int targetId = GetParam();
std::vector<cv::String> outNames(3);
outNames[0] = "yolo_82";
outNames[1] = "yolo_94";
outNames[2] = "yolo_106";
std::vector<int> classIds(3);
std::vector<float> confidences(3);
std::vector<Rect2f> boxes(3);
classIds[0] = 7; confidences[0] = 0.952983f; boxes[0] = Rect2f(0.614622, 0.150257, 0.286747, 0.138994); // a truck
classIds[1] = 1; confidences[1] = 0.987908f; boxes[1] = Rect2f(0.150913, 0.221933, 0.591342, 0.524327); // a bycicle
classIds[2] = 16; confidences[2] = 0.998836f; boxes[2] = Rect2f(0.160024, 0.389964, 0.257861, 0.553752); // a dog (COCO)
testDarknetModel("yolov3.cfg", "yolov3.weights", outNames,
classIds, confidences, boxes, targetId);
}
const int number_of_objects = 3;
Mat ref(number_of_objects, sizeof(ref_array) / (number_of_objects * sizeof(float)), CV_32FC1, &ref_array);
INSTANTIATE_TEST_CASE_P(/**/, Test_Darknet_nets, availableDnnTargets());
normAssert(ref, detection);
static void testDarknetLayer(const std::string& name, bool hasWeights = false)
{
std::string cfg = findDataFile("dnn/darknet/" + name + ".cfg", false);
std::string model = "";
if (hasWeights)
model = findDataFile("dnn/darknet/" + name + ".weights", false);
Mat inp = blobFromNPY(findDataFile("dnn/darknet/" + name + "_in.npy", false));
Mat ref = blobFromNPY(findDataFile("dnn/darknet/" + name + "_out.npy", false));
Net net = readNet(cfg, model);
net.setInput(inp);
Mat out = net.forward();
normAssert(out, ref);
}
TEST(Reproducibility_YoloVoc, Accuracy)
TEST(Test_Darknet, shortcut)
{
Net net;
{
const string cfg = findDataFile("dnn/yolo-voc.cfg", false);
const string model = findDataFile("dnn/yolo-voc.weights", false);
net = readNetFromDarknet(cfg, model);
ASSERT_FALSE(net.empty());
}
// dog416.png is dog.jpg that resized to 416x416 in the lossless PNG format
Mat sample = imread(_tf("dog416.png"));
ASSERT_TRUE(!sample.empty());
Size inputSize(416, 416);
if (sample.size() != inputSize)
resize(sample, sample, inputSize);
net.setInput(blobFromImage(sample, 1 / 255.F), "data");
Mat out = net.forward("detection_out");
Mat detection;
const float confidenceThreshold = 0.24;
for (int i = 0; i < out.rows; i++) {
const int probability_index = 5;
const int probability_size = out.cols - probability_index;
float *prob_array_ptr = &out.at<float>(i, probability_index);
size_t objectClass = std::max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr;
float confidence = out.at<float>(i, (int)objectClass + probability_index);
if (confidence > confidenceThreshold)
detection.push_back(out.row(i));
}
// obtained by: ./darknet detector test ./cfg/voc.data ./cfg/yolo-voc.cfg ./yolo-voc.weights -thresh 0.24 ./dog416.png
// There are 3 objects (6-car, 1-bicycle, 11-dog) with 25 values for each:
// { relative_center_x, relative_center_y, relative_width, relative_height, unused_t0, probability_for_each_class[20] }
float ref_array[] = {
0.740161F, 0.214100F, 0.325575F, 0.173418F, 0.750769F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.750469F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.501618F, 0.504757F, 0.461713F, 0.481310F, 0.783550F, 0.000000F, 0.780879F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.279968F, 0.638651F, 0.282737F, 0.600284F, 0.901864F, 0.000000F, 0.000000F, 0.000000F, 0.000000F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.901615F,
0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F
};
const int number_of_objects = 3;
Mat ref(number_of_objects, sizeof(ref_array) / (number_of_objects * sizeof(float)), CV_32FC1, &ref_array);
testDarknetLayer("shortcut");
}
normAssert(ref, detection);
TEST(Test_Darknet, upsample)
{
testDarknetLayer("upsample");
}
}} // namespace

@ -35,12 +35,14 @@ using namespace dnn;
float confThreshold;
std::vector<std::string> classes;
void postprocess(Mat& frame, const Mat& out, Net& net);
void postprocess(Mat& frame, const std::vector<Mat>& out, Net& net);
void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame);
void callback(int pos, void* userdata);
std::vector<String> getOutputsNames(const Net& net);
int main(int argc, char** argv)
{
CommandLineParser parser(argc, argv, keys);
@ -115,9 +117,10 @@ int main(int argc, char** argv)
Mat imInfo = (Mat_<float>(1, 3) << inpSize.height, inpSize.width, 1.6f);
net.setInput(imInfo, "im_info");
}
Mat out = net.forward();
std::vector<Mat> outs;
net.forward(outs, getOutputsNames(net));
postprocess(frame, out, net);
postprocess(frame, outs, net);
// Put efficiency information.
std::vector<double> layersTimes;
@ -131,18 +134,19 @@ int main(int argc, char** argv)
return 0;
}
void postprocess(Mat& frame, const Mat& out, Net& net)
void postprocess(Mat& frame, const std::vector<Mat>& outs, Net& net)
{
static std::vector<int> outLayers = net.getUnconnectedOutLayers();
static std::string outLayerType = net.getLayer(outLayers[0])->type;
float* data = (float*)out.data;
if (net.getLayer(0)->outputNameToIndex("im_info") != -1) // Faster-RCNN or R-FCN
{
// Network produces output blob with a shape 1x1xNx7 where N is a number of
// detections and an every detection is a vector of values
// [batchId, classId, confidence, left, top, right, bottom]
for (size_t i = 0; i < out.total(); i += 7)
CV_Assert(outs.size() == 1);
float* data = (float*)outs[0].data;
for (size_t i = 0; i < outs[0].total(); i += 7)
{
float confidence = data[i + 2];
if (confidence > confThreshold)
@ -161,7 +165,9 @@ void postprocess(Mat& frame, const Mat& out, Net& net)
// Network produces output blob with a shape 1x1xNx7 where N is a number of
// detections and an every detection is a vector of values
// [batchId, classId, confidence, left, top, right, bottom]
for (size_t i = 0; i < out.total(); i += 7)
CV_Assert(outs.size() == 1);
float* data = (float*)outs[0].data;
for (size_t i = 0; i < outs[0].total(); i += 7)
{
float confidence = data[i + 2];
if (confidence > confThreshold)
@ -177,27 +183,45 @@ void postprocess(Mat& frame, const Mat& out, Net& net)
}
else if (outLayerType == "Region")
{
// Network produces output blob with a shape NxC where N is a number of
// detected objects and C is a number of classes + 4 where the first 4
// numbers are [center_x, center_y, width, height]
for (int i = 0; i < out.rows; ++i, data += out.cols)
std::vector<int> classIds;
std::vector<float> confidences;
std::vector<Rect> boxes;
for (size_t i = 0; i < outs.size(); ++i)
{
Mat confidences = out.row(i).colRange(5, out.cols);
Point classIdPoint;
double confidence;
minMaxLoc(confidences, 0, &confidence, 0, &classIdPoint);
if (confidence > confThreshold)
// Network produces output blob with a shape NxC where N is a number of
// detected objects and C is a number of classes + 4 where the first 4
// numbers are [center_x, center_y, width, height]
float* data = (float*)outs[i].data;
for (int j = 0; j < outs[i].rows; ++j, data += outs[i].cols)
{
int classId = classIdPoint.x;
int centerX = (int)(data[0] * frame.cols);
int centerY = (int)(data[1] * frame.rows);
int width = (int)(data[2] * frame.cols);
int height = (int)(data[3] * frame.rows);
int left = centerX - width / 2;
int top = centerY - height / 2;
drawPred(classId, (float)confidence, left, top, left + width, top + height, frame);
Mat scores = outs[i].row(j).colRange(5, outs[i].cols);
Point classIdPoint;
double confidence;
minMaxLoc(scores, 0, &confidence, 0, &classIdPoint);
if (confidence > confThreshold)
{
int centerX = (int)(data[0] * frame.cols);
int centerY = (int)(data[1] * frame.rows);
int width = (int)(data[2] * frame.cols);
int height = (int)(data[3] * frame.rows);
int left = centerX - width / 2;
int top = centerY - height / 2;
classIds.push_back(classIdPoint.x);
confidences.push_back((float)confidence);
boxes.push_back(Rect(left, top, width, height));
}
}
}
std::vector<int> indices;
NMSBoxes(boxes, confidences, confThreshold, 0.4, indices);
for (size_t i = 0; i < indices.size(); ++i)
{
int idx = indices[i];
Rect box = boxes[idx];
drawPred(classIds[idx], confidences[idx], box.x, box.y,
box.x + box.width, box.y + box.height, frame);
}
}
else
CV_Error(Error::StsNotImplemented, "Unknown output layer type: " + outLayerType);
@ -227,3 +251,17 @@ void callback(int pos, void*)
{
confThreshold = pos * 0.01f;
}
std::vector<String> getOutputsNames(const Net& net)
{
static std::vector<String> names;
if (names.empty())
{
std::vector<int> outLayers = net.getUnconnectedOutLayers();
std::vector<String> layersNames = net.getLayerNames();
names.resize(outLayers.size());
for (size_t i = 0; i < outLayers.size(); ++i)
names[i] = layersNames[outLayers[i] - 1];
}
return names;
}

@ -55,7 +55,11 @@ net.setPreferableTarget(args.target)
confThreshold = args.thr
def postprocess(frame, out):
def getOutputsNames(net):
layersNames = net.getLayerNames()
return [layersNames[i[0] - 1] for i in net.getUnconnectedOutLayers()]
def postprocess(frame, outs):
frameHeight = frame.shape[0]
frameWidth = frame.shape[1]
@ -63,7 +67,7 @@ def postprocess(frame, out):
# Draw a bounding box.
cv.rectangle(frame, (left, top), (right, bottom), (0, 255, 0))
label = '%.2f' % confidence
label = '%.2f' % conf
# Print a label of class.
if classes:
@ -83,6 +87,8 @@ def postprocess(frame, out):
# Network produces output blob with a shape 1x1xNx7 where N is a number of
# detections and an every detection is a vector of values
# [batchId, classId, confidence, left, top, right, bottom]
assert(len(outs) == 1)
out = outs[0]
for detection in out[0, 0]:
confidence = detection[2]
if confidence > confThreshold:
@ -96,6 +102,8 @@ def postprocess(frame, out):
# Network produces output blob with a shape 1x1xNx7 where N is a number of
# detections and an every detection is a vector of values
# [batchId, classId, confidence, left, top, right, bottom]
assert(len(outs) == 1)
out = outs[0]
for detection in out[0, 0]:
confidence = detection[2]
if confidence > confThreshold:
@ -109,18 +117,33 @@ def postprocess(frame, out):
# Network produces output blob with a shape NxC where N is a number of
# detected objects and C is a number of classes + 4 where the first 4
# numbers are [center_x, center_y, width, height]
for detection in out:
confidences = detection[5:]
classId = np.argmax(confidences)
confidence = confidences[classId]
if confidence > confThreshold:
center_x = int(detection[0] * frameWidth)
center_y = int(detection[1] * frameHeight)
width = int(detection[2] * frameWidth)
height = int(detection[3] * frameHeight)
left = center_x - width / 2
top = center_y - height / 2
drawPred(classId, confidence, left, top, left + width, top + height)
classIds = []
confidences = []
boxes = []
for out in outs:
for detection in out:
scores = detection[5:]
classId = np.argmax(scores)
confidence = scores[classId]
if confidence > confThreshold:
center_x = int(detection[0] * frameWidth)
center_y = int(detection[1] * frameHeight)
width = int(detection[2] * frameWidth)
height = int(detection[3] * frameHeight)
left = center_x - width / 2
top = center_y - height / 2
classIds.append(classId)
confidences.append(float(confidence))
boxes.append([left, top, width, height])
indices = cv.dnn.NMSBoxes(boxes, confidences, confThreshold, 0.4)
for i in indices:
i = i[0]
box = boxes[i]
left = box[0]
top = box[1]
width = box[2]
height = box[3]
drawPred(classIds[i], confidences[i], left, top, left + width, top + height)
# Process inputs
winName = 'Deep learning object detection in OpenCV'
@ -152,9 +175,9 @@ while cv.waitKey(1) < 0:
if net.getLayer(0).outputNameToIndex('im_info') != -1: # Faster-RCNN or R-FCN
frame = cv.resize(frame, (inpWidth, inpHeight))
net.setInput(np.array([inpHeight, inpWidth, 1.6], dtype=np.float32), 'im_info');
out = net.forward()
outs = net.forward(getOutputsNames(net))
postprocess(frame, out)
postprocess(frame, outs)
# Put efficiency information.
t, _ = net.getPerfProfile()

Loading…
Cancel
Save