From 9c84749e2c28d0744376e515a159c3cf30bb43c4 Mon Sep 17 00:00:00 2001 From: Smirnov Egor Date: Wed, 6 Oct 2021 16:09:20 +0300 Subject: [PATCH] backport YOLOv4x-mish new_coords CUDA implementation --- modules/dnn/src/cuda/region.cu | 78 +++++++++++++------ modules/dnn/src/cuda4dnn/kernels/region.hpp | 2 +- .../dnn/src/cuda4dnn/primitives/region.hpp | 6 +- modules/dnn/src/layers/region_layer.cpp | 5 +- modules/dnn/test/test_darknet_importer.cpp | 10 ++- 5 files changed, 73 insertions(+), 28 deletions(-) diff --git a/modules/dnn/src/cuda/region.cu b/modules/dnn/src/cuda/region.cu index 06b44abe9c..3700a93d99 100644 --- a/modules/dnn/src/cuda/region.cu +++ b/modules/dnn/src/cuda/region.cu @@ -31,7 +31,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { size_type boxes_per_cell, size_type box_size, size_type rows, size_type cols, T scale_x_y, size_type height_norm, size_type width_norm, - T object_prob_cutoff) + T object_prob_cutoff, bool new_coords) { using vector2_type = get_vector_type_t; auto bias_vPtr = vector2_type::get_pointer(bias.data()); @@ -47,22 +47,43 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { const auto y = (box_index % batch_inner_size) / row_inner_size; const auto x = (box_index % row_inner_size) / col_inner_size; - using device::fast_sigmoid; - const auto tmp_x = (fast_sigmoid(input[box_offset + 0]) - static_cast(0.5)) * scale_x_y + static_cast(0.5); - const auto tmp_y = (fast_sigmoid(input[box_offset + 1]) - static_cast(0.5)) * scale_x_y + static_cast(0.5); - output[box_offset + 0] = (T(x) + tmp_x) / T(cols); - output[box_offset + 1] = (T(y) + tmp_y) / T(rows); + /* When new_coords is true, we shouldn't use logistic activation again */ + T objectness_prob; + if (new_coords) + { + const auto tmp_x = (input[box_offset + 0] - static_cast(0.5)) * scale_x_y + static_cast(0.5); + const auto tmp_y = (input[box_offset + 1] - static_cast(0.5)) * scale_x_y + static_cast(0.5); - vector2_type bias_xy; - v_load(bias_xy, bias_vPtr[box_of_the_cell]); + output[box_offset + 0] = fast_divide_ftz(static_cast(x) + tmp_x, static_cast(cols)); + output[box_offset + 1] = fast_divide_ftz(static_cast(y) + tmp_y, static_cast(rows)); - using device::fast_exp; - output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / T(width_norm); - output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / T(height_norm); + vector2_type bias_xy; + v_load(bias_xy, bias_vPtr[box_of_the_cell]); - /* squash objectness score into a probability */ - using device::fast_sigmoid; - T objectness_prob = fast_sigmoid(input[box_offset + 4]); + output[box_offset + 2] = input[box_offset + 2] * input[box_offset + 2] * + static_cast(4) * bias_xy.data[0] / static_cast(width_norm); + output[box_offset + 3] = input[box_offset + 3] * input[box_offset + 3] * + static_cast(4) * bias_xy.data[1] / static_cast(height_norm); + + objectness_prob = input[box_offset + 4]; + } + else + { + const auto tmp_x = (fast_sigmoid(input[box_offset + 0]) - static_cast(0.5)) * scale_x_y + static_cast(0.5); + const auto tmp_y = (fast_sigmoid(input[box_offset + 1]) - static_cast(0.5)) * scale_x_y + static_cast(0.5); + + output[box_offset + 0] = fast_divide_ftz(static_cast(x) + tmp_x, static_cast(cols)); + output[box_offset + 1] = fast_divide_ftz(static_cast(y) + tmp_y, static_cast(rows)); + + vector2_type bias_xy; + v_load(bias_xy, bias_vPtr[box_of_the_cell]); + + output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / static_cast(width_norm); + output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / static_cast(height_norm); + + /* squash objectness score into a probability */ + objectness_prob = fast_sigmoid(input[box_offset + 4]); + } /* ignore prediction if the objectness probability is less than the cutoff */ if (objectness_prob < object_prob_cutoff) @@ -73,7 +94,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { } template - __global__ void region_sigmoid_class_score(Span output, View input, T class_prob_cutoff, size_type box_size) + __global__ void region_sigmoid_class_score(Span output, View input, T class_prob_cutoff, + size_type box_size, bool new_coords) { for (auto idx : grid_stride_range(output.size())) { const index_type box_no = idx / box_size; @@ -92,9 +114,20 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { * * to obtain the actual class probability, we multiply the conditional probability * with the object probability + * + * when new_coords is true, we shouldn't use logistic activation again. */ - using device::fast_sigmoid; - auto actual_class_prob = objectness_prob * fast_sigmoid(input[idx]); + + T actual_class_prob; + if (new_coords) + { + actual_class_prob = objectness_prob * input[idx]; + } + else + { + actual_class_prob = objectness_prob * fast_sigmoid(input[idx]); + } + if (actual_class_prob <= class_prob_cutoff) actual_class_prob = T(0); output[idx] = actual_class_prob; @@ -147,7 +180,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { std::size_t boxes_per_cell, std::size_t box_size, std::size_t rows, std::size_t cols, T scale_x_y, std::size_t height_norm, std::size_t width_norm, - bool if_true_sigmoid_else_softmax /* true = sigmoid, false = softmax */) + bool if_true_sigmoid_else_softmax, /* true = sigmoid, false = softmax */ + bool new_coords) { CV_Assert(output.size() == input.size()); CV_Assert(output.size() % box_size == 0); @@ -158,12 +192,12 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { launch_kernel(box_kernel, box_policy, output, input, bias, boxes_per_cell, box_size, rows, cols, scale_x_y, height_norm, width_norm, - object_prob_cutoff); + object_prob_cutoff, new_coords); if (if_true_sigmoid_else_softmax) { auto kernel_score = raw::region_sigmoid_class_score; auto policy_score = make_policy(kernel_score, output.size(), 0, stream); - launch_kernel(kernel_score, policy_score, output, input, class_prob_cutoff, box_size); + launch_kernel(kernel_score, policy_score, output, input, class_prob_cutoff, box_size, new_coords); } else { auto kernel_score = raw::region_softmax_class_score; auto policy_score = make_policy(kernel_score, output.size(), 0, stream); @@ -173,10 +207,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void region(const Stream&, Span<__half>, View<__half>, View<__half>, - __half, __half, std::size_t, std::size_t, std::size_t, std::size_t, __half, std::size_t, std::size_t, bool); + __half, __half, std::size_t, std::size_t, std::size_t, std::size_t, __half, std::size_t, std::size_t, bool, bool); #endif template void region(const Stream&, Span, View, View, - float, float, std::size_t, std::size_t, std::size_t, std::size_t, float, std::size_t, std::size_t, bool); + float, float, std::size_t, std::size_t, std::size_t, std::size_t, float, std::size_t, std::size_t, bool, bool); }}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda4dnn/kernels/region.hpp b/modules/dnn/src/cuda4dnn/kernels/region.hpp index 87742d2f81..b815fb11c9 100644 --- a/modules/dnn/src/cuda4dnn/kernels/region.hpp +++ b/modules/dnn/src/cuda4dnn/kernels/region.hpp @@ -18,7 +18,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { std::size_t boxes_per_cell, std::size_t box_size, std::size_t rows, std::size_t cols, T scale_x_y, std::size_t height_norm, std::size_t width_norm, - bool if_true_sigmoid_else_softmax); + bool if_true_sigmoid_else_softmax, bool new_coords); }}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda4dnn/primitives/region.hpp b/modules/dnn/src/cuda4dnn/primitives/region.hpp index 7813a47bc7..d22d44214e 100644 --- a/modules/dnn/src/cuda4dnn/primitives/region.hpp +++ b/modules/dnn/src/cuda4dnn/primitives/region.hpp @@ -60,6 +60,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { T class_prob_cutoff; T nms_iou_threshold; + bool new_coords; }; template @@ -87,6 +88,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { class_prob_cutoff = config.class_prob_cutoff; nms_iou_threshold = config.nms_iou_threshold; + new_coords = config.new_coords; } void forward( @@ -115,7 +117,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { boxes_per_cell, cell_box_size, rows, cols, scale_x_y, height_norm, width_norm, - if_true_sigmoid_else_softmax + if_true_sigmoid_else_softmax, + new_coords ); if (nms_iou_threshold > 0) { @@ -176,6 +179,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { T object_prob_cutoff, class_prob_cutoff; T nms_iou_threshold; + bool new_coords; }; }}} /* namespace cv::dnn::cuda4dnn */ diff --git a/modules/dnn/src/layers/region_layer.cpp b/modules/dnn/src/layers/region_layer.cpp index 242e5e6f88..73ed53974f 100644 --- a/modules/dnn/src/layers/region_layer.cpp +++ b/modules/dnn/src/layers/region_layer.cpp @@ -125,7 +125,7 @@ public: #endif #ifdef HAVE_CUDA if (backendId == DNN_BACKEND_CUDA) - return new_coords == 0; + return true; #endif return backendId == DNN_BACKEND_OPENCV; } @@ -437,11 +437,12 @@ public: config.scale_x_y = scale_x_y; - config.object_prob_cutoff = (classfix == -1) ? 0.5 : 0.0; + config.object_prob_cutoff = (classfix == -1) ? thresh : 0.f; config.class_prob_cutoff = thresh; config.nms_iou_threshold = nmsThreshold; + config.new_coords = (new_coords == 1); return make_cuda_node(preferableTarget, std::move(context->stream), blobs[0], config); } #endif diff --git a/modules/dnn/test/test_darknet_importer.cpp b/modules/dnn/test/test_darknet_importer.cpp index c85981a8bc..d822efa889 100644 --- a/modules/dnn/test/test_darknet_importer.cpp +++ b/modules/dnn/test/test_darknet_importer.cpp @@ -745,8 +745,14 @@ TEST_P(Test_Darknet_nets, YOLOv4x_mish) }; Mat ref(N0 + N1, 7, CV_32FC1, (void*)ref_); - double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.006 : 8e-5; - double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.042 : 3e-4; + double scoreDiff = 8e-5; + double iouDiff = 3e-4; + + if (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD || target == DNN_TARGET_CUDA_FP16) + { + scoreDiff = 0.006; + iouDiff = 0.042; + } std::string config_file = "yolov4x-mish.cfg"; std::string weights_file = "yolov4x-mish.weights";