|
|
|
@ -96,30 +96,31 @@ public: |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
int number = (s[1] % 8 == 0) ? 8 : ((s[1] % 4 == 0) ? 4 : 1); |
|
|
|
|
String buildopt = format("-DNUM=%d ", number); |
|
|
|
|
String kname = format("calc_mean%d", number); |
|
|
|
|
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); |
|
|
|
|
if (kernel.empty()) |
|
|
|
|
return false; |
|
|
|
|
size_t global[] = { (size_t)s[0], (size_t)(s[1] / number) }; |
|
|
|
|
kernel.set(0, ocl::KernelArg::PtrReadOnly(inpMat)); |
|
|
|
|
kernel.set(1, (int)s[0]); |
|
|
|
|
kernel.set(2, (int)s[1]); |
|
|
|
|
kernel.set(3, ocl::KernelArg::PtrReadOnly(meanMat)); |
|
|
|
|
kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmpMat)); |
|
|
|
|
ret = kernel.run(2, global, NULL, false); |
|
|
|
|
if (!ret) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
String buildopt = format("-DNUM=%d ", number); |
|
|
|
|
if (normVariance) |
|
|
|
|
{ |
|
|
|
|
String kname = format("calc_mean%d", number); |
|
|
|
|
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); |
|
|
|
|
if (kernel.empty()) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
kernel.set(0, ocl::KernelArg::PtrReadOnly(inpMat)); |
|
|
|
|
kernel.set(1, (int)s[0]); |
|
|
|
|
kernel.set(2, (int)s[1]); |
|
|
|
|
kernel.set(3, ocl::KernelArg::PtrReadOnly(meanMat)); |
|
|
|
|
kernel.set(4, ocl::KernelArg::PtrWriteOnly(tmpMat)); |
|
|
|
|
ret = kernel.run(2, global, NULL, false); |
|
|
|
|
if (!ret) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
ret = ocl4dnn::ocl4dnnGEMV<float>(ocl4dnn::CblasNoTrans, s[0], s[1], alpha, |
|
|
|
|
tmpMat, 0, oneMat, 0, 0.0f, devMat, 0); |
|
|
|
|
if (!ret) |
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
kname = format("mvn%d", number); |
|
|
|
|
String kname = format("mvn%d", number); |
|
|
|
|
if (normVariance) |
|
|
|
|
buildopt += "-DNORM_VARIANCE"; |
|
|
|
|
ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); |
|
|
|
|