Adding of OCL implementations and public interfaces for Convolution and LRN

pull/707/head
Vitaliy Lyudvichenko 8 years ago
parent 50c9e1c912
commit 601afeed90
  1. 40
      modules/dnn/include/opencv2/dnn/all_layers.hpp
  2. 3
      modules/dnn/include/opencv2/dnn/dnn.hpp
  3. 35
      modules/dnn/include/opencv2/dnn/shape_utils.hpp
  4. 7
      modules/dnn/src/dnn.cpp
  5. 8
      modules/dnn/src/init.cpp
  6. 189
      modules/dnn/src/layers/convolution_layer.cpp
  7. 84
      modules/dnn/src/layers/convolution_layer.hpp
  8. 27
      modules/dnn/src/layers/layers_common.cpp
  9. 2
      modules/dnn/src/layers/layers_common.hpp
  10. 267
      modules/dnn/src/layers/lrn_layer.cpp
  11. 28
      modules/dnn/src/layers/lrn_layer.hpp
  12. 73
      modules/dnn/src/layers/pooling_layer.cpp
  13. 10
      modules/dnn/src/layers/pooling_layer.hpp
  14. 76
      modules/dnn/src/opencl/lrn.cl
  15. 14
      modules/dnn/test/test_layers.cpp

@ -205,6 +205,46 @@ namespace dnn
void forward(std::vector<Blob*> &input, std::vector<Blob> &output);
};
class CV_EXPORTS_W BaseConvolutionLayer : public Layer
{
public:
Size kernel, pad, stride;
};
class CV_EXPORTS_W ConvolutionLayer : public BaseConvolutionLayer
{
public:
static Ptr<BaseConvolutionLayer> create();
static Ptr<BaseConvolutionLayer> create(Size kernel = Size(3, 3), Size pad = Size(0, 0), Size stride = Size(1, 1));
};
class CV_EXPORTS_W DeconvolutionLayer : public BaseConvolutionLayer
{
public:
static Ptr<BaseConvolutionLayer> create();
static Ptr<BaseConvolutionLayer> create(Size kernel = Size(3, 3), Size pad = Size(0, 0), Size stride = Size(1, 1));
};
class CV_EXPORTS_W LRNLayer : public Layer
{
public:
enum
{
CHANNEL_NRM,
SPATIAL_NRM
};
int type;
int size;
double alpha, beta;
};
//! @}
//! @}

@ -120,7 +120,8 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
String type; //!< Type name which was used for creating layer by layer factory.
Layer();
explicit Layer(const LayerParams &params); //!< Initializes only #name, #type and #blobs fields.
explicit Layer(const LayerParams &params); //!< Initializes only #name, #type and #blobs fields.
void setParamsFrom(const LayerParams &params); //!< Initializes only #name, #type and #blobs fields.
virtual ~Layer();
};

@ -48,7 +48,10 @@
namespace cv {
namespace dnn {
std::ostream &operator<< (std::ostream &s, cv::Range &r)
//Useful shortcut
typedef BlobShape Shape;
inline std::ostream &operator<< (std::ostream &s, cv::Range &r)
{
return s << "[" << r.start << ", " << r.end << ")";
}
@ -96,8 +99,6 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1)
ranges[i] = Range::all();
ranges[0] = r0;
ranges[1] = r1;
// for (int i = 0; i < m.dims; i++)
// std::cout << ranges[i] << "\n";
return m(&ranges[0]);
}
@ -128,8 +129,32 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2, co
return m(&ranges[0]);
}
}
//Traits for switching in ploymorphic implementations
template<typename XMat>
struct MatTraits
{
};
}
template<>
struct MatTraits<cv::Mat>
{
enum
{
IS_MAT = 1,
IS_UMAT = 0,
};
};
template<>
struct MatTraits<cv::UMat>
{
enum
{
IS_MAT = 0,
IS_UMAT = 1,
};
};
}
}
#endif

@ -543,6 +543,13 @@ Layer::Layer(const LayerParams &params)
}
void Layer::setParamsFrom(const LayerParams &params)
{
blobs = params.blobs;
name = params.name;
type = params.type;
}
int Layer::inputNameToIndex(String)
{
return -1;

@ -81,9 +81,9 @@ void initModule()
REG_RUNTIME_LAYER_CLASS(Split, SplitLayer)
REG_RUNTIME_LAYER_CLASS(Reshape, ReshapeLayer)
REG_STATIC_LAYER_FUNC(Flatten, createFlattenLayer)
REG_RUNTIME_LAYER_CLASS(Pooling, PoolingLayer)
REG_RUNTIME_LAYER_CLASS(Pooling, PoolingLayerImpl)
REG_RUNTIME_LAYER_CLASS(MVN, MVNLayer)
REG_RUNTIME_LAYER_CLASS(LRN, LRNLayer)
REG_RUNTIME_LAYER_FUNC(LRN, createLRNLayerFromCaffe)
REG_RUNTIME_LAYER_CLASS(InnerProduct, FullyConnectedLayer)
REG_RUNTIME_LAYER_CLASS(ReLU, ElementWiseLayer<ReLUFunctor>)
@ -94,8 +94,8 @@ void initModule()
REG_RUNTIME_LAYER_CLASS(Sigmoid, ElementWiseLayer<SigmoidFunctor>)
REG_RUNTIME_LAYER_CLASS(Dropout, BlankLayer)
REG_RUNTIME_LAYER_CLASS(Convolution, ConvolutionLayer)
REG_RUNTIME_LAYER_CLASS(Deconvolution, DeConvolutionLayer)
REG_RUNTIME_LAYER_FUNC(Convolution, createConvolutionLayerFromCaffe)
REG_RUNTIME_LAYER_FUNC(Deconvolution, createDeconvolutionLayerFromCaffe)
REG_RUNTIME_LAYER_CLASS(Concat, ConcatLayer)
init.status = true;

@ -55,26 +55,11 @@ namespace dnn
typedef BlobShape Shape;
ConvolutionLayer::ConvolutionLayer(LayerParams &params) : Layer(params)
ConvolutionLayerImpl::ConvolutionLayerImpl()
{
getKernelParams(params, kerH, kerW, padH, padW, strideH, strideW);
numOutput = params.get<int>("num_output");
bias = params.get<bool>("bias_term", true);
group = params.get<int>("group", 1);
CV_Assert(numOutput % group == 0);
CV_Assert(!bias || blobs.size() == 2);
CV_Assert( bias || blobs.size() == 1);
const Blob &wgtBlob = blobs[0];
CV_Assert(wgtBlob.dims() == 4 && wgtBlob.cols() == kerW && wgtBlob.rows() == kerH);
if (bias)
{
Blob &biasBlob = blobs[1];
CV_Assert(biasBlob.total() == (size_t)numOutput);
}
tryUseOpenCL = true;
numOutput = -1;
group = -1;
#if HAVE_CBLAS
if (getBlasThreads() != cv::getThreadNum())
@ -82,57 +67,71 @@ ConvolutionLayer::ConvolutionLayer(LayerParams &params) : Layer(params)
setBlasThreads(cv::getThreadNum());
}
#endif
}
tryUseOpenCL = true;
void ConvolutionLayerImpl::init()
{
CV_Assert(1 <= blobs.size() && blobs.size() <= 2);
bias = (blobs.size() >= 2);
numOutput = blobs[0].num();
CV_Assert(blobs[0].dims() == 4 && blobs[0].cols() == kernel.width && blobs[0].rows() == kernel.height);
CV_Assert(!bias || blobs[1].total() == (size_t)blobs[0].num());
useOpenCL = ocl::useOpenCL() && tryUseOpenCL;
}
void ConvolutionLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
void ConvolutionLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
const Blob &input = *inputs[0];
CV_Assert(input.dims() == 4 && (input.type() == CV_32F || input.type() == CV_64F));
computeInpOutShape(input);
const Blob &inpBlob = *inputs[0];
CV_Assert(inpBlob.dims() == 4 && inpBlob.type() == CV_32F);
computeInpOutShape(inpBlob);
group = inpCn / blobs[0].channels();
CV_Assert(inpCn % group == 0 && outCn % group == 0);
CV_Assert(blobs[0].num() == outCn && blobs[0].channels() == inpCn / group);
outGroupCn = outCn / group;
inpGroupCn = inpCn / group;
ksize = inpGroupCn * kerH * kerW;
ksize = inpGroupCn * kernel.height * kernel.width;
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->type() == inpBlob.type());
CV_Assert(inputs[i]->dims() == 4 && inputs[i]->channels() == inpBlob.channels());
CV_Assert(inputs[i]->rows() == inpBlob.rows() && inputs[i]->cols() == inpBlob.cols());
outputs[i].create(Shape(inputs[i]->num(), topCn, topH, topW));
CV_Assert(inputs[i]->type() == input.type());
CV_Assert(inputs[i]->dims() == 4 && inputs[i]->channels() == input.channels());
CV_Assert(inputs[i]->rows() == input.rows() && inputs[i]->cols() == input.cols());
}
useOpenCL = ocl::useOpenCL() && tryUseOpenCL;
int allocFlags = useOpenCL ? Blob::ALLOC_BOTH : Blob::ALLOC_MAT;
if (!is1x1())
{
colBlob.create(Shape(ksize, outH * outW), inpBlob.type(), allocFlags);
colBlob.create(Shape(ksize, outH * outW), input.type(), allocFlags);
}
if (bias)
{
biasOnesBlob.create(Shape(1, topH * topW), inpBlob.type(), allocFlags);
biasOnesBlob.create(Shape(1, topH * topW), input.type(), allocFlags);
biasOnesBlob.matRef().setTo(1);
}
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
outputs[i].create(Shape(inputs[i]->num(), topCn, topH, topW));
}
}
inline bool ConvolutionLayer::is1x1() const
bool ConvolutionLayerImpl::is1x1() const
{
return (kerH == 1 && kerW == 1) && (strideW == 1 && strideH == 1); //hotfix with stride
return (kernel.height == 1 && kernel.width == 1) &&
(stride.height == 1 && stride.width == 1);
}
template<typename XMat>
void ConvolutionLayer::forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
void ConvolutionLayerImpl::forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
XMat weightsMat = reshaped(blobs[0].getRefConst<XMat>(), Shape(outCn, ksize));
XMat biasesMat = reshaped(blobs[1].getRefConst<XMat>(), Shape(outCn, 1));
@ -167,7 +166,7 @@ void ConvolutionLayer::forward_(std::vector<Blob*> &inputs, std::vector<Blob> &o
}
}
void ConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
void ConvolutionLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
if (!useOpenCL)
forward_<Mat>(inputs, outputs);
@ -175,7 +174,7 @@ void ConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &ou
forward_<UMat>(inputs, outputs);
}
void ConvolutionLayer::im2col(const UMat &srcImg, UMat &dstCol)
void ConvolutionLayerImpl::im2col(const UMat &srcImg, UMat &dstCol)
{
if (is1x1())
{
@ -183,7 +182,7 @@ void ConvolutionLayer::im2col(const UMat &srcImg, UMat &dstCol)
return;
}
#ifdef HAVE_OPENCL
CV_Assert(im2col_ocl(srcImg, inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, this->colBlob.umatRef()));
CV_Assert(im2col_ocl(srcImg, inpGroupCn, inpH, inpW, kernel.height, kernel.width, pad.height, pad.width, stride.height, stride.width, this->colBlob.umatRef()));
dstCol = this->colBlob.umatRefConst();
#else
CV_Error(Error::StsInternal, "");
@ -191,7 +190,7 @@ void ConvolutionLayer::im2col(const UMat &srcImg, UMat &dstCol)
#endif
}
void ConvolutionLayer::im2col(const Mat &srcImg, Mat &dstCol)
void ConvolutionLayerImpl::im2col(const Mat &srcImg, Mat &dstCol)
{
if (is1x1())
{
@ -201,43 +200,47 @@ void ConvolutionLayer::im2col(const Mat &srcImg, Mat &dstCol)
Mat &colMat = colBlob.matRef();
if (srcImg.type() == CV_32F)
im2col_CpuPBody<float>::run(srcImg.ptr<float>(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, colMat.ptr<float>());
im2col_CpuPBody<float>::run(srcImg.ptr<float>(), inpGroupCn, inpH, inpW, kernel.height, kernel.width, pad.height, pad.width, stride.height, stride.width, colMat.ptr<float>());
if (srcImg.type() == CV_64F)
im2col_CpuPBody<double>::run(srcImg.ptr<double>(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, colMat.ptr<double>());
im2col_CpuPBody<double>::run(srcImg.ptr<double>(), inpGroupCn, inpH, inpW, kernel.height, kernel.width, pad.height, pad.width, stride.height, stride.width, colMat.ptr<double>());
dstCol = colMat;
}
void ConvolutionLayer::computeInpOutShape(const Blob &inpBlob)
void ConvolutionLayerImpl::computeInpOutShape(const Blob &input)
{
inpH = inpBlob.rows();
inpW = inpBlob.cols();
inpCn = inpBlob.channels();
inpH = input.rows();
inpW = input.cols();
inpCn = input.channels();
outH = (inpH + 2 * padH - kerH) / strideH + 1;
outW = (inpW + 2 * padW - kerW) / strideW + 1;
outH = (inpH + 2 * pad.height - kernel.height) / stride.height + 1;
outW = (inpW + 2 * pad.width - kernel.width) / stride.width + 1;
outCn = numOutput;
topH = outH; topW = outW; topCn = outCn;
}
DeConvolutionLayer::DeConvolutionLayer(LayerParams &params)
: ConvolutionLayer(params) {}
//Deconvolution
void DeConvolutionLayer::computeInpOutShape(const Blob &inpBlob)
DeConvolutionLayerImpl::DeConvolutionLayerImpl()
{
}
void DeConvolutionLayerImpl::computeInpOutShape(const Blob &inpBlob)
{
outH = inpBlob.rows();
outW = inpBlob.cols();
outCn = inpBlob.channels();
inpH = strideH * (outH - 1) + kerH - 2 * padH;
inpW = strideW * (outW - 1) + kerW - 2 * padW;
inpH = stride.height * (outH - 1) + kernel.height - 2 * pad.height;
inpW = stride.width * (outW - 1) + kernel.width - 2 * pad.width;
inpCn = numOutput;
topH = inpH; topW = inpW; topCn = inpCn;
}
void DeConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
void DeConvolutionLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
if (!useOpenCL)
forward_<Mat>(inputs, outputs);
@ -246,7 +249,7 @@ void DeConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &
}
template<typename XMat>
void DeConvolutionLayer::forward_(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
void DeConvolutionLayerImpl::forward_(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
XMat weightsMat = reshaped(blobs[0].getRefConst<XMat>(), Shape(outCn, ksize));
XMat biasesMat = reshaped(blobs[1].getRefConst<XMat>(), Shape(outCn, 1));
@ -282,7 +285,7 @@ void DeConvolutionLayer::forward_(std::vector<Blob *> &inputs, std::vector<Blob>
}
}
void DeConvolutionLayer::col2im(const Mat &colMat, Mat &dstImg)
void DeConvolutionLayerImpl::col2im(const Mat &colMat, Mat &dstImg)
{
if (is1x1())
{
@ -290,12 +293,12 @@ void DeConvolutionLayer::col2im(const Mat &colMat, Mat &dstImg)
return;
}
if (dstImg.type() == CV_32F)
col2im_CpuPBody<float>::run(colMat.ptr<float>(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dstImg.ptr<float>());
col2im_CpuPBody<float>::run(colMat.ptr<float>(), inpGroupCn, inpH, inpW, kernel.height, kernel.width, pad.height, pad.width, stride.height, stride.width, dstImg.ptr<float>());
if (dstImg.type() == CV_64F)
col2im_CpuPBody<double>::run(colMat.ptr<double>(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dstImg.ptr<double>());
col2im_CpuPBody<double>::run(colMat.ptr<double>(), inpGroupCn, inpH, inpW, kernel.height, kernel.width, pad.height, pad.width, stride.height, stride.width, dstImg.ptr<double>());
}
void DeConvolutionLayer::col2im(const UMat &colMat, UMat &dstImg)
void DeConvolutionLayerImpl::col2im(const UMat &colMat, UMat &dstImg)
{
if (is1x1())
{
@ -303,12 +306,74 @@ void DeConvolutionLayer::col2im(const UMat &colMat, UMat &dstImg)
return;
}
#ifdef HAVE_OPENCL
CV_Assert(col2im_ocl(colMat, inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dstImg));
CV_Assert(col2im_ocl(colMat, inpGroupCn, inpH, inpW, kernel.height, kernel.width, pad.height, pad.width, stride.height, stride.width, dstImg));
#else
CV_Error(Error::StsInternal, "");
dstImg = colMat;
#endif
}
//Initializers
Ptr<BaseConvolutionLayer> ConvolutionLayer::create()
{
return Ptr<BaseConvolutionLayer>(new ConvolutionLayerImpl());
}
Ptr<BaseConvolutionLayer> ConvolutionLayer::create(Size kernel, Size pad, Size stride)
{
ConvolutionLayerImpl *l = new ConvolutionLayerImpl();
l->kernel = kernel;
l->pad = pad;
l->stride = stride;
return Ptr<BaseConvolutionLayer>(l);
}
Ptr<BaseConvolutionLayer> DeconvolutionLayer::create()
{
return Ptr<BaseConvolutionLayer>(new DeConvolutionLayerImpl());
}
Ptr<BaseConvolutionLayer> DeconvolutionLayer::create(Size kernel, Size pad, Size stride)
{
DeConvolutionLayerImpl *l = new DeConvolutionLayerImpl();
l->kernel = kernel;
l->pad = pad;
l->stride = stride;
return Ptr<BaseConvolutionLayer>(l);
}
//Importers
template<typename CLayer>
static void initConvDeconvLayerFromCaffe(CLayer *l, LayerParams &params)
{
l->setParamsFrom(params);
getCaffeConvParams(params, l->kernel, l->pad, l->stride);
bool bias = params.get<bool>("bias_term", true);
int numOutput = params.get<int>("num_output");
int group = params.get<int>("group", 1);
CV_Assert(numOutput % group == 0);
CV_Assert((bias && l->blobs.size() == 2) || (!bias && l->blobs.size() == 1));
}
Ptr<Layer> createConvolutionLayerFromCaffe(LayerParams &params)
{
ConvolutionLayerImpl *l = new ConvolutionLayerImpl();
initConvDeconvLayerFromCaffe(l, params);
l->init();
return Ptr<Layer>(l);
}
Ptr<Layer> createDeconvolutionLayerFromCaffe(LayerParams &params)
{
ConvolutionLayerImpl *l = new DeConvolutionLayerImpl();
initConvDeconvLayerFromCaffe(l, params);
l->init();
return Ptr<Layer>(l);
}
}
}

@ -42,61 +42,65 @@
#ifndef __OPENCV_DNN_LAYERS_CONVOLUTION_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_CONVOLUTION_LAYER_HPP__
#include "../precomp.hpp"
#include <opencv2/dnn/all_layers.hpp>
namespace cv
{
namespace dnn
{
//TODO: simultaneously convolution and bias addition for cache optimization
class ConvolutionLayer : public Layer
{
protected:
bool bias;
int numOutput, group;
int padH, padW;
int kerH, kerW;
int strideH, strideW;
int inpH, inpW, inpCn;
int outH, outW, outCn;
int topH, topW, topCn; //switched between inp/out on deconv/conv
int inpGroupCn, outGroupCn;
int ksize;
//TODO: simultaneously convolution and bias addition for cache optimization
class ConvolutionLayerImpl : public ConvolutionLayer
{
public:
ConvolutionLayerImpl();
virtual void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
virtual void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
virtual void init();
protected:
int numOutput, group;
int inpH, inpW, inpCn;
int outH, outW, outCn;
int topH, topW, topCn; //switched between inp/out on deconv/conv
int inpGroupCn, outGroupCn;
int ksize;
bool tryUseOpenCL, useOpenCL;
bool bias;
bool tryUseOpenCL, useOpenCL;
Blob colBlob, biasOnesBlob;
Blob colBlob, biasOnesBlob;
inline bool is1x1() const;
virtual void computeInpOutShape(const Blob &inpBlob);
bool is1x1() const;
virtual void computeInpOutShape(const Blob &inpBlob);
void im2col(const Mat &srcImg, Mat &dstCol);
void im2col(const UMat &srcImg, UMat &dstCol);
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void im2col(const Mat &srcImg, Mat &dstCol);
void im2col(const UMat &srcImg, UMat &dstCol);
};
class DeConvolutionLayerImpl : public ConvolutionLayerImpl
{
public:
DeConvolutionLayerImpl();
virtual void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
public:
ConvolutionLayer() {}
ConvolutionLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
protected:
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
virtual void computeInpOutShape(const Blob &inpBlob);
class DeConvolutionLayer : public ConvolutionLayer
{
protected:
void computeInpOutShape(const Blob &inpBlob);
void col2im(const Mat &colMat, Mat &dstImg);
void col2im(const UMat &colMat, UMat &dstImg);
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void col2im(const Mat &colMat, Mat &dstImg);
void col2im(const UMat &colMat, UMat &dstImg);
};
public:
DeConvolutionLayer(LayerParams &params);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
//Importers
Ptr<Layer> createConvolutionLayerFromCaffe(LayerParams &params);
Ptr<Layer> createDeconvolutionLayerFromCaffe(LayerParams &params);
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
}
}
#endif

@ -46,43 +46,44 @@ namespace cv
namespace dnn
{
void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW)
void getCaffeConvParams(LayerParams &params, Size &kernel, Size &pad, Size &stride)
{
if (params.has("kernel_h") && params.has("kernel_w"))
{
kernelH = params.get<int>("kernel_h");
kernelW = params.get<int>("kernel_w");
kernel.height = params.get<int>("kernel_h");
kernel.width = params.get<int>("kernel_w");
}
else if (params.has("kernel_size"))
{
kernelH = kernelW = params.get<int>("kernel_size");
kernel.height = kernel.width = params.get<int>("kernel_size");
}
else
{
CV_Error(cv::Error::StsBadArg, "kernel_size (or kernel_h and kernel_w) not specified");
CV_Error(Error::StsBadArg, "kernel_size (or kernel_h and kernel_w) not specified");
}
CV_Assert(kernel.height > 0 && kernel.width > 0);
if (params.has("pad_h") && params.has("pad_w"))
{
padH = params.get<int>("pad_h");
padW = params.get<int>("pad_w");
pad.height = params.get<int>("pad_h");
pad.width = params.get<int>("pad_w");
}
else
{
padH = padW = params.get<int>("pad", 0);
pad.height = pad.width = params.get<int>("pad", 0);
}
CV_Assert(pad.height >= 0 && pad.width >= 0);
if (params.has("stride_h") && params.has("stride_w"))
{
strideH = params.get<int>("stride_h");
strideW = params.get<int>("stride_w");
stride.height = params.get<int>("stride_h");
stride.width = params.get<int>("stride_w");
}
else
{
strideH = strideW = params.get<int>("stride", 1);
stride.height = stride.width = params.get<int>("stride", 1);
}
CV_Assert(kernelH > 0 && kernelW > 0 && padH >= 0 && padW >= 0 && strideH > 0 && strideW > 0);
CV_Assert(stride.height > 0 && stride.width > 0);
}
}

@ -48,7 +48,7 @@ namespace cv
namespace dnn
{
void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW);
void getCaffeConvParams(LayerParams &params, Size &kernel, Size &pad, Size &stride);
}
}

@ -42,123 +42,218 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "lrn_layer.hpp"
#include "opencl_kernels_dnn.hpp"
#include <opencv2/imgproc.hpp>
#include <opencv2/core/ocl.hpp>
#include <opencv2/dnn/shape_utils.hpp>
#include <algorithm>
namespace cv
{
namespace dnn
{
LRNLayer::LRNLayer(LayerParams &params) : Layer(params)
{
String nrmType = params.get<String>("norm_region", "ACROSS_CHANNELS");
if (nrmType == "ACROSS_CHANNELS")
type = CHANNEL_NRM;
else if (nrmType == "WITHIN_CHANNEL")
type = SPATIAL_NRM;
else
CV_Error(Error::StsBadArg, "Unknown region type \"" + nrmType + "\"");
size = params.get<int>("local_size", 5);
if (size % 2 != 1 || size <= 0)
CV_Error(Error::StsBadArg, "LRN layer supports only positive odd values for local_size");
alpha = params.get<double>("alpha", 1);
beta = params.get<double>("beta", 0.75);
}
void LRNLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
outputs.resize(1);
LRNLayerImpl::LRNLayerImpl()
{
size = 5;
alpha = 1;
beta = 0.75;
type = CHANNEL_NRM;
}
Vec4i shape = inputs[0]->shape4();
outputs[0].create(shape);
void LRNLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1 && inputs[0]->dims() == 4);
useOpenCL = cv::ocl::useOpenCL();
shape[0] = 1; //maybe make shape[0] = 1 too
bufBlob.create(shape);
}
if (type == SPATIAL_NRM && !useOpenCL)
buf.create(inputs[0]->shape().slice(2), inputs[0]->type(), Blob::ALLOC_MAT);
if (type == CHANNEL_NRM && useOpenCL)
buf.create(inputs[0]->shape().slice(2), inputs[0]->type(), Blob::ALLOC_UMAT);
void LRNLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
outputs.resize(1);
outputs[0].create(inputs[0]->shape(), inputs[0]->type());
}
void LRNLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
Blob &src = *inputs[0];
Blob &dst = outputs[0];
switch (type)
{
Blob &src = *inputs[0];
Blob &dst = outputs[0];
case CHANNEL_NRM:
channelNoramlization(src, dst);
break;
case SPATIAL_NRM:
spatialNormalization(src, dst);
break;
default:
CV_Error(Error::StsNotImplemented, "Unimplemented mode of LRN layer");
break;
}
}
switch (type)
{
case CHANNEL_NRM:
channelNoramlization(src, dst);
break;
case SPATIAL_NRM:
spatialNormalization(src, dst);
break;
default:
CV_Error(cv::Error::StsNotImplemented, "Unimplemented mode of LRN layer");
break;
}
template<typename XMat>
static XMat getPlane(XMat &m, int n, int cn)
{
return reshaped(slice(m, n, cn), BlobShape::like(m).slice(2));
}
void LRNLayerImpl::channelNoramlization(Blob &src, Blob &dst)
{
if (!useOpenCL)
channelNoramlization_<Mat>(src, dst);
else
{
//channelNoramlization_ocl(src.getRefConst<UMat>(), dst.getRef<UMat>()); //consumes a lot of memory
channelNoramlization_<UMat>(src, dst);
}
}
template<typename XMat>
void LRNLayerImpl::channelNoramlization_(Blob &srcBlob, Blob &dstBlob)
{
int num = srcBlob.num();
int channels = srcBlob.channels();
int ksize = (size - 1) / 2;
XMat srcMat = srcBlob.getRefConst<XMat>();
XMat dstMat = dstBlob.getRef<XMat>();
void LRNLayer::channelNoramlization(Blob &srcBlob, Blob &dstBlob)
for (int n = 0; n < num; n++)
{
CV_DbgAssert(srcBlob.ptr() != dstBlob.ptr());
XMat accum = getPlane(dstMat, n, channels-1); //trick for memory saving
accum.setTo(0);
int num = srcBlob.num();
int channels = srcBlob.channels();
int ksize = (size - 1) / 2;
for (int cn = 0; cn < std::min(ksize, channels); cn++)
cv::accumulateSquare(getPlane(srcMat, n, cn), accum);
for (int n = 0; n < num; n++)
for (int cn = 0; cn < channels; cn++)
{
Mat accum = dstBlob.getPlane(n, channels-1); //trick for memory saving
accum.setTo(0);
for (int cn = 0; cn < std::min(ksize, channels); cn++)
cv::accumulateSquare(srcBlob.getPlane(n, cn), accum);
if (cn + ksize < channels)
{
cv::accumulateSquare(getPlane(srcMat, n, cn + ksize), accum);
}
for (int cn = 0; cn < channels; cn++)
if (cn - ksize - 1 >= 0)
{
if (cn + ksize < channels)
{
cv::accumulateSquare(srcBlob.getPlane(n, cn + ksize), accum);
}
if (cn - ksize - 1 >= 0)
{
Mat left = srcBlob.getPlane(n, cn - ksize - 1);
cv::subtract(accum, left.mul(left), accum); //subtractSquare
}
Mat dst = dstBlob.getPlane(n, cn);
accum.convertTo(dst, dst.type(), alpha/size, 1);
cv::pow(dst, beta, dst);
cv::divide(srcBlob.getPlane(n, cn), dst, dst);
//subtractSquare
XMat left = getPlane(srcMat, n, cn - ksize - 1);
cv::pow(left, 2, left);
cv::subtract(accum, left, accum);
}
XMat dst = getPlane(dstMat, n, cn);
accum.convertTo(dst, dst.type(), alpha/size, 1);
cv::pow(dst, beta, dst);
cv::divide(getPlane(srcMat, n, cn), dst, dst);
}
}
}
void LRNLayer::spatialNormalization(Blob &srcBlob, Blob &dstBlob)
{
int num = srcBlob.num();
int channels = srcBlob.channels();
bool LRNLayerImpl::channelNoramlization_ocl(const UMat &src, UMat &dst)
{
if (src.offset != 0 || dst.offset != 0) //TODO: add offset
return false;
String buildOpts = String("-DT=") + ocl::typeToStr(src.type());
ocl::Kernel kerScale("LRNFillScale", ocl::dnn::lrn_oclsrc, buildOpts);
if (kerScale.empty())
return false;
ocl::Kernel kerOutput("LRNComputeOutput", ocl::dnn::lrn_oclsrc, buildOpts);
if (kerOutput.empty())
return false;
Shape shape = Shape::like(src);
int ksize = (size - 1) / 2;
size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize();
UMat &scaleBuf = buf.umatRef();
size_t nthreads = (size_t)(shape.total() / shape[1]);
kerScale.args((int)nthreads,
ocl::KernelArg::PtrReadOnly(src), shape[0], shape[1], shape[2], shape[3],
size, (float)(alpha/size), (float)ksize, ocl::KernelArg::PtrWriteOnly(scaleBuf));
if (!kerScale.run(1, &nthreads, &wgSize, true))
return false;
nthreads = (size_t)shape.total();
kerOutput.args((int)nthreads,
ocl::KernelArg::PtrReadOnly(src), ocl::KernelArg::PtrReadOnly(scaleBuf),
-beta, ocl::KernelArg::PtrWriteOnly(dst) );
if (!kerOutput.run(1, &nthreads, &wgSize, true))
return false;
return true;
}
void LRNLayerImpl::spatialNormalization(Blob &src, Blob &dst)
{
if (!useOpenCL)
spatialNormalization_<Mat>(src, dst);
else
spatialNormalization_<UMat>(src, dst);
}
template<typename XMat>
void LRNLayerImpl::spatialNormalization_(Blob &srcBlob, Blob &dstBlob)
{
int num = srcBlob.num();
int channels = srcBlob.channels();
for (int n = 0; n < num; n++)
XMat srcMat = srcBlob.getRefConst<XMat>();
XMat dstMat = dstBlob.getRef<XMat>();
for (int n = 0; n < num; n++)
{
for (int cn = 0; cn < channels; cn++)
{
for (int cn = 0; cn < channels; cn++)
XMat src = getPlane(srcMat, n, cn);
XMat dst = getPlane(dstMat, n, cn);
if (MatTraits<XMat>::IS_UMAT)
{
Mat src = srcBlob.getPlane(n, cn);
Mat dst = dstBlob.getPlane(n, cn);
uchar *dataDst0 = dst.data;
cv::pow(srcBlob.getPlane(n, cn), 2, dst);
//TODO: check border type
cv::boxFilter(dst, dst, dst.depth(), cv::Size(size, size), cv::Point(-1, -1), false, cv::BORDER_CONSTANT);
dst.convertTo(dst, dst.type(), alpha/(size*size), 1);
cv::pow(dst, beta, dst);
cv::divide(src, dst, dst);
CV_Assert(dataDst0 == dst.data); //debug
cv::sqrBoxFilter(src, dst, dst.depth(), Size(size, size), Point(-1, -1), false, BORDER_CONSTANT | BORDER_ISOLATED);
}
else
{
//TODO: fix cv::boxFilter with BORDER_ISOLATED flag in CPU mode
Mat bufMat = buf.getRef<Mat>();
src.copyTo(bufMat);
cv::sqrBoxFilter(bufMat, dst, dst.depth(), Size(size, size), Point(-1, -1), false, BORDER_CONSTANT);
}
dst.convertTo(dst, dst.type(), alpha/(size*size), 1);
cv::pow(dst, beta, dst);
cv::divide(src, dst, dst);
}
}
}
Ptr<Layer> createLRNLayerFromCaffe(LayerParams &params)
{
LRNLayerImpl *l = new LRNLayerImpl();
String nrmType = params.get<String>("norm_region", "ACROSS_CHANNELS");
if (nrmType == "ACROSS_CHANNELS")
l->type = LRNLayer::CHANNEL_NRM;
else if (nrmType == "WITHIN_CHANNEL")
l->type = LRNLayer::SPATIAL_NRM;
else
CV_Error(Error::StsBadArg, "Unknown region type \"" + nrmType + "\"");
int size = params.get<int>("local_size", 5);
if (size % 2 != 1 || size <= 0)
CV_Error(Error::StsBadArg, "LRN layer supports only positive odd values for local_size");
l->size = size;
l->alpha = params.get<double>("alpha", 1);
l->beta = params.get<double>("beta", 0.75);
return Ptr<Layer>(l);
}
}
}

@ -42,34 +42,36 @@
#ifndef __OPENCV_DNN_LAYERS_LRN_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_LRN_LAYER_HPP__
#include "../precomp.hpp"
#include <opencv2/dnn/all_layers.hpp>
namespace cv
{
namespace dnn
{
class LRNLayer : public Layer
class LRNLayerImpl : public LRNLayer
{
enum
{
CHANNEL_NRM,
SPATIAL_NRM,
SPATIAL_CONTRAST_NRM //cuda-convnet feature
} type;
int size;
double alpha, beta;
Blob bufBlob;
bool useOpenCL;
Blob buf;
void channelNoramlization(Blob &src, Blob &dst);
template<typename XMat>
void channelNoramlization_(Blob &src, Blob &dst);
bool channelNoramlization_ocl(const UMat &src, UMat &dst);
void spatialNormalization(Blob &src, Blob &dst);
template<typename XMat>
void spatialNormalization_(Blob &src, Blob &dst);
public:
LRNLayer(LayerParams &params);
LRNLayerImpl();
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
Ptr<Layer> createLRNLayerFromCaffe(LayerParams &params);
}
}
#endif

@ -72,22 +72,21 @@ namespace dnn
type = MAX;
}
getKernelParams(params, kernelH, kernelW, padH, padW, strideH, strideW);
getCaffeConvParams(params, kernel, pad, stride);
}
void PoolingLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
inpW = inputs[0]->cols();
inpH = inputs[0]->rows();
computeOutputShape(inpH, inpW);
inp = inputs[0]->size2();
computeOutputShape(inp);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i]->rows() == inpH && inputs[i]->cols() == inpW);
outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), outH, outW));
CV_Assert(inputs[i]->rows() == inp.height && inputs[i]->cols() == inp.width);
outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), out.height, out.width));
}
}
@ -104,7 +103,7 @@ namespace dnn
avePooling(*inputs[ii], outputs[ii]);
break;
default:
CV_Error(cv::Error::StsNotImplemented, "Not implemented");
CV_Error(Error::StsNotImplemented, "Not implemented");
break;
}
}
@ -112,7 +111,7 @@ namespace dnn
void PoolingLayer::maxPooling(Blob &input, Blob &output)
{
CV_DbgAssert(output.rows() == outH && output.cols() == outW);
CV_DbgAssert(output.rows() == out.height && output.cols() == out.width);
for (int n = 0; n < input.num(); ++n)
{
@ -121,23 +120,23 @@ namespace dnn
float *srcData = input.ptrf(n, c);
float *dstData = output.ptrf(n, c);
for (int ph = 0; ph < outH; ++ph)
for (int ph = 0; ph < out.height; ++ph)
{
for (int pw = 0; pw < outW; ++pw)
for (int pw = 0; pw < out.width; ++pw)
{
int hstart = ph * strideH - padH;
int wstart = pw * strideW - padW;
int hend = min(hstart + kernelH, inpH);
int wend = min(wstart + kernelW, inpW);
int hstart = ph * stride.height - pad.height;
int wstart = pw * stride.width - pad.width;
int hend = min(hstart + kernel.height, inp.height);
int wend = min(wstart + kernel.width, inp.width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
const int poolIndex = ph * outW + pw;
const int poolIndex = ph * out.width + pw;
float max_val = -FLT_MAX;
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
{
const int index = h * inpW + w;
const int index = h * inp.width + w;
if (srcData[index] > max_val)
max_val = srcData[index];
}
@ -158,49 +157,49 @@ namespace dnn
float *srcData = input.ptrf(n, c);
float *dstData = output.ptrf(n, c);
for (int ph = 0; ph < outH; ++ph)
for (int ph = 0; ph < out.height; ++ph)
{
for (int pw = 0; pw < outW; ++pw)
for (int pw = 0; pw < out.width; ++pw)
{
int hstart = ph * strideH - padH;
int wstart = pw * strideW - padW;
int hend = min(hstart + kernelH, inpH + padH);
int wend = min(wstart + kernelW, inpW + padW);
int hstart = ph * stride.height - pad.height;
int wstart = pw * stride.width - pad.width;
int hend = min(hstart + kernel.height, inp.height + pad.height);
int wend = min(wstart + kernel.width, inp.width + pad.width);
int poolSize = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, inpH);
wend = min(wend, inpW);
hend = min(hend, inp.height);
wend = min(wend, inp.width);
dstData[ph * outW + pw] = 0.f;
dstData[ph * out.width + pw] = 0.f;
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
dstData[ph * outW + pw] += srcData[h * inpW + w];
dstData[ph * out.width + pw] += srcData[h * inp.width + w];
dstData[ph * outW + pw] /= poolSize;
dstData[ph * out.width + pw] /= poolSize;
}
}
}
}
}
void PoolingLayer::computeOutputShape(int inH, int inW)
void PoolingLayer::computeOutputShape(Size inpSz)
{
//Yeah, something strange Caffe scheme-)
outH = static_cast<int>(ceil(static_cast<float>(inH + 2 * padH - kernelH) / strideH)) + 1;
outW = static_cast<int>(ceil(static_cast<float>(inW + 2 * padW - kernelW) / strideW)) + 1;
out.height = static_cast<int>(ceil(static_cast<float>(inpSz.height + 2 * pad.height - kernel.height) / stride.height)) + 1;
out.width = static_cast<int>(ceil(static_cast<float>(inpSz.width + 2 * pad.width - kernel.width) / stride.width)) + 1;
if (padH || padW)
if (pad.height || pad.width)
{
// If we have padding, ensure that the last pooling starts strictly
// inside the image (instead of at the padding); otherwise clip the last.
if ((outH - 1) * strideH >= inH + padH)
--outH;
if ((outW - 1) * strideW >= inW + padW)
--outW;
CV_Assert((outH - 1) * strideH < inH + padH);
CV_Assert((outW - 1) * strideW < inW + padW);
if ((out.height - 1) * stride.height >= inpSz.height + pad.height)
--out.height;
if ((out.width - 1) * stride.width >= inpSz.width + pad.width)
--out.width;
CV_Assert((out.height - 1) * stride.height < inpSz.height + pad.height);
CV_Assert((out.width - 1) * stride.width < inpSz.width + pad.width);
}
}
}

@ -57,14 +57,10 @@ namespace dnn
};
int type;
int padH, padW;
int strideH, strideW;
int kernelH, kernelW;
Size kernel, pad, stride;
Size inp, out;
int inpH, inpW;
int outH, outW;
void computeOutputShape(int inpH, int inpW);
void computeOutputShape(Size inpSz);
void maxPooling(Blob &input, Blob &output);
void avePooling(Blob &input, Blob &output);

@ -0,0 +1,76 @@
/*************************************************************************************
* Copyright (c) 2015, Advanced Micro Devices, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation and/or
* other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
* INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
* OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
**************************************************************************************/
__kernel void LRNComputeOutput(const int nthreads, __global T* in, __global T* scale, const T negative_beta, __global T* out) {
int index = get_global_id(0);
int tmp = get_global_size(0);
for(index; index < nthreads; index += tmp)
out[index] = in[index] * pow(scale[index], negative_beta);
}
__kernel void LRNFillScale(const int nthreads, __global T* in, const int num, const int channels, const int height, const int width, const int size, const T alpha_over_size, const T k, __global T* scale) {
int index = get_global_id(0);
int tmp = get_global_size(0);
for(index; index < nthreads; index += tmp) {
// find out the local offset
const int w = index % width;
const int h = (index / width) % height;
const int n = index / width / height;
const int offset = (n * channels * height + h) * width + w;
const int step = height * width;
in = in + offset;
scale = scale + offset;
int head = 0;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;
T accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad && head < channels) {
accum_scale += in[head * step] * in[head * step];
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
if (head - size >= 0) {
accum_scale -= in[(head - size) * step]
* in[(head - size) * step];
}
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
// subtract only
while (head < channels + post_pad) {
if (head - size >= 0) {
accum_scale -= in[(head - size) * step]
* in[(head - size) * step];
}
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
}
}

@ -97,12 +97,22 @@ OCL_TEST(Layer_Test_Softmax, Accuracy)
TEST(Layer_Test_LRN_spatial, Accuracy)
{
testLayerUsingCaffeModels("layer_lrn_spatial");
OCL_OFF(testLayerUsingCaffeModels("layer_lrn_spatial"));
}
OCL_TEST(Layer_Test_LRN_spatial, Accuracy)
{
OCL_ON(testLayerUsingCaffeModels("layer_lrn_spatial"));
OCL_OFF();
}
TEST(Layer_Test_LRN_channels, Accuracy)
{
testLayerUsingCaffeModels("layer_lrn_channels");
OCL_OFF(testLayerUsingCaffeModels("layer_lrn_channels"));
}
OCL_TEST(Layer_Test_LRN_channels, Accuracy)
{
OCL_ON(testLayerUsingCaffeModels("layer_lrn_channels"));
OCL_OFF();
}
TEST(Layer_Test_Convolution, Accuracy)

Loading…
Cancel
Save