parent
06f949a590
commit
23d3ede6ab
7 changed files with 114 additions and 14 deletions
@ -0,0 +1,41 @@ |
||||
#include "../precomp.hpp" |
||||
#include <opencv2/core/ocl.hpp> |
||||
#include "im2col.hpp" |
||||
#include "opencl_kernels_dnn.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
void im2col_ocl(UMat &img, |
||||
int channels, int height, int width,
|
||||
int kernel_h, int kernel_w, |
||||
int pad_h, int pad_w, |
||||
int stride_h, int stride_w,
|
||||
UMat &col) |
||||
{ |
||||
int h_out = (height + 2 * pad_h - kernel_h) / stride_h + 1; |
||||
int w_out = (width + 2 * pad_w - kernel_w) / stride_w + 1; |
||||
|
||||
CV_Assert(img.isContinuous() && col.isContinuous()); |
||||
CV_Assert(img.total() == (size_t)channels * height * width); |
||||
CV_Assert(col.total() == (size_t)h_out * w_out * kernel_h * kernel_w); |
||||
|
||||
ocl::Kernel im2col_ker("im2col", ocl::dnn::im2col_oclsrc); |
||||
|
||||
im2col_ker.args(ocl::KernelArg::PtrReadOnly(img),
|
||||
channels, height, width, |
||||
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, |
||||
h_out, w_out, |
||||
ocl::KernelArg::PtrWriteOnly(col) |
||||
); |
||||
|
||||
size_t globalSize[] = { (size_t)channels * h_out * w_out }; |
||||
size_t localSize[] = { ocl::Device::getDefault().maxWorkGroupSize() }; |
||||
|
||||
CV_Assert(im2col_ker.run(1, globalSize, localSize, false)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,30 @@ |
||||
__kernel void im2col(__global const float *im_src, |
||||
int channels, int height_inp, int width_inp, |
||||
int kernel_h, int kernel_w, int pad_h, int pad_w, int stride_h, int stride_w, |
||||
int height_out, int width_out, |
||||
__global float *im_col |
||||
) |
||||
{ |
||||
int index = get_global_id(0); |
||||
int j_out = index % width_out; |
||||
int i_out = (index / width_out) % height_out; |
||||
int c_inp = (index / width_out) / height_out; |
||||
|
||||
int c_out = c_inp * kernel_h * kernel_w; |
||||
int i_inp = i_out * stride_h - pad_h; |
||||
int j_inp = j_out * stride_w - pad_w; |
||||
|
||||
im_col += (c_out * height_out + i_out) * width_out + j_out; |
||||
im_src += (c_inp * height_inp + i_inp) * width_inp + j_inp; |
||||
|
||||
for (int ki = 0; ki < kernel_h; ++ki) |
||||
for (int kj = 0; kj < kernel_w; ++kj) { |
||||
int i = i_inp + ki; |
||||
int j = j_inp + kj; |
||||
*im_col = (h >= 0 && w >= 0 && h < height_inp && w < width_inp) ? |
||||
im_src[i * width_inp + j] : 0; |
||||
im_col += height_out * width_out; |
||||
} |
||||
} |
||||
|
||||
} |
Loading…
Reference in new issue