@ -31,7 +31,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, std::size_t Order,
typename std::enable_if<Order == 2 || Order == 3, bool>::type = true> /* Order has been hardcoded; see code */
typename std::enable_if<Order == 1 || Order == 2 || Order == 3, bool>::type = true> /* Order has been hardcoded; see code */
__global__ void max_pooling_with_indices(
Span<T> output, Span<T> indices, View<T> input, size_type channels,
array<size_type, Order> out_spatial_dims, array<size_type, Order> in_spatial_dims,
@ -72,7 +72,22 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
in_spatial_size *= in_spatial_dims[i];
const auto outer_offset = (n * channels + c) * in_spatial_size;
if (Order == 2) {
if (Order == 1) {
array<index_type, Order> idx;
for (idx[0] = start[0]; idx[0] != end[0]; idx[0]++) {
index_type offset = 0;
index_type stride = 1;
for (int i = Order - 1; i >= 0; i--) {
offset += stride * idx[i];
stride *= in_spatial_dims[i];
}
if (input[outer_offset + offset] > max_value) {
max_idx = offset;
max_value = input[outer_offset + offset];
}
}
} else if (Order == 2) {
array<index_type, Order> idx;
for (idx[0] = start[0]; idx[0] != end[0]; idx[0]++) {
for (idx[1] = start[1]; idx[1] != end[1]; idx[1]++) {
@ -206,8 +221,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
out_spatial_dims[i] = output.get_axis_size(2 + i);
}
/* only max_pooling2d and max_pooling3d are supported */
CV_Assert(2 <= order && order <= 3);
CV_Assert(1 <= order && order <= 3);
std::size_t channels = input.get_axis_size(1);
if (order == 3) {
launch_max_pooling_kernel<T, 3>(stream, output, indices, input, channels,
@ -215,6 +229,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
} else if (order == 2) {
launch_max_pooling_kernel<T, 2>(stream, output, indices, input, channels,
out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
} else if (order == 1) {
launch_max_pooling_kernel<T, 1>(stream, output, indices, input, channels,
out_spatial_dims, in_spatial_dims, window_size, strides, padding_left);
}
}