diff --git a/modules/dnn/src/layers/pooling_layer.cpp b/modules/dnn/src/layers/pooling_layer.cpp index f27df8591..9aaee31ea 100644 --- a/modules/dnn/src/layers/pooling_layer.cpp +++ b/modules/dnn/src/layers/pooling_layer.cpp @@ -132,7 +132,7 @@ void PoolingLayerImpl::maxPooling(Blob &src, Blob &dst, Blob &mask) bool PoolingLayerImpl::maxPooling_ocl(Blob &src, Blob &dst, Blob &mask) { - return pooling_ocl("MaxPoolForward", src, dst); + return pooling_ocl("MaxPoolForward", src, dst, &mask); } void PoolingLayerImpl::avePooling(Blob &src, Blob &dst) @@ -201,22 +201,36 @@ bool PoolingLayerImpl::pooling_ocl(const char *kname, const Blob &src, Blob &dst { const UMat &srcMat = src.umatRefConst(); UMat &dstMat = dst.umatRef(); - UMat* indexesMat = mask == NULL ? NULL : &dst.umatRef(); + UMat *maskUMat = mask == NULL ? NULL : &mask->umatRef(); + CV_Assert(maskUMat == NULL || maskUMat->type() == CV_32FC1); // FIXIT CV_32SC1 + CV_Assert(maskUMat == NULL || maskUMat->offset == 0); CV_Assert(srcMat.offset == 0 && dstMat.offset == 0); - ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc, String("-DT=") + ocl::typeToStr(src.type())); + ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc, + cv::format("-DT=%s%s", ocl::typeToStr(src.type()), maskUMat ? " -DMASK=1" : "")); if (ker.empty()) return false; BlobShape s = src.shape(); size_t nthreads = dst.total(); - ker.args((int)nthreads, + if (maskUMat) + { + ker.args((int)nthreads, ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3], out.height, out.width, kernel.height, kernel.width, stride.height, stride.width, pad.height, pad.width, ocl::KernelArg::PtrWriteOnly(dstMat), - ocl::KernelArg(ocl::KernelArg::PTR_ONLY + ocl::KernelArg::WRITE_ONLY, indexesMat)); + ocl::KernelArg::PtrWriteOnly(*maskUMat)); + } + else + { + ker.args((int)nthreads, + ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3], + out.height, out.width, kernel.height, kernel.width, + stride.height, stride.width, pad.height, pad.width, + ocl::KernelArg::PtrWriteOnly(dstMat)); + } size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize(); if (!ker.run(1, &nthreads, &wgSize, true)) diff --git a/modules/dnn/src/opencl/pooling.cl b/modules/dnn/src/opencl/pooling.cl index 80c96f5ae..adfd59e6d 100644 --- a/modules/dnn/src/opencl/pooling.cl +++ b/modules/dnn/src/opencl/pooling.cl @@ -24,8 +24,16 @@ * POSSIBILITY OF SUCH DAMAGE. **************************************************************************************/ -__kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, __global T* top_data, __global int* mask -) { +__kernel void MaxPoolForward(const int nthreads, + __global T* bottom_data, const int num, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + __global T* top_data +#ifdef MASK + , __global float* mask +#endif + ) +{ int index = get_global_id(0); int tmp = get_global_size(0); for(index; index < nthreads; index += tmp) { @@ -51,15 +59,25 @@ __kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const } } } + top_data[index] = maxval; - if (mask) { - mask[index] = maxidx; - } +#ifdef MASK + mask[index] = maxidx; +#endif } } -__kernel void AvePoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w,__global T* top_data) { +__kernel void AvePoolForward(const int nthreads, + __global T* bottom_data, const int num, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + __global T* top_data +#ifdef MASK + , __global float* mask // NOT USED +#endif + ) +{ int index = get_global_id(0); int tmp = get_global_size(0); for(index; index < nthreads; index+=tmp) {