parent
06f949a590
commit
23d3ede6ab
7 changed files with 114 additions and 14 deletions
@ -0,0 +1,41 @@ |
|||||||
|
#include "../precomp.hpp" |
||||||
|
#include <opencv2/core/ocl.hpp> |
||||||
|
#include "im2col.hpp" |
||||||
|
#include "opencl_kernels_dnn.hpp" |
||||||
|
|
||||||
|
namespace cv |
||||||
|
{ |
||||||
|
namespace dnn |
||||||
|
{ |
||||||
|
|
||||||
|
void im2col_ocl(UMat &img, |
||||||
|
int channels, int height, int width,
|
||||||
|
int kernel_h, int kernel_w, |
||||||
|
int pad_h, int pad_w, |
||||||
|
int stride_h, int stride_w,
|
||||||
|
UMat &col) |
||||||
|
{ |
||||||
|
int h_out = (height + 2 * pad_h - kernel_h) / stride_h + 1; |
||||||
|
int w_out = (width + 2 * pad_w - kernel_w) / stride_w + 1; |
||||||
|
|
||||||
|
CV_Assert(img.isContinuous() && col.isContinuous()); |
||||||
|
CV_Assert(img.total() == (size_t)channels * height * width); |
||||||
|
CV_Assert(col.total() == (size_t)h_out * w_out * kernel_h * kernel_w); |
||||||
|
|
||||||
|
ocl::Kernel im2col_ker("im2col", ocl::dnn::im2col_oclsrc); |
||||||
|
|
||||||
|
im2col_ker.args(ocl::KernelArg::PtrReadOnly(img),
|
||||||
|
channels, height, width, |
||||||
|
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, |
||||||
|
h_out, w_out, |
||||||
|
ocl::KernelArg::PtrWriteOnly(col) |
||||||
|
); |
||||||
|
|
||||||
|
size_t globalSize[] = { (size_t)channels * h_out * w_out }; |
||||||
|
size_t localSize[] = { ocl::Device::getDefault().maxWorkGroupSize() }; |
||||||
|
|
||||||
|
CV_Assert(im2col_ker.run(1, globalSize, localSize, false)); |
||||||
|
} |
||||||
|
|
||||||
|
} |
||||||
|
} |
@ -0,0 +1,30 @@ |
|||||||
|
__kernel void im2col(__global const float *im_src, |
||||||
|
int channels, int height_inp, int width_inp, |
||||||
|
int kernel_h, int kernel_w, int pad_h, int pad_w, int stride_h, int stride_w, |
||||||
|
int height_out, int width_out, |
||||||
|
__global float *im_col |
||||||
|
) |
||||||
|
{ |
||||||
|
int index = get_global_id(0); |
||||||
|
int j_out = index % width_out; |
||||||
|
int i_out = (index / width_out) % height_out; |
||||||
|
int c_inp = (index / width_out) / height_out; |
||||||
|
|
||||||
|
int c_out = c_inp * kernel_h * kernel_w; |
||||||
|
int i_inp = i_out * stride_h - pad_h; |
||||||
|
int j_inp = j_out * stride_w - pad_w; |
||||||
|
|
||||||
|
im_col += (c_out * height_out + i_out) * width_out + j_out; |
||||||
|
im_src += (c_inp * height_inp + i_inp) * width_inp + j_inp; |
||||||
|
|
||||||
|
for (int ki = 0; ki < kernel_h; ++ki) |
||||||
|
for (int kj = 0; kj < kernel_w; ++kj) { |
||||||
|
int i = i_inp + ki; |
||||||
|
int j = j_inp + kj; |
||||||
|
*im_col = (h >= 0 && w >= 0 && h < height_inp && w < width_inp) ? |
||||||
|
im_src[i * width_inp + j] : 0; |
||||||
|
im_col += height_out * width_out; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
} |
Loading…
Reference in new issue