From 48eecafc89650ba8971681b4e07a1b909458a17c Mon Sep 17 00:00:00 2001 From: YashasSamaga Date: Mon, 30 Dec 2019 23:02:17 +0530 Subject: [PATCH] simplify code to help MSVC 19.10 and lower --- modules/dnn/src/cuda/grid_stride_range.hpp | 42 +++++++++++--------- modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp | 7 ++-- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/modules/dnn/src/cuda/grid_stride_range.hpp b/modules/dnn/src/cuda/grid_stride_range.hpp index 4b61a0f574..4693547f3a 100644 --- a/modules/dnn/src/cuda/grid_stride_range.hpp +++ b/modules/dnn/src/cuda/grid_stride_range.hpp @@ -12,25 +12,29 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device { namespace detail { - template __device__ auto getGridDim()->decltype(dim3::x); - template <> inline __device__ auto getGridDim<0>()->decltype(dim3::x) { return gridDim.x; } - template <> inline __device__ auto getGridDim<1>()->decltype(dim3::x) { return gridDim.y; } - template <> inline __device__ auto getGridDim<2>()->decltype(dim3::x) { return gridDim.z; } - - template __device__ auto getBlockDim()->decltype(dim3::x); - template <> inline __device__ auto getBlockDim<0>()->decltype(dim3::x) { return blockDim.x; } - template <> inline __device__ auto getBlockDim<1>()->decltype(dim3::x) { return blockDim.y; } - template <> inline __device__ auto getBlockDim<2>()->decltype(dim3::x) { return blockDim.z; } - - template __device__ auto getBlockIdx()->decltype(uint3::x); - template <> inline __device__ auto getBlockIdx<0>()->decltype(uint3::x) { return blockIdx.x; } - template <> inline __device__ auto getBlockIdx<1>()->decltype(uint3::x) { return blockIdx.y; } - template <> inline __device__ auto getBlockIdx<2>()->decltype(uint3::x) { return blockIdx.z; } - - template __device__ auto getThreadIdx()->decltype(uint3::x); - template <> inline __device__ auto getThreadIdx<0>()->decltype(uint3::x) { return threadIdx.x; } - template <> inline __device__ auto getThreadIdx<1>()->decltype(uint3::x) { return threadIdx.y; } - template <> inline __device__ auto getThreadIdx<2>()->decltype(uint3::x) { return threadIdx.z; } + using dim3_member_type = decltype(dim3::x); + + template __device__ dim3_member_type getGridDim(); + template <> inline __device__ dim3_member_type getGridDim<0>() { return gridDim.x; } + template <> inline __device__ dim3_member_type getGridDim<1>() { return gridDim.y; } + template <> inline __device__ dim3_member_type getGridDim<2>() { return gridDim.z; } + + template __device__ dim3_member_type getBlockDim(); + template <> inline __device__ dim3_member_type getBlockDim<0>() { return blockDim.x; } + template <> inline __device__ dim3_member_type getBlockDim<1>() { return blockDim.y; } + template <> inline __device__ dim3_member_type getBlockDim<2>() { return blockDim.z; } + + using uint3_member_type = decltype(uint3::x); + + template __device__ uint3_member_type getBlockIdx(); + template <> inline __device__ uint3_member_type getBlockIdx<0>() { return blockIdx.x; } + template <> inline __device__ uint3_member_type getBlockIdx<1>() { return blockIdx.y; } + template <> inline __device__ uint3_member_type getBlockIdx<2>() { return blockIdx.z; } + + template __device__ uint3_member_type getThreadIdx(); + template <> inline __device__ uint3_member_type getThreadIdx<0>() { return threadIdx.x; } + template <> inline __device__ uint3_member_type getThreadIdx<1>() { return threadIdx.y; } + template <> inline __device__ uint3_member_type getThreadIdx<2>() { return threadIdx.z; } } template diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp index 06879448d7..19b46a9b36 100644 --- a/modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp @@ -37,9 +37,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu } /** get_data_type returns the equivalent cudnn enumeration constant for type T */ - template auto get_data_type()->decltype(CUDNN_DATA_FLOAT); - template <> inline auto get_data_type()->decltype(CUDNN_DATA_HALF) { return CUDNN_DATA_HALF; } - template <> inline auto get_data_type()->decltype(CUDNN_DATA_FLOAT) { return CUDNN_DATA_FLOAT; } + using cudnn_data_enum_type = decltype(CUDNN_DATA_FLOAT); + template cudnn_data_enum_type get_data_type(); + template <> inline cudnn_data_enum_type get_data_type() { return CUDNN_DATA_HALF; } + template <> inline cudnn_data_enum_type get_data_type() { return CUDNN_DATA_FLOAT; } } /** @brief noncopyable cuDNN smart handle