diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp
index 70ee37555..0d2a9af1e 100644
--- a/modules/dnn/include/opencv2/dnn/all_layers.hpp
+++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp
@@ -209,7 +209,7 @@ namespace dnn
     {
     public:
 
-        Size kernel, pad, stride;
+        Size kernel, stride, pad;
     };
 
     class CV_EXPORTS_W ConvolutionLayer : public BaseConvolutionLayer
@@ -232,7 +232,7 @@ namespace dnn
     {
     public:
 
-        enum
+        enum Type
         {
             CHANNEL_NRM,
             SPATIAL_NRM
@@ -241,9 +241,26 @@ namespace dnn
 
         int size;
         double alpha, beta;
+
+        static Ptr<LRNLayer> create(int type = CHANNEL_NRM, int size = 5, double alpha = 1, double beta = 0.75);
     };
 
+    class CV_EXPORTS_W PoolingLayer : public Layer
+    {
+    public:
 
+        enum Type
+        {
+            MAX,
+            AVE,
+            STOCHASTIC
+        };
+
+        int type;
+        Size kernel, stride, pad;
+
+        static Ptr<PoolingLayer> create(int type = MAX, Size kernel = Size(2, 2), Size pad = Size(0, 0), Size stride = Size(1, 1));
+    };
 
 //! @}
 //! @}
diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp
index ded668ae0..58e6e0d3f 100644
--- a/modules/dnn/src/init.cpp
+++ b/modules/dnn/src/init.cpp
@@ -81,7 +81,7 @@ void initModule()
     REG_RUNTIME_LAYER_CLASS(Split, SplitLayer)
     REG_RUNTIME_LAYER_CLASS(Reshape, ReshapeLayer)
     REG_STATIC_LAYER_FUNC(Flatten, createFlattenLayer)
-    REG_RUNTIME_LAYER_CLASS(Pooling, PoolingLayerImpl)
+    REG_RUNTIME_LAYER_FUNC(Pooling, createPoolingLayerFromCaffe)
     REG_RUNTIME_LAYER_CLASS(MVN, MVNLayer)
     REG_RUNTIME_LAYER_FUNC(LRN, createLRNLayerFromCaffe)
     REG_RUNTIME_LAYER_CLASS(InnerProduct, FullyConnectedLayer)
diff --git a/modules/dnn/src/layers/lrn_layer.cpp b/modules/dnn/src/layers/lrn_layer.cpp
index f846d765f..9694b44cb 100644
--- a/modules/dnn/src/layers/lrn_layer.cpp
+++ b/modules/dnn/src/layers/lrn_layer.cpp
@@ -53,17 +53,18 @@ namespace cv
 namespace dnn
 {
 
-LRNLayerImpl::LRNLayerImpl()
+LRNLayerImpl::LRNLayerImpl(int type_, int size_, double alpha_, double beta_)
 {
-    size = 5;
-    alpha = 1;
-    beta = 0.75;
-    type = CHANNEL_NRM;
+    type = type_;
+    size = size_;
+    alpha = alpha_;
+    beta = beta_;
 }
 
 void LRNLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
 {
     CV_Assert(inputs.size() == 1 && inputs[0]->dims() == 4);
+    CV_Assert(type == CHANNEL_NRM || type == SPATIAL_NRM);
     useOpenCL = cv::ocl::useOpenCL();
 
     if (type == SPATIAL_NRM && !useOpenCL)
@@ -154,6 +155,7 @@ void LRNLayerImpl::channelNoramlization_(Blob &srcBlob, Blob &dstBlob)
 
 bool LRNLayerImpl::channelNoramlization_ocl(const UMat &src, UMat &dst)
 {
+#ifdef HAVE_OPENCL
     if (src.offset != 0 || dst.offset != 0) //TODO: add offset
         return false;
 
@@ -187,6 +189,9 @@ bool LRNLayerImpl::channelNoramlization_ocl(const UMat &src, UMat &dst)
         return false;
 
     return true;
+#else
+    return false;
+#endif
 }
 
 void LRNLayerImpl::spatialNormalization(Blob &src, Blob &dst)
@@ -232,27 +237,31 @@ void LRNLayerImpl::spatialNormalization_(Blob &srcBlob, Blob &dstBlob)
     }
 }
 
-Ptr<Layer> createLRNLayerFromCaffe(LayerParams &params)
+
+Ptr<LRNLayer> LRNLayer::create(int type, int size, double alpha, double beta)
 {
-    LRNLayerImpl *l = new LRNLayerImpl();
+    return Ptr<LRNLayer>(new LRNLayerImpl(type, size, alpha, beta));
+}
 
+Ptr<Layer> createLRNLayerFromCaffe(LayerParams &params)
+{
+    int type;
     String nrmType = params.get<String>("norm_region", "ACROSS_CHANNELS");
     if (nrmType == "ACROSS_CHANNELS")
-        l->type = LRNLayer::CHANNEL_NRM;
+        type = LRNLayer::CHANNEL_NRM;
     else if (nrmType == "WITHIN_CHANNEL")
-        l->type = LRNLayer::SPATIAL_NRM;
+        type = LRNLayer::SPATIAL_NRM;
     else
         CV_Error(Error::StsBadArg, "Unknown region type \"" + nrmType + "\"");
 
     int size = params.get<int>("local_size", 5);
     if (size % 2 != 1 || size <= 0)
         CV_Error(Error::StsBadArg, "LRN layer supports only positive odd values for local_size");
-    l->size = size;
 
-    l->alpha = params.get<double>("alpha", 1);
-    l->beta = params.get<double>("beta", 0.75);
+    double alpha = params.get<double>("alpha", 1);
+    double beta = params.get<double>("beta", 0.75);
 
-    return Ptr<Layer>(l);
+    return Ptr<Layer>(new LRNLayerImpl(type, size, alpha, beta));
 }
 
 }
diff --git a/modules/dnn/src/layers/lrn_layer.hpp b/modules/dnn/src/layers/lrn_layer.hpp
index 29f9252e5..43d652868 100644
--- a/modules/dnn/src/layers/lrn_layer.hpp
+++ b/modules/dnn/src/layers/lrn_layer.hpp
@@ -64,7 +64,7 @@ namespace dnn
 
     public:
 
-        LRNLayerImpl();
+        LRNLayerImpl(int type = CHANNEL_NRM, int size = 5, double alpha = 1, double beta = 0.75);
         void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
         void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
     };
diff --git a/modules/dnn/src/layers/pooling_layer.cpp b/modules/dnn/src/layers/pooling_layer.cpp
index 858cad74d..04e872011 100644
--- a/modules/dnn/src/layers/pooling_layer.cpp
+++ b/modules/dnn/src/layers/pooling_layer.cpp
@@ -42,8 +42,10 @@
 #include "../precomp.hpp"
 #include "layers_common.hpp"
 #include "pooling_layer.hpp"
+#include "opencl_kernels_dnn.hpp"
 #include <float.h>
 #include <algorithm>
+#include <opencv2/core/ocl.hpp>
 using std::max;
 using std::min;
 
@@ -53,154 +55,243 @@ namespace dnn
 {
 //TODO: add ceil_mode param
 
-    PoolingLayer::PoolingLayer(LayerParams &params) : Layer(params)
-    {
-        if (params.has("pool"))
-        {
-            String pool = params.get<String>("pool").toLowerCase();
-            if (pool == "max")
-                type = MAX;
-            else if (pool == "ave")
-                type = AVE;
-            else if (pool == "stochastic")
-                type = STOCHASTIC;
-            else
-                CV_Error(cv::Error::StsBadArg, "Unknown pooling type \"" + pool + "\"");
-        }
-        else
-        {
-            type = MAX;
-        }
+PoolingLayerImpl::PoolingLayerImpl()
+{
 
-        getCaffeConvParams(params, kernel, pad, stride);
-    }
+}
 
-    void PoolingLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
-    {
-        CV_Assert(inputs.size() > 0);
+PoolingLayerImpl::PoolingLayerImpl(int type_, Size kernel_, Size pad_, Size stride_)
+{
+    type = type_;
+    kernel = kernel_;
+    pad = pad_;
+    stride = stride_;
+}
 
-        inp = inputs[0]->size2();
-        computeOutputShape(inp);
+void PoolingLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
+{
+    CV_Assert(inputs.size() > 0);
 
-        outputs.resize(inputs.size());
-        for (size_t i = 0; i < inputs.size(); i++)
-        {
-            CV_Assert(inputs[i]->rows() == inp.height && inputs[i]->cols() == inp.width);
-            outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), out.height, out.width));
-        }
+    inp = inputs[0]->size2();
+    computeOutputShape(inp);
+
+    useOpenCL = ocl::useOpenCL();
+
+    outputs.resize(inputs.size());
+    for (size_t i = 0; i < inputs.size(); i++)
+    {
+        CV_Assert(inputs[i]->rows() == inp.height && inputs[i]->cols() == inp.width);
+        outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), out.height, out.width));
     }
+}
 
-    void PoolingLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
+void PoolingLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
+{
+    for (size_t ii = 0; ii < inputs.size(); ii++)
     {
-        for (size_t ii = 0; ii < inputs.size(); ii++)
+        switch (type)
         {
-            switch (type)
-            {
-            case MAX:
-                maxPooling(*inputs[ii], outputs[ii]);
-                break;
-            case AVE:
-                avePooling(*inputs[ii], outputs[ii]);
-                break;
-            default:
-                CV_Error(Error::StsNotImplemented, "Not implemented");
-                break;
-            }
+        case MAX:
+            maxPooling(*inputs[ii], outputs[ii]);
+            break;
+        case AVE:
+            avePooling(*inputs[ii], outputs[ii]);
+            break;
+        default:
+            CV_Error(Error::StsNotImplemented, "Not implemented");
+            break;
         }
     }
+}
+
+void PoolingLayerImpl::maxPooling(Blob &src, Blob &dst)
+{
+    if (!useOpenCL)
+        maxPooling_cpu(src, dst);
+    else
+    {
+        CV_Assert(maxPooling_ocl(src, dst));
+    }
+}
+
+bool PoolingLayerImpl::maxPooling_ocl(Blob &src, Blob &dst)
+{
+    return pooling_ocl("MaxPoolForward", src, dst);
+}
 
-    void PoolingLayer::maxPooling(Blob &input, Blob &output)
+void PoolingLayerImpl::avePooling(Blob &src, Blob &dst)
+{
+    if (!useOpenCL)
+        avePooling_cpu(src, dst);
+    else
     {
-        CV_DbgAssert(output.rows() == out.height && output.cols() == out.width);
+        CV_Assert(avePooling_ocl(src, dst));
+    }
+}
+
+bool PoolingLayerImpl::avePooling_ocl(Blob &src, Blob &dst)
+{
+    return pooling_ocl("AvePoolForward", src, dst);
+}
+
+void PoolingLayerImpl::maxPooling_cpu(Blob &src, Blob &dst)
+{
+    CV_DbgAssert(dst.rows() == out.height && dst.cols() == out.width);
 
-        for (int n = 0; n < input.num(); ++n)
+    for (int n = 0; n < src.num(); ++n)
+    {
+        for (int c = 0; c < src.channels(); ++c)
         {
-            for (int c = 0; c < input.channels(); ++c)
-            {
-                float *srcData = input.ptrf(n, c);
-                float *dstData = output.ptrf(n, c);
+            const float *srcData = src.ptrf(n, c);
+            float *dstData = dst.ptrf(n, c);
 
-                for (int ph = 0; ph < out.height; ++ph)
+            for (int ph = 0; ph < out.height; ++ph)
+            {
+                for (int pw = 0; pw < out.width; ++pw)
                 {
-                    for (int pw = 0; pw < out.width; ++pw)
-                    {
-                        int hstart = ph * stride.height - pad.height;
-                        int wstart = pw * stride.width - pad.width;
-                        int hend = min(hstart + kernel.height, inp.height);
-                        int wend = min(wstart + kernel.width, inp.width);
-                        hstart = max(hstart, 0);
-                        wstart = max(wstart, 0);
-                        const int poolIndex = ph * out.width + pw;
-                        float max_val = -FLT_MAX;
-
-                        for (int h = hstart; h < hend; ++h)
-                            for (int w = wstart; w < wend; ++w)
-                            {
-                                const int index = h * inp.width + w;
-                                if (srcData[index] > max_val)
-                                    max_val = srcData[index];
-                            }
-
-                        dstData[poolIndex] = max_val;
-                    }
+                    int hstart = ph * stride.height - pad.height;
+                    int wstart = pw * stride.width - pad.width;
+                    int hend = min(hstart + kernel.height, inp.height);
+                    int wend = min(wstart + kernel.width, inp.width);
+                    hstart = max(hstart, 0);
+                    wstart = max(wstart, 0);
+                    const int poolIndex = ph * out.width + pw;
+                    float max_val = -FLT_MAX;
+
+                    for (int h = hstart; h < hend; ++h)
+                        for (int w = wstart; w < wend; ++w)
+                        {
+                            const int index = h * inp.width + w;
+                            if (srcData[index] > max_val)
+                                max_val = srcData[index];
+                        }
+
+                    dstData[poolIndex] = max_val;
                 }
             }
         }
     }
+}
+
 
-    void PoolingLayer::avePooling(Blob &input, Blob &output)
+#ifdef HAVE_OPENCL
+bool PoolingLayerImpl::pooling_ocl(const char *kname, const Blob &src, Blob &dst, Blob *mask)
+{
+    const UMat &srcMat = src.umatRefConst();
+    UMat &dstMat = dst.umatRef();
+    CV_Assert(mask == NULL && srcMat.offset == 0 && dstMat.offset == 0);
+
+    ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc, String("-DT=") + ocl::typeToStr(src.type()));
+    if (ker.empty())
+        return false;
+
+    BlobShape s = src.shape();
+    size_t nthreads = dst.total();
+    ker.args((int)nthreads,
+             ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3],
+             out.height, out.width, kernel.height, kernel.width,
+             stride.height, stride.width, pad.height, pad.width,
+             ocl::KernelArg::PtrWriteOnly(dstMat));
+
+    size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize();
+    if (!ker.run(1, &nthreads, &wgSize, true))
+        return false;
+
+    return true;
+}
+#else
+bool PoolingLayerImpl::pooling_ocl(const char*, const Blob&, Blob&, Blob*)
+{
+    return false;
+}
+#endif
+
+void PoolingLayerImpl::avePooling_cpu(Blob &src, Blob &dst)
+{
+    for (int n = 0; n < src.num(); ++n)
     {
-        for (int n = 0; n < input.num(); ++n)
+        for (int c = 0; c < src.channels(); ++c)
         {
-            for (int c = 0; c < input.channels(); ++c)
-            {
-                float *srcData = input.ptrf(n, c);
-                float *dstData = output.ptrf(n, c);
+            const float *srcData = src.ptrf(n, c);
+            float *dstData = dst.ptrf(n, c);
 
-                for (int ph = 0; ph < out.height; ++ph)
+            for (int ph = 0; ph < out.height; ++ph)
+            {
+                for (int pw = 0; pw < out.width; ++pw)
                 {
-                    for (int pw = 0; pw < out.width; ++pw)
-                    {
-                        int hstart = ph * stride.height - pad.height;
-                        int wstart = pw * stride.width - pad.width;
-                        int hend = min(hstart + kernel.height, inp.height + pad.height);
-                        int wend = min(wstart + kernel.width, inp.width + pad.width);
-                        int poolSize = (hend - hstart) * (wend - wstart);
-                        hstart = max(hstart, 0);
-                        wstart = max(wstart, 0);
-                        hend = min(hend, inp.height);
-                        wend = min(wend, inp.width);
-
-                        dstData[ph * out.width + pw] = 0.f;
-
-                        for (int h = hstart; h < hend; ++h)
-                            for (int w = wstart; w < wend; ++w)
-                                dstData[ph * out.width + pw] += srcData[h * inp.width + w];
-
-                        dstData[ph * out.width + pw] /= poolSize;
-                    }
+                    int hstart = ph * stride.height - pad.height;
+                    int wstart = pw * stride.width - pad.width;
+                    int hend = min(hstart + kernel.height, inp.height + pad.height);
+                    int wend = min(wstart + kernel.width, inp.width + pad.width);
+                    int poolSize = (hend - hstart) * (wend - wstart);
+                    hstart = max(hstart, 0);
+                    wstart = max(wstart, 0);
+                    hend = min(hend, inp.height);
+                    wend = min(wend, inp.width);
+
+                    dstData[ph * out.width + pw] = 0.f;
+
+                    for (int h = hstart; h < hend; ++h)
+                        for (int w = wstart; w < wend; ++w)
+                            dstData[ph * out.width + pw] += srcData[h * inp.width + w];
+
+                    dstData[ph * out.width + pw] /= poolSize;
                 }
-          }
+            }
         }
     }
+}
+
+void PoolingLayerImpl::computeOutputShape(Size inpSz)
+{
+    //Yeah, something strange Caffe scheme-)
+    out.height = static_cast<int>(ceil(static_cast<float>(inpSz.height + 2 * pad.height - kernel.height) / stride.height)) + 1;
+    out.width = static_cast<int>(ceil(static_cast<float>(inpSz.width + 2 * pad.width - kernel.width) / stride.width)) + 1;
 
-    void PoolingLayer::computeOutputShape(Size inpSz)
+    if (pad.height || pad.width)
     {
-        //Yeah, something strange Caffe scheme-)
-        out.height = static_cast<int>(ceil(static_cast<float>(inpSz.height + 2 * pad.height - kernel.height) / stride.height)) + 1;
-        out.width = static_cast<int>(ceil(static_cast<float>(inpSz.width + 2 * pad.width - kernel.width) / stride.width)) + 1;
+        // If we have padding, ensure that the last pooling starts strictly
+        // inside the image (instead of at the padding); otherwise clip the last.
+        if ((out.height - 1) * stride.height >= inpSz.height + pad.height)
+            --out.height;
+        if ((out.width - 1) * stride.width >= inpSz.width + pad.width)
+            --out.width;
+        CV_Assert((out.height - 1) * stride.height < inpSz.height + pad.height);
+        CV_Assert((out.width - 1) * stride.width < inpSz.width + pad.width);
+    }
+}
 
-        if (pad.height || pad.width)
-        {
-            // If we have padding, ensure that the last pooling starts strictly
-            // inside the image (instead of at the padding); otherwise clip the last.
-            if ((out.height - 1) * stride.height >= inpSz.height + pad.height)
-                --out.height;
-            if ((out.width - 1) * stride.width >= inpSz.width + pad.width)
-                --out.width;
-            CV_Assert((out.height - 1) * stride.height < inpSz.height + pad.height);
-            CV_Assert((out.width - 1) * stride.width < inpSz.width + pad.width);
-        }
+Ptr<PoolingLayer> PoolingLayer::create(int type, Size kernel, Size pad, Size stride)
+{
+    return Ptr<PoolingLayer>(new PoolingLayerImpl(type, kernel, pad, stride));
+}
+
+Ptr<Layer> createPoolingLayerFromCaffe(LayerParams &params)
+{
+    int type;
+    Size kernel, pad, stride;
+
+    if (params.has("pool"))
+    {
+        String pool = params.get<String>("pool").toLowerCase();
+        if (pool == "max")
+            type = PoolingLayer::MAX;
+        else if (pool == "ave")
+            type = PoolingLayer::AVE;
+        else if (pool == "stochastic")
+            type = PoolingLayer::STOCHASTIC;
+        else
+            CV_Error(Error::StsBadArg, "Unknown pooling type \"" + pool + "\"");
+    }
+    else
+    {
+        type = PoolingLayer::MAX;
     }
+
+    getCaffeConvParams(params, kernel, pad, stride);
+
+    return Ptr<Layer>(new PoolingLayerImpl(type, kernel, pad, stride));
+}
+
 }
 }
diff --git a/modules/dnn/src/layers/pooling_layer.hpp b/modules/dnn/src/layers/pooling_layer.hpp
index 02f1b2b09..b66242e39 100644
--- a/modules/dnn/src/layers/pooling_layer.hpp
+++ b/modules/dnn/src/layers/pooling_layer.hpp
@@ -1,4 +1,4 @@
-/*M///////////////////////////////////////////////////////////////////////////////////////
+/*M///////////////////////////////////////////////////////////////////////////////////////
 //
 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
 //
@@ -42,33 +42,40 @@
 #ifndef __OPENCV_DNN_LAYERS_POOLING_LAYER_HPP__
 #define __OPENCV_DNN_LAYERS_POOLING_LAYER_HPP__
 #include "../precomp.hpp"
+#include <opencv2/dnn/all_layers.hpp>
 
 namespace cv
 {
 namespace dnn
 {
-    class PoolingLayer : public Layer
+    class PoolingLayerImpl : public PoolingLayer
     {
-        enum
-        {
-            MAX,
-            AVE,
-            STOCHASTIC
-        };
-
-        int type;
-        Size kernel, pad, stride;
+        bool useOpenCL;
         Size inp, out;
 
         void computeOutputShape(Size inpSz);
-        void maxPooling(Blob &input, Blob &output);
-        void avePooling(Blob &input, Blob &output);
+
+        bool pooling_ocl(const char *kname, const Blob &src, Blob &dst, Blob *mask = NULL);
+
+        void maxPooling(Blob &src, Blob &dst);
+        void maxPooling_cpu(Blob &src, Blob &dst);
+        bool maxPooling_ocl(Blob &src, Blob &dst);
+
+        void avePooling(Blob &src, Blob &dst);
+        void avePooling_cpu(Blob &src, Blob &dst);
+        bool avePooling_ocl(Blob &src, Blob &dst);
 
     public:
-        PoolingLayer(LayerParams &params);
+
+        PoolingLayerImpl();
+        PoolingLayerImpl(int type, Size kernel, Size pad, Size stride);
+
         void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
         void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
     };
+
+Ptr<Layer> createPoolingLayerFromCaffe(LayerParams &params);
+
 }
 }
 #endif
diff --git a/modules/dnn/src/opencl/pooling.cl b/modules/dnn/src/opencl/pooling.cl
new file mode 100644
index 000000000..aeb70bc55
--- /dev/null
+++ b/modules/dnn/src/opencl/pooling.cl
@@ -0,0 +1,94 @@
+/*************************************************************************************
+ * Copyright (c) 2015, Advanced Micro Devices, Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification,
+ * are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation and/or
+ *  other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+ * IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
+ * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
+ * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+ * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ **************************************************************************************/
+
+__kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, __global T* top_data
+#ifdef MASK
+  , __global int* mask, __global T* top_mask
+#endif
+) {
+  int index = get_global_id(0);
+  int tmp = get_global_size(0);
+  for(index; index < nthreads; index += tmp) {
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    int hstart = ph * stride_h - pad_h;
+    int wstart = pw * stride_w - pad_w;
+    const int hend = min(hstart + kernel_h, height);
+    const int wend = min(wstart + kernel_w, width);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
+    T maxval = -FLT_MAX;
+    int maxidx = -1;
+    bottom_data =
+    bottom_data + (n * channels + c) * height * width;
+    for (int h = hstart; h < hend; ++h) {
+      for (int w = wstart; w < wend; ++w) {
+        if (bottom_data[h * width + w] > maxval) {
+          maxidx = h * width + w;
+          maxval = bottom_data[maxidx];
+        }
+      }
+    }
+    top_data[index] = maxval;
+#ifdef MASK
+    if (mask) {
+      mask[index] = maxidx;
+    } else {
+      top_mask[index] = maxidx;
+    }
+#endif
+  }
+}
+
+__kernel void AvePoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w,__global T* top_data) {
+  int index = get_global_id(0);
+  int tmp = get_global_size(0);
+  for(index; index < nthreads; index+=tmp) {
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels; int hstart = ph * stride_h - pad_h; int wstart = pw * stride_w - pad_w;
+    int hend = min(hstart + kernel_h, height + pad_h);
+    int wend = min(wstart + kernel_w, width + pad_w);
+    const int pool_size = (hend - hstart) * (wend - wstart);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
+    hend = min(hend, height);
+    wend = min(wend, width);
+    T aveval = 0;
+    bottom_data =
+    bottom_data + (n * channels + c) * height * width;
+    for (int h = hstart; h < hend; ++h) {
+      for (int w = wstart; w < wend; ++w) {
+        aveval += bottom_data[h * width + w];
+      }
+    }
+    top_data[index] = aveval / pool_size;
+  }
+
+}
diff --git a/modules/dnn/test/test_layers.cpp b/modules/dnn/test/test_layers.cpp
index 8871bcb88..9958076e7 100644
--- a/modules/dnn/test/test_layers.cpp
+++ b/modules/dnn/test/test_layers.cpp
@@ -142,12 +142,24 @@ TEST(Layer_Test_InnerProduct, Accuracy)
 
 TEST(Layer_Test_Pooling_max, Accuracy)
 {
-     testLayerUsingCaffeModels("layer_pooling_max");
+     OCL_OFF(testLayerUsingCaffeModels("layer_pooling_max"));
+     OCL_ON();
+}
+OCL_TEST(Layer_Test_Pooling_max, Accuracy)
+{
+     OCL_ON(testLayerUsingCaffeModels("layer_pooling_max"));
+     OCL_OFF();
 }
 
 TEST(Layer_Test_Pooling_ave, Accuracy)
 {
-     testLayerUsingCaffeModels("layer_pooling_ave");
+     OCL_OFF(testLayerUsingCaffeModels("layer_pooling_ave"));
+     OCL_ON();
+}
+OCL_TEST(Layer_Test_Pooling_ave, Accuracy)
+{
+     OCL_ON(testLayerUsingCaffeModels("layer_pooling_ave"));
+     OCL_OFF();
 }
 
 TEST(Layer_Test_MVN, Accuracy)