diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 70ee37555..0d2a9af1e 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -209,7 +209,7 @@ namespace dnn { public: - Size kernel, pad, stride; + Size kernel, stride, pad; }; class CV_EXPORTS_W ConvolutionLayer : public BaseConvolutionLayer @@ -232,7 +232,7 @@ namespace dnn { public: - enum + enum Type { CHANNEL_NRM, SPATIAL_NRM @@ -241,9 +241,26 @@ namespace dnn int size; double alpha, beta; + + static Ptr<LRNLayer> create(int type = CHANNEL_NRM, int size = 5, double alpha = 1, double beta = 0.75); }; + class CV_EXPORTS_W PoolingLayer : public Layer + { + public: + enum Type + { + MAX, + AVE, + STOCHASTIC + }; + + int type; + Size kernel, stride, pad; + + static Ptr<PoolingLayer> create(int type = MAX, Size kernel = Size(2, 2), Size pad = Size(0, 0), Size stride = Size(1, 1)); + }; //! @} //! @} diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index ded668ae0..58e6e0d3f 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -81,7 +81,7 @@ void initModule() REG_RUNTIME_LAYER_CLASS(Split, SplitLayer) REG_RUNTIME_LAYER_CLASS(Reshape, ReshapeLayer) REG_STATIC_LAYER_FUNC(Flatten, createFlattenLayer) - REG_RUNTIME_LAYER_CLASS(Pooling, PoolingLayerImpl) + REG_RUNTIME_LAYER_FUNC(Pooling, createPoolingLayerFromCaffe) REG_RUNTIME_LAYER_CLASS(MVN, MVNLayer) REG_RUNTIME_LAYER_FUNC(LRN, createLRNLayerFromCaffe) REG_RUNTIME_LAYER_CLASS(InnerProduct, FullyConnectedLayer) diff --git a/modules/dnn/src/layers/lrn_layer.cpp b/modules/dnn/src/layers/lrn_layer.cpp index f846d765f..9694b44cb 100644 --- a/modules/dnn/src/layers/lrn_layer.cpp +++ b/modules/dnn/src/layers/lrn_layer.cpp @@ -53,17 +53,18 @@ namespace cv namespace dnn { -LRNLayerImpl::LRNLayerImpl() +LRNLayerImpl::LRNLayerImpl(int type_, int size_, double alpha_, double beta_) { - size = 5; - alpha = 1; - beta = 0.75; - type = CHANNEL_NRM; + type = type_; + size = size_; + alpha = alpha_; + beta = beta_; } void LRNLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs) { CV_Assert(inputs.size() == 1 && inputs[0]->dims() == 4); + CV_Assert(type == CHANNEL_NRM || type == SPATIAL_NRM); useOpenCL = cv::ocl::useOpenCL(); if (type == SPATIAL_NRM && !useOpenCL) @@ -154,6 +155,7 @@ void LRNLayerImpl::channelNoramlization_(Blob &srcBlob, Blob &dstBlob) bool LRNLayerImpl::channelNoramlization_ocl(const UMat &src, UMat &dst) { +#ifdef HAVE_OPENCL if (src.offset != 0 || dst.offset != 0) //TODO: add offset return false; @@ -187,6 +189,9 @@ bool LRNLayerImpl::channelNoramlization_ocl(const UMat &src, UMat &dst) return false; return true; +#else + return false; +#endif } void LRNLayerImpl::spatialNormalization(Blob &src, Blob &dst) @@ -232,27 +237,31 @@ void LRNLayerImpl::spatialNormalization_(Blob &srcBlob, Blob &dstBlob) } } -Ptr<Layer> createLRNLayerFromCaffe(LayerParams ¶ms) + +Ptr<LRNLayer> LRNLayer::create(int type, int size, double alpha, double beta) { - LRNLayerImpl *l = new LRNLayerImpl(); + return Ptr<LRNLayer>(new LRNLayerImpl(type, size, alpha, beta)); +} +Ptr<Layer> createLRNLayerFromCaffe(LayerParams ¶ms) +{ + int type; String nrmType = params.get<String>("norm_region", "ACROSS_CHANNELS"); if (nrmType == "ACROSS_CHANNELS") - l->type = LRNLayer::CHANNEL_NRM; + type = LRNLayer::CHANNEL_NRM; else if (nrmType == "WITHIN_CHANNEL") - l->type = LRNLayer::SPATIAL_NRM; + type = LRNLayer::SPATIAL_NRM; else CV_Error(Error::StsBadArg, "Unknown region type \"" + nrmType + "\""); int size = params.get<int>("local_size", 5); if (size % 2 != 1 || size <= 0) CV_Error(Error::StsBadArg, "LRN layer supports only positive odd values for local_size"); - l->size = size; - l->alpha = params.get<double>("alpha", 1); - l->beta = params.get<double>("beta", 0.75); + double alpha = params.get<double>("alpha", 1); + double beta = params.get<double>("beta", 0.75); - return Ptr<Layer>(l); + return Ptr<Layer>(new LRNLayerImpl(type, size, alpha, beta)); } } diff --git a/modules/dnn/src/layers/lrn_layer.hpp b/modules/dnn/src/layers/lrn_layer.hpp index 29f9252e5..43d652868 100644 --- a/modules/dnn/src/layers/lrn_layer.hpp +++ b/modules/dnn/src/layers/lrn_layer.hpp @@ -64,7 +64,7 @@ namespace dnn public: - LRNLayerImpl(); + LRNLayerImpl(int type = CHANNEL_NRM, int size = 5, double alpha = 1, double beta = 0.75); void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs); void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs); }; diff --git a/modules/dnn/src/layers/pooling_layer.cpp b/modules/dnn/src/layers/pooling_layer.cpp index 858cad74d..04e872011 100644 --- a/modules/dnn/src/layers/pooling_layer.cpp +++ b/modules/dnn/src/layers/pooling_layer.cpp @@ -42,8 +42,10 @@ #include "../precomp.hpp" #include "layers_common.hpp" #include "pooling_layer.hpp" +#include "opencl_kernels_dnn.hpp" #include <float.h> #include <algorithm> +#include <opencv2/core/ocl.hpp> using std::max; using std::min; @@ -53,154 +55,243 @@ namespace dnn { //TODO: add ceil_mode param - PoolingLayer::PoolingLayer(LayerParams ¶ms) : Layer(params) - { - if (params.has("pool")) - { - String pool = params.get<String>("pool").toLowerCase(); - if (pool == "max") - type = MAX; - else if (pool == "ave") - type = AVE; - else if (pool == "stochastic") - type = STOCHASTIC; - else - CV_Error(cv::Error::StsBadArg, "Unknown pooling type \"" + pool + "\""); - } - else - { - type = MAX; - } +PoolingLayerImpl::PoolingLayerImpl() +{ - getCaffeConvParams(params, kernel, pad, stride); - } +} - void PoolingLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs) - { - CV_Assert(inputs.size() > 0); +PoolingLayerImpl::PoolingLayerImpl(int type_, Size kernel_, Size pad_, Size stride_) +{ + type = type_; + kernel = kernel_; + pad = pad_; + stride = stride_; +} - inp = inputs[0]->size2(); - computeOutputShape(inp); +void PoolingLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs) +{ + CV_Assert(inputs.size() > 0); - outputs.resize(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) - { - CV_Assert(inputs[i]->rows() == inp.height && inputs[i]->cols() == inp.width); - outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), out.height, out.width)); - } + inp = inputs[0]->size2(); + computeOutputShape(inp); + + useOpenCL = ocl::useOpenCL(); + + outputs.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) + { + CV_Assert(inputs[i]->rows() == inp.height && inputs[i]->cols() == inp.width); + outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), out.height, out.width)); } +} - void PoolingLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs) +void PoolingLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs) +{ + for (size_t ii = 0; ii < inputs.size(); ii++) { - for (size_t ii = 0; ii < inputs.size(); ii++) + switch (type) { - switch (type) - { - case MAX: - maxPooling(*inputs[ii], outputs[ii]); - break; - case AVE: - avePooling(*inputs[ii], outputs[ii]); - break; - default: - CV_Error(Error::StsNotImplemented, "Not implemented"); - break; - } + case MAX: + maxPooling(*inputs[ii], outputs[ii]); + break; + case AVE: + avePooling(*inputs[ii], outputs[ii]); + break; + default: + CV_Error(Error::StsNotImplemented, "Not implemented"); + break; } } +} + +void PoolingLayerImpl::maxPooling(Blob &src, Blob &dst) +{ + if (!useOpenCL) + maxPooling_cpu(src, dst); + else + { + CV_Assert(maxPooling_ocl(src, dst)); + } +} + +bool PoolingLayerImpl::maxPooling_ocl(Blob &src, Blob &dst) +{ + return pooling_ocl("MaxPoolForward", src, dst); +} - void PoolingLayer::maxPooling(Blob &input, Blob &output) +void PoolingLayerImpl::avePooling(Blob &src, Blob &dst) +{ + if (!useOpenCL) + avePooling_cpu(src, dst); + else { - CV_DbgAssert(output.rows() == out.height && output.cols() == out.width); + CV_Assert(avePooling_ocl(src, dst)); + } +} + +bool PoolingLayerImpl::avePooling_ocl(Blob &src, Blob &dst) +{ + return pooling_ocl("AvePoolForward", src, dst); +} + +void PoolingLayerImpl::maxPooling_cpu(Blob &src, Blob &dst) +{ + CV_DbgAssert(dst.rows() == out.height && dst.cols() == out.width); - for (int n = 0; n < input.num(); ++n) + for (int n = 0; n < src.num(); ++n) + { + for (int c = 0; c < src.channels(); ++c) { - for (int c = 0; c < input.channels(); ++c) - { - float *srcData = input.ptrf(n, c); - float *dstData = output.ptrf(n, c); + const float *srcData = src.ptrf(n, c); + float *dstData = dst.ptrf(n, c); - for (int ph = 0; ph < out.height; ++ph) + for (int ph = 0; ph < out.height; ++ph) + { + for (int pw = 0; pw < out.width; ++pw) { - for (int pw = 0; pw < out.width; ++pw) - { - int hstart = ph * stride.height - pad.height; - int wstart = pw * stride.width - pad.width; - int hend = min(hstart + kernel.height, inp.height); - int wend = min(wstart + kernel.width, inp.width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - const int poolIndex = ph * out.width + pw; - float max_val = -FLT_MAX; - - for (int h = hstart; h < hend; ++h) - for (int w = wstart; w < wend; ++w) - { - const int index = h * inp.width + w; - if (srcData[index] > max_val) - max_val = srcData[index]; - } - - dstData[poolIndex] = max_val; - } + int hstart = ph * stride.height - pad.height; + int wstart = pw * stride.width - pad.width; + int hend = min(hstart + kernel.height, inp.height); + int wend = min(wstart + kernel.width, inp.width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + const int poolIndex = ph * out.width + pw; + float max_val = -FLT_MAX; + + for (int h = hstart; h < hend; ++h) + for (int w = wstart; w < wend; ++w) + { + const int index = h * inp.width + w; + if (srcData[index] > max_val) + max_val = srcData[index]; + } + + dstData[poolIndex] = max_val; } } } } +} + - void PoolingLayer::avePooling(Blob &input, Blob &output) +#ifdef HAVE_OPENCL +bool PoolingLayerImpl::pooling_ocl(const char *kname, const Blob &src, Blob &dst, Blob *mask) +{ + const UMat &srcMat = src.umatRefConst(); + UMat &dstMat = dst.umatRef(); + CV_Assert(mask == NULL && srcMat.offset == 0 && dstMat.offset == 0); + + ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc, String("-DT=") + ocl::typeToStr(src.type())); + if (ker.empty()) + return false; + + BlobShape s = src.shape(); + size_t nthreads = dst.total(); + ker.args((int)nthreads, + ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3], + out.height, out.width, kernel.height, kernel.width, + stride.height, stride.width, pad.height, pad.width, + ocl::KernelArg::PtrWriteOnly(dstMat)); + + size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize(); + if (!ker.run(1, &nthreads, &wgSize, true)) + return false; + + return true; +} +#else +bool PoolingLayerImpl::pooling_ocl(const char*, const Blob&, Blob&, Blob*) +{ + return false; +} +#endif + +void PoolingLayerImpl::avePooling_cpu(Blob &src, Blob &dst) +{ + for (int n = 0; n < src.num(); ++n) { - for (int n = 0; n < input.num(); ++n) + for (int c = 0; c < src.channels(); ++c) { - for (int c = 0; c < input.channels(); ++c) - { - float *srcData = input.ptrf(n, c); - float *dstData = output.ptrf(n, c); + const float *srcData = src.ptrf(n, c); + float *dstData = dst.ptrf(n, c); - for (int ph = 0; ph < out.height; ++ph) + for (int ph = 0; ph < out.height; ++ph) + { + for (int pw = 0; pw < out.width; ++pw) { - for (int pw = 0; pw < out.width; ++pw) - { - int hstart = ph * stride.height - pad.height; - int wstart = pw * stride.width - pad.width; - int hend = min(hstart + kernel.height, inp.height + pad.height); - int wend = min(wstart + kernel.width, inp.width + pad.width); - int poolSize = (hend - hstart) * (wend - wstart); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - hend = min(hend, inp.height); - wend = min(wend, inp.width); - - dstData[ph * out.width + pw] = 0.f; - - for (int h = hstart; h < hend; ++h) - for (int w = wstart; w < wend; ++w) - dstData[ph * out.width + pw] += srcData[h * inp.width + w]; - - dstData[ph * out.width + pw] /= poolSize; - } + int hstart = ph * stride.height - pad.height; + int wstart = pw * stride.width - pad.width; + int hend = min(hstart + kernel.height, inp.height + pad.height); + int wend = min(wstart + kernel.width, inp.width + pad.width); + int poolSize = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, inp.height); + wend = min(wend, inp.width); + + dstData[ph * out.width + pw] = 0.f; + + for (int h = hstart; h < hend; ++h) + for (int w = wstart; w < wend; ++w) + dstData[ph * out.width + pw] += srcData[h * inp.width + w]; + + dstData[ph * out.width + pw] /= poolSize; } - } + } } } +} + +void PoolingLayerImpl::computeOutputShape(Size inpSz) +{ + //Yeah, something strange Caffe scheme-) + out.height = static_cast<int>(ceil(static_cast<float>(inpSz.height + 2 * pad.height - kernel.height) / stride.height)) + 1; + out.width = static_cast<int>(ceil(static_cast<float>(inpSz.width + 2 * pad.width - kernel.width) / stride.width)) + 1; - void PoolingLayer::computeOutputShape(Size inpSz) + if (pad.height || pad.width) { - //Yeah, something strange Caffe scheme-) - out.height = static_cast<int>(ceil(static_cast<float>(inpSz.height + 2 * pad.height - kernel.height) / stride.height)) + 1; - out.width = static_cast<int>(ceil(static_cast<float>(inpSz.width + 2 * pad.width - kernel.width) / stride.width)) + 1; + // If we have padding, ensure that the last pooling starts strictly + // inside the image (instead of at the padding); otherwise clip the last. + if ((out.height - 1) * stride.height >= inpSz.height + pad.height) + --out.height; + if ((out.width - 1) * stride.width >= inpSz.width + pad.width) + --out.width; + CV_Assert((out.height - 1) * stride.height < inpSz.height + pad.height); + CV_Assert((out.width - 1) * stride.width < inpSz.width + pad.width); + } +} - if (pad.height || pad.width) - { - // If we have padding, ensure that the last pooling starts strictly - // inside the image (instead of at the padding); otherwise clip the last. - if ((out.height - 1) * stride.height >= inpSz.height + pad.height) - --out.height; - if ((out.width - 1) * stride.width >= inpSz.width + pad.width) - --out.width; - CV_Assert((out.height - 1) * stride.height < inpSz.height + pad.height); - CV_Assert((out.width - 1) * stride.width < inpSz.width + pad.width); - } +Ptr<PoolingLayer> PoolingLayer::create(int type, Size kernel, Size pad, Size stride) +{ + return Ptr<PoolingLayer>(new PoolingLayerImpl(type, kernel, pad, stride)); +} + +Ptr<Layer> createPoolingLayerFromCaffe(LayerParams ¶ms) +{ + int type; + Size kernel, pad, stride; + + if (params.has("pool")) + { + String pool = params.get<String>("pool").toLowerCase(); + if (pool == "max") + type = PoolingLayer::MAX; + else if (pool == "ave") + type = PoolingLayer::AVE; + else if (pool == "stochastic") + type = PoolingLayer::STOCHASTIC; + else + CV_Error(Error::StsBadArg, "Unknown pooling type \"" + pool + "\""); + } + else + { + type = PoolingLayer::MAX; } + + getCaffeConvParams(params, kernel, pad, stride); + + return Ptr<Layer>(new PoolingLayerImpl(type, kernel, pad, stride)); +} + } } diff --git a/modules/dnn/src/layers/pooling_layer.hpp b/modules/dnn/src/layers/pooling_layer.hpp index 02f1b2b09..b66242e39 100644 --- a/modules/dnn/src/layers/pooling_layer.hpp +++ b/modules/dnn/src/layers/pooling_layer.hpp @@ -1,4 +1,4 @@ -/*M/////////////////////////////////////////////////////////////////////////////////////// +/*M/////////////////////////////////////////////////////////////////////////////////////// // // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. // @@ -42,33 +42,40 @@ #ifndef __OPENCV_DNN_LAYERS_POOLING_LAYER_HPP__ #define __OPENCV_DNN_LAYERS_POOLING_LAYER_HPP__ #include "../precomp.hpp" +#include <opencv2/dnn/all_layers.hpp> namespace cv { namespace dnn { - class PoolingLayer : public Layer + class PoolingLayerImpl : public PoolingLayer { - enum - { - MAX, - AVE, - STOCHASTIC - }; - - int type; - Size kernel, pad, stride; + bool useOpenCL; Size inp, out; void computeOutputShape(Size inpSz); - void maxPooling(Blob &input, Blob &output); - void avePooling(Blob &input, Blob &output); + + bool pooling_ocl(const char *kname, const Blob &src, Blob &dst, Blob *mask = NULL); + + void maxPooling(Blob &src, Blob &dst); + void maxPooling_cpu(Blob &src, Blob &dst); + bool maxPooling_ocl(Blob &src, Blob &dst); + + void avePooling(Blob &src, Blob &dst); + void avePooling_cpu(Blob &src, Blob &dst); + bool avePooling_ocl(Blob &src, Blob &dst); public: - PoolingLayer(LayerParams ¶ms); + + PoolingLayerImpl(); + PoolingLayerImpl(int type, Size kernel, Size pad, Size stride); + void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs); void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs); }; + +Ptr<Layer> createPoolingLayerFromCaffe(LayerParams ¶ms); + } } #endif diff --git a/modules/dnn/src/opencl/pooling.cl b/modules/dnn/src/opencl/pooling.cl new file mode 100644 index 000000000..aeb70bc55 --- /dev/null +++ b/modules/dnn/src/opencl/pooling.cl @@ -0,0 +1,94 @@ +/************************************************************************************* + * Copyright (c) 2015, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation and/or + * other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + * IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + **************************************************************************************/ + +__kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, __global T* top_data +#ifdef MASK + , __global int* mask, __global T* top_mask +#endif +) { + int index = get_global_id(0); + int tmp = get_global_size(0); + for(index; index < nthreads; index += tmp) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + const int hend = min(hstart + kernel_h, height); + const int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + T maxval = -FLT_MAX; + int maxidx = -1; + bottom_data = + bottom_data + (n * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (bottom_data[h * width + w] > maxval) { + maxidx = h * width + w; + maxval = bottom_data[maxidx]; + } + } + } + top_data[index] = maxval; +#ifdef MASK + if (mask) { + mask[index] = maxidx; + } else { + top_mask[index] = maxidx; + } +#endif + } +} + +__kernel void AvePoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w,__global T* top_data) { + int index = get_global_id(0); + int tmp = get_global_size(0); + for(index; index < nthreads; index+=tmp) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; int hstart = ph * stride_h - pad_h; int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + const int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + T aveval = 0; + bottom_data = + bottom_data + (n * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + aveval += bottom_data[h * width + w]; + } + } + top_data[index] = aveval / pool_size; + } + +} diff --git a/modules/dnn/test/test_layers.cpp b/modules/dnn/test/test_layers.cpp index 8871bcb88..9958076e7 100644 --- a/modules/dnn/test/test_layers.cpp +++ b/modules/dnn/test/test_layers.cpp @@ -142,12 +142,24 @@ TEST(Layer_Test_InnerProduct, Accuracy) TEST(Layer_Test_Pooling_max, Accuracy) { - testLayerUsingCaffeModels("layer_pooling_max"); + OCL_OFF(testLayerUsingCaffeModels("layer_pooling_max")); + OCL_ON(); +} +OCL_TEST(Layer_Test_Pooling_max, Accuracy) +{ + OCL_ON(testLayerUsingCaffeModels("layer_pooling_max")); + OCL_OFF(); } TEST(Layer_Test_Pooling_ave, Accuracy) { - testLayerUsingCaffeModels("layer_pooling_ave"); + OCL_OFF(testLayerUsingCaffeModels("layer_pooling_ave")); + OCL_ON(); +} +OCL_TEST(Layer_Test_Pooling_ave, Accuracy) +{ + OCL_ON(testLayerUsingCaffeModels("layer_pooling_ave")); + OCL_OFF(); } TEST(Layer_Test_MVN, Accuracy)