backport YOLOv4x-mish new_coords CUDA implementation

pull/20818/head
Smirnov Egor 4 years ago
parent 13c6eb42e9
commit 9c84749e2c
  1. 78
      modules/dnn/src/cuda/region.cu
  2. 2
      modules/dnn/src/cuda4dnn/kernels/region.hpp
  3. 6
      modules/dnn/src/cuda4dnn/primitives/region.hpp
  4. 5
      modules/dnn/src/layers/region_layer.cpp
  5. 10
      modules/dnn/test/test_darknet_importer.cpp

@ -31,7 +31,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
size_type boxes_per_cell, size_type box_size, size_type boxes_per_cell, size_type box_size,
size_type rows, size_type cols, T scale_x_y, size_type rows, size_type cols, T scale_x_y,
size_type height_norm, size_type width_norm, size_type height_norm, size_type width_norm,
T object_prob_cutoff) T object_prob_cutoff, bool new_coords)
{ {
using vector2_type = get_vector_type_t<T, 2>; using vector2_type = get_vector_type_t<T, 2>;
auto bias_vPtr = vector2_type::get_pointer(bias.data()); auto bias_vPtr = vector2_type::get_pointer(bias.data());
@ -47,22 +47,43 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
const auto y = (box_index % batch_inner_size) / row_inner_size; const auto y = (box_index % batch_inner_size) / row_inner_size;
const auto x = (box_index % row_inner_size) / col_inner_size; const auto x = (box_index % row_inner_size) / col_inner_size;
using device::fast_sigmoid; /* When new_coords is true, we shouldn't use logistic activation again */
const auto tmp_x = (fast_sigmoid(input[box_offset + 0]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5); T objectness_prob;
const auto tmp_y = (fast_sigmoid(input[box_offset + 1]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5); if (new_coords)
output[box_offset + 0] = (T(x) + tmp_x) / T(cols); {
output[box_offset + 1] = (T(y) + tmp_y) / T(rows); const auto tmp_x = (input[box_offset + 0] - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
const auto tmp_y = (input[box_offset + 1] - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
vector2_type bias_xy; output[box_offset + 0] = fast_divide_ftz(static_cast<T>(x) + tmp_x, static_cast<T>(cols));
v_load(bias_xy, bias_vPtr[box_of_the_cell]); output[box_offset + 1] = fast_divide_ftz(static_cast<T>(y) + tmp_y, static_cast<T>(rows));
using device::fast_exp; vector2_type bias_xy;
output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / T(width_norm); v_load(bias_xy, bias_vPtr[box_of_the_cell]);
output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / T(height_norm);
/* squash objectness score into a probability */ output[box_offset + 2] = input[box_offset + 2] * input[box_offset + 2] *
using device::fast_sigmoid; static_cast<T>(4) * bias_xy.data[0] / static_cast<T>(width_norm);
T objectness_prob = fast_sigmoid(input[box_offset + 4]); output[box_offset + 3] = input[box_offset + 3] * input[box_offset + 3] *
static_cast<T>(4) * bias_xy.data[1] / static_cast<T>(height_norm);
objectness_prob = input[box_offset + 4];
}
else
{
const auto tmp_x = (fast_sigmoid(input[box_offset + 0]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
const auto tmp_y = (fast_sigmoid(input[box_offset + 1]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
output[box_offset + 0] = fast_divide_ftz(static_cast<T>(x) + tmp_x, static_cast<T>(cols));
output[box_offset + 1] = fast_divide_ftz(static_cast<T>(y) + tmp_y, static_cast<T>(rows));
vector2_type bias_xy;
v_load(bias_xy, bias_vPtr[box_of_the_cell]);
output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / static_cast<T>(width_norm);
output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / static_cast<T>(height_norm);
/* squash objectness score into a probability */
objectness_prob = fast_sigmoid(input[box_offset + 4]);
}
/* ignore prediction if the objectness probability is less than the cutoff */ /* ignore prediction if the objectness probability is less than the cutoff */
if (objectness_prob < object_prob_cutoff) if (objectness_prob < object_prob_cutoff)
@ -73,7 +94,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
} }
template <class T> template <class T>
__global__ void region_sigmoid_class_score(Span<T> output, View<T> input, T class_prob_cutoff, size_type box_size) __global__ void region_sigmoid_class_score(Span<T> output, View<T> input, T class_prob_cutoff,
size_type box_size, bool new_coords)
{ {
for (auto idx : grid_stride_range(output.size())) { for (auto idx : grid_stride_range(output.size())) {
const index_type box_no = idx / box_size; const index_type box_no = idx / box_size;
@ -92,9 +114,20 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
* *
* to obtain the actual class probability, we multiply the conditional probability * to obtain the actual class probability, we multiply the conditional probability
* with the object probability * with the object probability
*
* when new_coords is true, we shouldn't use logistic activation again.
*/ */
using device::fast_sigmoid;
auto actual_class_prob = objectness_prob * fast_sigmoid(input[idx]); T actual_class_prob;
if (new_coords)
{
actual_class_prob = objectness_prob * input[idx];
}
else
{
actual_class_prob = objectness_prob * fast_sigmoid(input[idx]);
}
if (actual_class_prob <= class_prob_cutoff) if (actual_class_prob <= class_prob_cutoff)
actual_class_prob = T(0); actual_class_prob = T(0);
output[idx] = actual_class_prob; output[idx] = actual_class_prob;
@ -147,7 +180,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
std::size_t boxes_per_cell, std::size_t box_size, std::size_t boxes_per_cell, std::size_t box_size,
std::size_t rows, std::size_t cols, T scale_x_y, std::size_t rows, std::size_t cols, T scale_x_y,
std::size_t height_norm, std::size_t width_norm, std::size_t height_norm, std::size_t width_norm,
bool if_true_sigmoid_else_softmax /* true = sigmoid, false = softmax */) bool if_true_sigmoid_else_softmax, /* true = sigmoid, false = softmax */
bool new_coords)
{ {
CV_Assert(output.size() == input.size()); CV_Assert(output.size() == input.size());
CV_Assert(output.size() % box_size == 0); CV_Assert(output.size() % box_size == 0);
@ -158,12 +192,12 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
launch_kernel(box_kernel, box_policy, launch_kernel(box_kernel, box_policy,
output, input, bias, boxes_per_cell, box_size, output, input, bias, boxes_per_cell, box_size,
rows, cols, scale_x_y, height_norm, width_norm, rows, cols, scale_x_y, height_norm, width_norm,
object_prob_cutoff); object_prob_cutoff, new_coords);
if (if_true_sigmoid_else_softmax) { if (if_true_sigmoid_else_softmax) {
auto kernel_score = raw::region_sigmoid_class_score<T>; auto kernel_score = raw::region_sigmoid_class_score<T>;
auto policy_score = make_policy(kernel_score, output.size(), 0, stream); auto policy_score = make_policy(kernel_score, output.size(), 0, stream);
launch_kernel(kernel_score, policy_score, output, input, class_prob_cutoff, box_size); launch_kernel(kernel_score, policy_score, output, input, class_prob_cutoff, box_size, new_coords);
} else { } else {
auto kernel_score = raw::region_softmax_class_score<T>; auto kernel_score = raw::region_softmax_class_score<T>;
auto policy_score = make_policy(kernel_score, output.size(), 0, stream); auto policy_score = make_policy(kernel_score, output.size(), 0, stream);
@ -173,10 +207,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void region(const Stream&, Span<__half>, View<__half>, View<__half>, template void region(const Stream&, Span<__half>, View<__half>, View<__half>,
__half, __half, std::size_t, std::size_t, std::size_t, std::size_t, __half, std::size_t, std::size_t, bool); __half, __half, std::size_t, std::size_t, std::size_t, std::size_t, __half, std::size_t, std::size_t, bool, bool);
#endif #endif
template void region(const Stream&, Span<float>, View<float>, View<float>, template void region(const Stream&, Span<float>, View<float>, View<float>,
float, float, std::size_t, std::size_t, std::size_t, std::size_t, float, std::size_t, std::size_t, bool); float, float, std::size_t, std::size_t, std::size_t, std::size_t, float, std::size_t, std::size_t, bool, bool);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */ }}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -18,7 +18,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
std::size_t boxes_per_cell, std::size_t box_size, std::size_t boxes_per_cell, std::size_t box_size,
std::size_t rows, std::size_t cols, T scale_x_y, std::size_t rows, std::size_t cols, T scale_x_y,
std::size_t height_norm, std::size_t width_norm, std::size_t height_norm, std::size_t width_norm,
bool if_true_sigmoid_else_softmax); bool if_true_sigmoid_else_softmax, bool new_coords);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */ }}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -60,6 +60,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
T class_prob_cutoff; T class_prob_cutoff;
T nms_iou_threshold; T nms_iou_threshold;
bool new_coords;
}; };
template <class T> template <class T>
@ -87,6 +88,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
class_prob_cutoff = config.class_prob_cutoff; class_prob_cutoff = config.class_prob_cutoff;
nms_iou_threshold = config.nms_iou_threshold; nms_iou_threshold = config.nms_iou_threshold;
new_coords = config.new_coords;
} }
void forward( void forward(
@ -115,7 +117,8 @@ namespace cv { namespace dnn { namespace cuda4dnn {
boxes_per_cell, cell_box_size, boxes_per_cell, cell_box_size,
rows, cols, scale_x_y, rows, cols, scale_x_y,
height_norm, width_norm, height_norm, width_norm,
if_true_sigmoid_else_softmax if_true_sigmoid_else_softmax,
new_coords
); );
if (nms_iou_threshold > 0) { if (nms_iou_threshold > 0) {
@ -176,6 +179,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
T object_prob_cutoff, class_prob_cutoff; T object_prob_cutoff, class_prob_cutoff;
T nms_iou_threshold; T nms_iou_threshold;
bool new_coords;
}; };
}}} /* namespace cv::dnn::cuda4dnn */ }}} /* namespace cv::dnn::cuda4dnn */

@ -125,7 +125,7 @@ public:
#endif #endif
#ifdef HAVE_CUDA #ifdef HAVE_CUDA
if (backendId == DNN_BACKEND_CUDA) if (backendId == DNN_BACKEND_CUDA)
return new_coords == 0; return true;
#endif #endif
return backendId == DNN_BACKEND_OPENCV; return backendId == DNN_BACKEND_OPENCV;
} }
@ -437,11 +437,12 @@ public:
config.scale_x_y = scale_x_y; config.scale_x_y = scale_x_y;
config.object_prob_cutoff = (classfix == -1) ? 0.5 : 0.0; config.object_prob_cutoff = (classfix == -1) ? thresh : 0.f;
config.class_prob_cutoff = thresh; config.class_prob_cutoff = thresh;
config.nms_iou_threshold = nmsThreshold; config.nms_iou_threshold = nmsThreshold;
config.new_coords = (new_coords == 1);
return make_cuda_node<cuda4dnn::RegionOp>(preferableTarget, std::move(context->stream), blobs[0], config); return make_cuda_node<cuda4dnn::RegionOp>(preferableTarget, std::move(context->stream), blobs[0], config);
} }
#endif #endif

@ -745,8 +745,14 @@ TEST_P(Test_Darknet_nets, YOLOv4x_mish)
}; };
Mat ref(N0 + N1, 7, CV_32FC1, (void*)ref_); Mat ref(N0 + N1, 7, CV_32FC1, (void*)ref_);
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.006 : 8e-5; double scoreDiff = 8e-5;
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.042 : 3e-4; double iouDiff = 3e-4;
if (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD || target == DNN_TARGET_CUDA_FP16)
{
scoreDiff = 0.006;
iouDiff = 0.042;
}
std::string config_file = "yolov4x-mish.cfg"; std::string config_file = "yolov4x-mish.cfg";
std::string weights_file = "yolov4x-mish.weights"; std::string weights_file = "yolov4x-mish.weights";

Loading…
Cancel
Save