|
|
|
@ -138,9 +138,12 @@ public: |
|
|
|
|
UMat& bnorm_weight = umat_scale; |
|
|
|
|
UMat& bnorm_bias = umat_shift; |
|
|
|
|
|
|
|
|
|
const unsigned LOCAL_SIZE = 128; |
|
|
|
|
bool use_half = (inputs[0].depth() == CV_16S); |
|
|
|
|
String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s", use_half ? "half" : "float", |
|
|
|
|
use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4"); |
|
|
|
|
String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s -DLOCAL_SIZE=%u", use_half ? "half" : "float", |
|
|
|
|
use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4", |
|
|
|
|
LOCAL_SIZE |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
int splitDim = (acrossChannels) ? 1 : 2; |
|
|
|
|
for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++) |
|
|
|
@ -155,8 +158,8 @@ public: |
|
|
|
|
float alpha = 1.0f / s[1]; |
|
|
|
|
|
|
|
|
|
String buildopt = "-DNUM=4" + opts; |
|
|
|
|
ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt); |
|
|
|
|
size_t localsize[] = { 128 }; |
|
|
|
|
ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MEAN_FUSE"); |
|
|
|
|
size_t localsize[] = { LOCAL_SIZE }; |
|
|
|
|
size_t globalsize[] = { (size_t)s[0] / 4 * localsize[0] }; |
|
|
|
|
|
|
|
|
|
int argId = 0; |
|
|
|
@ -165,7 +168,6 @@ public: |
|
|
|
|
k.set(argId++, alpha); |
|
|
|
|
k.set(argId++, ocl::KernelArg::PtrWriteOnly(meanMat)); |
|
|
|
|
k.set(argId++, ocl::KernelArg::PtrWriteOnly(tmpMat)); |
|
|
|
|
k.set(argId++, NULL, localsize[0] * sizeof(cl_float4)); |
|
|
|
|
bool ret = k.run(1, globalsize, localsize, false); |
|
|
|
|
if (!ret) |
|
|
|
|
return false; |
|
|
|
@ -173,7 +175,7 @@ public: |
|
|
|
|
buildopt += format(" %s %s", (fuse_batch_norm) ? "-DFUSE_BATCH_NORM" : "", |
|
|
|
|
(fuse_relu) ? "-DFUSE_RELU" : ""); |
|
|
|
|
|
|
|
|
|
ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt); |
|
|
|
|
ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MVN_FUSE"); |
|
|
|
|
argId = 0; |
|
|
|
|
k1.set(argId++, ocl::KernelArg::PtrReadOnly(tmpMat)); |
|
|
|
|
k1.set(argId++, ocl::KernelArg::PtrReadOnly(inpMat)); |
|
|
|
@ -185,7 +187,6 @@ public: |
|
|
|
|
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_weight)); |
|
|
|
|
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_bias)); |
|
|
|
|
k1.set(argId++, ocl::KernelArg::PtrWriteOnly(outMat)); |
|
|
|
|
k1.set(argId++, NULL, localsize[0] * sizeof(cl_float4)); |
|
|
|
|
ret = k1.run(1, globalsize, localsize, false); |
|
|
|
|
if (!ret) |
|
|
|
|
return false; |
|
|
|
@ -243,7 +244,7 @@ public: |
|
|
|
|
if (normVariance) |
|
|
|
|
{ |
|
|
|
|
String kname = format("calc_mean%d", number); |
|
|
|
|
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); |
|
|
|
|
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MEAN"); |
|
|
|
|
if (kernel.empty()) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
@ -263,7 +264,7 @@ public: |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
String kname = format("mvn%d", number); |
|
|
|
|
buildopt += format("%s%s%s", (normVariance) ? " -DNORM_VARIANCE" : "", |
|
|
|
|
buildopt += format("%s%s%s -DKERNEL_MVN", (normVariance) ? " -DNORM_VARIANCE" : "", |
|
|
|
|
(fuse_batch_norm) ? " -DFUSE_BATCH_NORM" : "", |
|
|
|
|
(fuse_relu) ? " -DFUSE_RELU" : ""); |
|
|
|
|
ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); |
|
|
|
|