DNN: add the Winograd fp16 support (#23654)

* add Winograd FP16 implementation

* fixed dispatching of FP16 code paths in dnn; use dynamic dispatcher only when NEON_FP16 is enabled in the build and the feature is present in the host CPU at runtime

* fixed some warnings

* hopefully fixed winograd on x64 (and maybe other platforms)

---------

Co-authored-by: Vadim Pisarevsky <vadim.pisarevsky@gmail.com>
pull/24566/head
zihaomu 1 year ago committed by GitHub
parent a478757483
commit b913e73d04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      cmake/OpenCVCompilerOptimizations.cmake
  2. 10
      cmake/checks/cpu_fp16.cpp
  3. 42
      modules/core/include/opencv2/core/cv_cpu_helper.h
  4. 2
      modules/core/test/ocl/test_image2d.cpp
  5. 4
      modules/dnn/CMakeLists.txt
  6. 2
      modules/dnn/include/opencv2/dnn/all_layers.hpp
  7. 3
      modules/dnn/include/opencv2/dnn/dnn.hpp
  8. 251
      modules/dnn/src/layers/cpu_kernels/conv_block.simd.hpp
  9. 2
      modules/dnn/src/layers/cpu_kernels/conv_depthwise.cpp
  10. 155
      modules/dnn/src/layers/cpu_kernels/conv_winograd_f63.cpp
  11. 476
      modules/dnn/src/layers/cpu_kernels/conv_winograd_f63.neon.cpp
  12. 489
      modules/dnn/src/layers/cpu_kernels/conv_winograd_f63.simd.hpp
  13. 206
      modules/dnn/src/layers/cpu_kernels/convolution.cpp
  14. 39
      modules/dnn/src/layers/cpu_kernels/convolution.hpp
  15. 9
      modules/dnn/src/model.cpp
  16. 7
      modules/dnn/test/test_backends.cpp
  17. 21
      modules/dnn/test/test_caffe_importer.cpp
  18. 17
      modules/dnn/test/test_darknet_importer.cpp
  19. 2
      modules/dnn/test/test_model.cpp
  20. 32
      modules/dnn/test/test_onnx_importer.cpp
  21. 12
      modules/dnn/test/test_tf_importer.cpp
  22. 7
      modules/dnn/test/test_torch_importer.cpp

@ -49,7 +49,7 @@
set(CPU_ALL_OPTIMIZATIONS "SSE;SSE2;SSE3;SSSE3;SSE4_1;SSE4_2;POPCNT;AVX;FP16;AVX2;FMA3;AVX_512F")
list(APPEND CPU_ALL_OPTIMIZATIONS "AVX512_COMMON;AVX512_KNL;AVX512_KNM;AVX512_SKX;AVX512_CNL;AVX512_CLX;AVX512_ICL")
list(APPEND CPU_ALL_OPTIMIZATIONS NEON VFPV3 FP16 NEON_DOTPROD)
list(APPEND CPU_ALL_OPTIMIZATIONS NEON VFPV3 FP16 NEON_DOTPROD NEON_FP16 NEON_BF16)
list(APPEND CPU_ALL_OPTIMIZATIONS MSA)
list(APPEND CPU_ALL_OPTIMIZATIONS VSX VSX3)
list(APPEND CPU_ALL_OPTIMIZATIONS RVV)

@ -15,12 +15,12 @@ int test()
#include "arm_neon.h"
int test()
{
const float src[] = { 0.0f, 0.0f, 0.0f, 0.0f };
short dst[8];
float32x4_t v_src = *(float32x4_t*)src;
const float src[] = { 0.0f, 1.0f, 2.0f, 3.0f };
short dst[4];
float32x4_t v_src = vld1q_f32(src);
float16x4_t v_dst = vcvt_f16_f32(v_src);
*(float16x4_t*)dst = v_dst;
return (int)dst[0];
vst1_f16((__fp16*)dst, v_dst);
return dst[0] + dst[1] + dst[2] + dst[3];
}
#else
#error "FP16 is not supported"

@ -441,6 +441,48 @@
#endif
#define __CV_CPU_DISPATCH_CHAIN_NEON_DOTPROD(fn, args, mode, ...) CV_CPU_CALL_NEON_DOTPROD(fn, args); __CV_EXPAND(__CV_CPU_DISPATCH_CHAIN_ ## mode(fn, args, __VA_ARGS__))
#if !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_COMPILE_NEON_FP16
# define CV_TRY_NEON_FP16 1
# define CV_CPU_FORCE_NEON_FP16 1
# define CV_CPU_HAS_SUPPORT_NEON_FP16 1
# define CV_CPU_CALL_NEON_FP16(fn, args) return (cpu_baseline::fn args)
# define CV_CPU_CALL_NEON_FP16_(fn, args) return (opt_NEON_FP16::fn args)
#elif !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_DISPATCH_COMPILE_NEON_FP16
# define CV_TRY_NEON_FP16 1
# define CV_CPU_FORCE_NEON_FP16 0
# define CV_CPU_HAS_SUPPORT_NEON_FP16 (cv::checkHardwareSupport(CV_CPU_NEON_FP16))
# define CV_CPU_CALL_NEON_FP16(fn, args) if (CV_CPU_HAS_SUPPORT_NEON_FP16) return (opt_NEON_FP16::fn args)
# define CV_CPU_CALL_NEON_FP16_(fn, args) if (CV_CPU_HAS_SUPPORT_NEON_FP16) return (opt_NEON_FP16::fn args)
#else
# define CV_TRY_NEON_FP16 0
# define CV_CPU_FORCE_NEON_FP16 0
# define CV_CPU_HAS_SUPPORT_NEON_FP16 0
# define CV_CPU_CALL_NEON_FP16(fn, args)
# define CV_CPU_CALL_NEON_FP16_(fn, args)
#endif
#define __CV_CPU_DISPATCH_CHAIN_NEON_FP16(fn, args, mode, ...) CV_CPU_CALL_NEON_FP16(fn, args); __CV_EXPAND(__CV_CPU_DISPATCH_CHAIN_ ## mode(fn, args, __VA_ARGS__))
#if !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_COMPILE_NEON_BF16
# define CV_TRY_NEON_BF16 1
# define CV_CPU_FORCE_NEON_BF16 1
# define CV_CPU_HAS_SUPPORT_NEON_BF16 1
# define CV_CPU_CALL_NEON_BF16(fn, args) return (cpu_baseline::fn args)
# define CV_CPU_CALL_NEON_BF16_(fn, args) return (opt_NEON_BF16::fn args)
#elif !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_DISPATCH_COMPILE_NEON_BF16
# define CV_TRY_NEON_BF16 1
# define CV_CPU_FORCE_NEON_BF16 0
# define CV_CPU_HAS_SUPPORT_NEON_BF16 (cv::checkHardwareSupport(CV_CPU_NEON_BF16))
# define CV_CPU_CALL_NEON_BF16(fn, args) if (CV_CPU_HAS_SUPPORT_NEON_BF16) return (opt_NEON_BF16::fn args)
# define CV_CPU_CALL_NEON_BF16_(fn, args) if (CV_CPU_HAS_SUPPORT_NEON_BF16) return (opt_NEON_BF16::fn args)
#else
# define CV_TRY_NEON_BF16 0
# define CV_CPU_FORCE_NEON_BF16 0
# define CV_CPU_HAS_SUPPORT_NEON_BF16 0
# define CV_CPU_CALL_NEON_BF16(fn, args)
# define CV_CPU_CALL_NEON_BF16_(fn, args)
#endif
#define __CV_CPU_DISPATCH_CHAIN_NEON_BF16(fn, args, mode, ...) CV_CPU_CALL_NEON_BF16(fn, args); __CV_EXPAND(__CV_CPU_DISPATCH_CHAIN_ ## mode(fn, args, __VA_ARGS__))
#if !defined CV_DISABLE_OPTIMIZATION && defined CV_ENABLE_INTRINSICS && defined CV_CPU_COMPILE_MSA
# define CV_TRY_MSA 1
# define CV_CPU_FORCE_MSA 1

@ -83,7 +83,7 @@ TEST(Image2D, turnOffOpenCL)
}
else
std::cout << "CV_8UC1 is not supported for OpenCL images. Test skipped." << std::endl;
// reset state to the previous one
cv::ocl::setUseOpenCL(useOCL);
}

@ -6,9 +6,9 @@ set(the_description "Deep neural network module. It allows to load models from d
ocv_add_dispatched_file_force_all("layers/layers_common" AVX AVX2 AVX512_SKX RVV LASX)
ocv_add_dispatched_file_force_all("int8layers/layers_common" AVX2 AVX512_SKX LASX)
ocv_add_dispatched_file_force_all("layers/cpu_kernels/conv_block" AVX AVX2)
ocv_add_dispatched_file_force_all("layers/cpu_kernels/conv_block" AVX AVX2 NEON NEON_FP16)
ocv_add_dispatched_file_force_all("layers/cpu_kernels/conv_depthwise" AVX AVX2 RVV LASX)
ocv_add_dispatched_file_force_all("layers/cpu_kernels/conv_winograd_f63" AVX AVX2)
ocv_add_dispatched_file_force_all("layers/cpu_kernels/conv_winograd_f63" AVX AVX2 NEON_FP16)
ocv_add_dispatched_file_force_all("layers/cpu_kernels/fast_gemm_kernels" AVX AVX2 NEON LASX)
ocv_add_module(dnn opencv_core opencv_imgproc WRAP python java objc js)

@ -303,7 +303,7 @@ CV__DNN_INLINE_NS_BEGIN
// quantization type flag. The perChannel default is true, that means it contains the parameters
// of per-Channel quantization. Otherwise, that means this layer contains per-Tensor quantized parameters.
bool per_channel;
bool useWinograd = true; // Flag whether to use Winograd to speed up 3x3 convolution.
bool useWinograd = false; // Flag whether to use Winograd to speed up 3x3 convolution.
static Ptr<BaseConvolutionLayer> create(const LayerParams& params);
};

@ -1458,6 +1458,9 @@ CV__DNN_INLINE_NS_BEGIN
/// @sa Net::setPreferableTarget
CV_WRAP Model& setPreferableTarget(dnn::Target targetId);
/// @sa Net::enableWinograd
CV_WRAP Model& enableWinograd(bool useWinograd);
CV_DEPRECATED_EXTERNAL
operator Net&() const { return getNetwork_(); }

@ -8,16 +8,26 @@ namespace cv {
namespace dnn {
CV_CPU_OPTIMIZATION_NAMESPACE_BEGIN
void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c, int width, const int convMR, const int convNR);
void convBlock_F32(int np, const float* a, const float* b, float* c, int ldc, bool init_c, int width, const int convMR, const int convNR);
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_AVX
// FP 16 branch.
void convBlock_F16(int np, const char * _a, const char * _b, char * _c, int ldc, bool init_c, int width,
const int convMR_fp16, const int convNR_fp16);
void convBlockMR1_F16(int np, const char* _a, const char* _b, float *c, const float _bias, bool init_c,
const float minval, const float maxval, bool ifMinMaxAct, const int width, const int convNR_FP16);
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY)
#if CV_AVX
#if !CV_FMA3 // AVX workaround
#undef _mm256_fmadd_ps
#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(c, _mm256_mul_ps(a, b))
#endif
void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c, int width, const int convMR, const int convNR)
void convBlock_F32(int np, const float* a, const float* b, float* c, int ldc, bool init_c, int width, const int convMR, const int convNR)
{
CV_Assert(convMR == 4 && convNR == 24);
__m256 c00 = _mm256_set1_ps(0.f), c01 = c00, c02 = c00;
@ -121,16 +131,11 @@ void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool i
_mm256_zeroupper();
}
#endif // CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY
CV_CPU_OPTIMIZATION_NAMESPACE_END
#endif
// NEON code work around.
namespace opt_NEON
{
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_NEON
#if CV_NEON
void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c, int width, const int convMR, const int convNR)
void convBlock_F32(int np, const float* a, const float* b, float* c, int ldc, bool init_c, int width, const int convMR, const int convNR)
{
#if CV_NEON_AARCH64
if (convMR == 4 && convNR == 28) // AARCH64
@ -298,104 +303,104 @@ void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool i
}
else
#endif
if (convMR == 4 && convNR == 12) // ARMv7
{
float32x4_t c0 = vdupq_n_f32(0.f), c1 = c0, c2 = c0;
float32x4_t c3 = vdupq_n_f32(0.f), c4 = c3, c5 = c3;
float32x4_t c6 = vdupq_n_f32(0.f), c7 = c6, c8 = c6;
float32x4_t c9 = vdupq_n_f32(0.f), c10 = c9, c11 = c9;
float32x2_t a0 = vdup_n_f32(0.0f), a1 = a0;
float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f);
if (width > 8)
if (convMR == 4 && convNR == 12) // ARMv7
{
for (int p = 0; p < np; p++, a += convMR, b += convNR)
{
a0 = vld1_f32(a), a1 = vld1_f32(a+2);
b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
c1 = vmlaq_lane_f32(c1, b1, a0, 0);
c2 = vmlaq_lane_f32(c2, b2, a0, 0);
float32x4_t c0 = vdupq_n_f32(0.f), c1 = c0, c2 = c0;
float32x4_t c3 = vdupq_n_f32(0.f), c4 = c3, c5 = c3;
float32x4_t c6 = vdupq_n_f32(0.f), c7 = c6, c8 = c6;
float32x4_t c9 = vdupq_n_f32(0.f), c10 = c9, c11 = c9;
c3 = vmlaq_lane_f32(c3, b0, a0, 1);
c4 = vmlaq_lane_f32(c4, b1, a0, 1);
c5 = vmlaq_lane_f32(c5, b2, a0, 1);
float32x2_t a0 = vdup_n_f32(0.0f), a1 = a0;
float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f);
c6 = vmlaq_lane_f32(c6, b0, a1, 0);
c7 = vmlaq_lane_f32(c7, b1, a1, 0);
c8 = vmlaq_lane_f32(c8, b2, a1, 0);
c9 = vmlaq_lane_f32(c9 , b0, a1, 1);
c10 = vmlaq_lane_f32(c10, b1, a1, 1);
c11 = vmlaq_lane_f32(c11, b2, a1, 1);
if (width > 8)
{
for (int p = 0; p < np; p++, a += convMR, b += convNR)
{
a0 = vld1_f32(a), a1 = vld1_f32(a+2);
b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
c1 = vmlaq_lane_f32(c1, b1, a0, 0);
c2 = vmlaq_lane_f32(c2, b2, a0, 0);
c3 = vmlaq_lane_f32(c3, b0, a0, 1);
c4 = vmlaq_lane_f32(c4, b1, a0, 1);
c5 = vmlaq_lane_f32(c5, b2, a0, 1);
c6 = vmlaq_lane_f32(c6, b0, a1, 0);
c7 = vmlaq_lane_f32(c7, b1, a1, 0);
c8 = vmlaq_lane_f32(c8, b2, a1, 0);
c9 = vmlaq_lane_f32(c9 , b0, a1, 1);
c10 = vmlaq_lane_f32(c10, b1, a1, 1);
c11 = vmlaq_lane_f32(c11, b2, a1, 1);
}
}
}
else if (width > 4)
{
for (int p = 0; p < np; p++, a += convMR, b += convNR)
else if (width > 4)
{
a0 = vld1_f32(a), a1 = vld1_f32(a+2);
b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4);
for (int p = 0; p < np; p++, a += convMR, b += convNR)
{
a0 = vld1_f32(a), a1 = vld1_f32(a+2);
b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
c1 = vmlaq_lane_f32(c1, b1, a0, 0);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
c1 = vmlaq_lane_f32(c1, b1, a0, 0);
c3 = vmlaq_lane_f32(c3, b0, a0, 1);
c4 = vmlaq_lane_f32(c4, b1, a0, 1);
c3 = vmlaq_lane_f32(c3, b0, a0, 1);
c4 = vmlaq_lane_f32(c4, b1, a0, 1);
c6 = vmlaq_lane_f32(c6, b0, a1, 0);
c7 = vmlaq_lane_f32(c7, b1, a1, 0);
c6 = vmlaq_lane_f32(c6, b0, a1, 0);
c7 = vmlaq_lane_f32(c7, b1, a1, 0);
c9 = vmlaq_lane_f32(c9 , b0, a1, 1);
c10 = vmlaq_lane_f32(c10, b1, a1, 1);
c9 = vmlaq_lane_f32(c9 , b0, a1, 1);
c10 = vmlaq_lane_f32(c10, b1, a1, 1);
}
}
}
else
{
for (int p = 0; p < np; p++, a += convMR, b += convNR)
else
{
a0 = vld1_f32(a), a1 = vld1_f32(a+2);
b0 = vld1q_f32(b);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
c3 = vmlaq_lane_f32(c3, b0, a0, 1);
c6 = vmlaq_lane_f32(c6, b0, a1, 0);
c9 = vmlaq_lane_f32(c9 , b0, a1, 1);
for (int p = 0; p < np; p++, a += convMR, b += convNR)
{
a0 = vld1_f32(a), a1 = vld1_f32(a+2);
b0 = vld1q_f32(b);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
c3 = vmlaq_lane_f32(c3, b0, a0, 1);
c6 = vmlaq_lane_f32(c6, b0, a1, 0);
c9 = vmlaq_lane_f32(c9 , b0, a1, 1);
}
}
}
if (!init_c)
{
c0 = vaddq_f32(c0, vld1q_f32(c));
c1 = vaddq_f32(c1, vld1q_f32(c + 4));
c2 = vaddq_f32(c2, vld1q_f32(c + 8));
if (!init_c)
{
c0 = vaddq_f32(c0, vld1q_f32(c));
c1 = vaddq_f32(c1, vld1q_f32(c + 4));
c2 = vaddq_f32(c2, vld1q_f32(c + 8));
c3 = vaddq_f32(c3, vld1q_f32(c + ldc));
c4 = vaddq_f32(c4, vld1q_f32(c + ldc + 4));
c5 = vaddq_f32(c5, vld1q_f32(c + ldc + 8));
c3 = vaddq_f32(c3, vld1q_f32(c + ldc));
c4 = vaddq_f32(c4, vld1q_f32(c + ldc + 4));
c5 = vaddq_f32(c5, vld1q_f32(c + ldc + 8));
c6 = vaddq_f32(c6, vld1q_f32(c + ldc * 2));
c7 = vaddq_f32(c7, vld1q_f32(c + ldc * 2 + 4));
c8 = vaddq_f32(c8, vld1q_f32(c + ldc * 2 + 8));
c6 = vaddq_f32(c6, vld1q_f32(c + ldc * 2));
c7 = vaddq_f32(c7, vld1q_f32(c + ldc * 2 + 4));
c8 = vaddq_f32(c8, vld1q_f32(c + ldc * 2 + 8));
c9 = vaddq_f32(c9 , vld1q_f32(c + ldc * 3));
c10 = vaddq_f32(c10, vld1q_f32(c + ldc * 3 + 4));
c11 = vaddq_f32(c11, vld1q_f32(c + ldc * 3 + 8));
}
c9 = vaddq_f32(c9 , vld1q_f32(c + ldc * 3));
c10 = vaddq_f32(c10, vld1q_f32(c + ldc * 3 + 4));
c11 = vaddq_f32(c11, vld1q_f32(c + ldc * 3 + 8));
}
vst1q_f32(c, c0), vst1q_f32(c+4, c1), vst1q_f32(c+8, c2);
vst1q_f32(c + ldc, c3), vst1q_f32(c + ldc + 4, c4), vst1q_f32(c + ldc + 8, c5);
vst1q_f32(c + ldc*2, c6), vst1q_f32(c + ldc*2 + 4, c7), vst1q_f32(c + ldc*2 + 8, c8);
vst1q_f32(c + ldc*3, c9), vst1q_f32(c + ldc*3 + 4, c10), vst1q_f32(c + ldc*3 + 8, c11);
}
else
CV_Error(Error::StsNotImplemented, "Unsupported convMR and/or convNR in opt_NEON::convBlock");
vst1q_f32(c, c0), vst1q_f32(c+4, c1), vst1q_f32(c+8, c2);
vst1q_f32(c + ldc, c3), vst1q_f32(c + ldc + 4, c4), vst1q_f32(c + ldc + 8, c5);
vst1q_f32(c + ldc*2, c6), vst1q_f32(c + ldc*2 + 4, c7), vst1q_f32(c + ldc*2 + 8, c8);
vst1q_f32(c + ldc*3, c9), vst1q_f32(c + ldc*3 + 4, c10), vst1q_f32(c + ldc*3 + 8, c11);
}
else
CV_Error(Error::StsNotImplemented, "Unsupported convMR and/or convNR in opt_NEON::convBlock");
}
void convBlockMR1_F32(int np, const float * a, const float * b, float *c, const float bias, bool init_c,
const float minval, const float maxval, bool ifMinMaxAct, const int width, const int convNR)
const float minval, const float maxval, bool ifMinMaxAct, const int width, const int convNR)
{
CV_Assert(convNR == 28);
float32x4_t c0 = vdupq_n_f32(bias), c1 = c0, c2 = c0;
@ -482,22 +487,17 @@ void convBlockMR1_F32(int np, const float * a, const float * b, float *c, const
vst1q_f32(c + 20, c5);
vst1q_f32(c + 24, c6);
}
#endif
#if CV_NEON_AARCH64 && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
// Fix conflict between float16_t in arm_neon.h and float16_t in cvdef.h.
typedef __fp16 float16_t;
#if defined(CV_NEON_AARCH64) && CV_NEON_AARCH64 && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
#ifndef __ARM_FEATURE_FMA // Work around without FMA support.
#define vfmaq_f16(a, b, c) (a + b * c)
#endif
void convBlock_FP16(int np, const char * _a, const char * _b, char * _c, int ldc, bool init_c, int width,
void convBlock_F16(int np, const char * _a, const char * _b, char * _c, int ldc, bool init_c, int width,
const int convMR_fp16, const int convNR_fp16)
{
#if 1
typedef __fp16 float16_t;
const float16_t* a = (const float16_t*)_a;
const float16_t* b = (const float16_t*)_b;
float16_t* c = (float16_t*)_c;
CV_Assert(convMR_fp16 == 8 && convNR_fp16 == 24);
float16x8_t c00 = vdupq_n_f16(0), c01 = c00, c02 = c00;
@ -603,8 +603,8 @@ void convBlock_FP16(int np, const char * _a, const char * _b, char * _c, int ldc
if (!init_c)
{
#undef _FX_UPDATE_CBUF_ROW
#define _FX_UPDATE_CBUF_ROW(row) \
#undef _FX_UPDATE_CBUF_ROW
#define _FX_UPDATE_CBUF_ROW(row) \
c##row##0 = c##row##0 + vld1q_f16(c + row*ldc); \
c##row##1 = c##row##1 + vld1q_f16(c + row*ldc + 8); \
c##row##2 = c##row##2 + vld1q_f16(c + row*ldc + 16)
@ -619,8 +619,8 @@ void convBlock_FP16(int np, const char * _a, const char * _b, char * _c, int ldc
_FX_UPDATE_CBUF_ROW(7);
}
#undef _FX_STORE_CBUF_ROW
#define _FX_STORE_CBUF_ROW(row) \
#undef _FX_STORE_CBUF_ROW
#define _FX_STORE_CBUF_ROW(row) \
vst1q_f16(c + row*ldc, c##row##0); \
vst1q_f16(c + row*ldc + 8, c##row##1); \
vst1q_f16(c + row*ldc + 16, c##row##2)
@ -633,46 +633,12 @@ void convBlock_FP16(int np, const char * _a, const char * _b, char * _c, int ldc
_FX_STORE_CBUF_ROW(5);
_FX_STORE_CBUF_ROW(6);
_FX_STORE_CBUF_ROW(7);
#else
// reference only.
const float16_t* a = (const float16_t*)_a;
const float16_t* b = (const float16_t*)_b;
float16_t* c = (float16_t*)_c;
float cbuf[convMR_fp16*convNR_fp16];
memset(cbuf, 0, sizeof(cbuf));
for( int p = 0; p < np; p++ )
{
for( int i = 0; i < convMR_fp16; i++ )
{
float ai = float(a[convMR_fp16*p + i]);
for( int j = 0; j < convNR_fp16; j++ )
cbuf[i*convNR_fp16+j] += float(b[convNR_fp16*p + j]) * ai;
}
}
if (!init_c)
{
for(int i = 0; i < convMR_fp16; i++)
{
for(int j = 0; j < convNR_fp16; j++)
c[i*ldc + j] = float16_t(float(c[i*ldc + j]) + cbuf[i*convNR_fp16 + j]);
}
}
else
{
for(int i = 0; i < convMR_fp16; i++)
{
for(int j = 0; j < convNR_fp16; j++)
c[i*ldc + j] = (float16_t)(cbuf[i*convNR_fp16 + j]);
}
}
#endif
}
void convBlockMR1_FP16(int np, const char* _a, const char* _b, float *c, const float _bias, bool init_c,
const float minval, const float maxval, bool ifMinMaxAct, const int width, const int convNR_FP16)
void convBlockMR1_F16(int np, const char* _a, const char* _b, float *c, const float _bias, bool init_c,
const float minval, const float maxval, bool ifMinMaxAct, const int width, const int convNR_FP16)
{
typedef __fp16 float16_t;
CV_Assert(convNR_FP16 == 24); // CONV_NR_FP16 = 24
const float16_t* a = (const float16_t*)_a;
const float16_t* b = (const float16_t*)_b;
@ -685,7 +651,7 @@ void convBlockMR1_FP16(int np, const char* _a, const char* _b, float *c, const f
{
for (int p = 0; p < np; p++, a++, b += convNR_FP16)
{
float16x8_t a0= vdupq_n_f16(a[0]);
float16x8_t a0 = vdupq_n_f16(a[0]);
float16x8_t b0 = vld1q_f16(b), b1 = vld1q_f16(b + 8), b2 = vld1q_f16(b + 16);
c0 = vfmaq_f16(c0, a0, b0);
@ -754,6 +720,7 @@ void convBlockMR1_FP16(int np, const char* _a, const char* _b, float *c, const f
}
#endif
#endif
}
#endif // CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY
CV_CPU_OPTIMIZATION_NAMESPACE_END
}} // namespace cv::dnn

@ -92,7 +92,7 @@ void runDepthwise(InputArray _input, OutputArray _output, const Ptr<FastConv>& c
ofstab[k] = dy * Wi + dx;
}
const float *weights0 = conv->weightsBufPtr, *bias = conv->biasBuf.data();
const float *weights0 = conv->getWeights(), *bias = conv->biasBuf.data();
const float* relu = reluslope.data();
CV_Assert(ksize > 1 || (pad_left == 0 && pad_right == 0 && pad_top == 0 && pad_bottom == 0));

@ -20,15 +20,15 @@ namespace cv { namespace dnn {
#if CV_NEON || CV_SIMD128 || CV_TRY_AVX2
enum { VEC_ALIGN = 32, DFT_TYPE = CV_32F }; // Memory alignment.
void winofunc_accum_f32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
void winofunc_accum_F32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF32, const int winoNatomF32);
/*Input transform*/
void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
void winofunc_BtXB_8x8_F32(const float* inptr, int inpstep,
float* outptr, int Cg, const int winoIblock, const int winoAtomF32);
/*Output transform*/
void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep, float* bpptr, int bpstep, float* outptr, int outstep,
void winofunc_AtXA_8x8_F32(const float* inptr, int inpstep, float* bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct);
int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv>& conv,
@ -67,6 +67,28 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
#endif
const int CONV_WINO_NATOMS_F32 = CONV_WINO_AREA / CONV_WINO_ATOM_F32; // for AVX2, it is 8, otherwise, it's 16.
int CONV_WINO_ATOM = CONV_WINO_ATOM_F32;
int CONV_WINO_NATOMS = CONV_WINO_NATOMS_F32;
#ifdef CONV_ARM_FP16
// FP 16
const int CONV_WINO_ATOM_F16 = CONV_WINO_ATOM_F32 * 2;
const int CONV_WINO_NATOMS_F16 = CONV_WINO_AREA / CONV_WINO_ATOM_F16;
#endif
int esz = sizeof(float );
#ifdef CONV_ARM_FP16
const bool useFP16 = conv->useFP16;
if (useFP16)
{
// works at FP 16.
CONV_WINO_ATOM = CONV_WINO_ATOM_F16;
CONV_WINO_NATOMS = CONV_WINO_NATOMS_F16;
esz = sizeof(float16_t);
}
#endif
int Kg_nblocks = (Kg + CONV_WINO_KBLOCK - 1)/CONV_WINO_KBLOCK;
const size_t inp_planesize = (size_t)Hi*Wi;
const size_t out_planesize = (size_t)H0*W0;
@ -78,9 +100,9 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
size_t totalbufsize = (size_t)N*C*blocks_per_plane_aligned*CONV_WINO_AREA;
AutoBuffer<float> _buf;
_buf.allocate(totalbufsize + VEC_ALIGN);
float* wbuf_all = alignPtr(_buf.data(), VEC_ALIGN);
AutoBuffer<char> _buf;
_buf.allocate((totalbufsize + VEC_ALIGN) * esz);
char* wbuf_all = alignPtr(_buf.data(), VEC_ALIGN * esz);
float* inp = input.ptr<float>();
float* out = output.ptr<float>();
@ -104,14 +126,15 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
int c = nc0 - n*C;
int g = c / Cg;
c -= g*Cg;
for (int block_id = 0; block_id < blocks_per_plane; block_id += CONV_WINO_IBLOCK)
{
for (int db = 0; db < CONV_WINO_IBLOCK; db++)
{
size_t inwofs = ((n*ngroups + g)*blocks_per_plane_aligned +
block_id)*Cg*CONV_WINO_AREA +
(c*CONV_WINO_IBLOCK + db)*CONV_WINO_ATOM_F32;
float* inwptr = (float*)wbuf_all + inwofs;
(c*CONV_WINO_IBLOCK + db) * CONV_WINO_ATOM;
char* inwptr = wbuf_all + inwofs * esz;
if (block_id + db < blocks_per_plane)
{
@ -152,27 +175,40 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
inptr = inpbuf;
inpstep = CONV_WINO_SIZE;
}
#if CV_TRY_AVX2
if (conv->useAVX2)
opt_AVX2::winofunc_BtXB_8x8_f32(inptr, inpstep, inwptr, Cg, CONV_WINO_IBLOCK, CONV_WINO_ATOM_F32);
opt_AVX2::winofunc_BtXB_8x8_F32(inptr, inpstep, (float *)inwptr, Cg, CONV_WINO_IBLOCK, CONV_WINO_ATOM);
else
#endif
#if CV_TRY_AVX
if (conv->useAVX)
opt_AVX::winofunc_BtXB_8x8_f32(inptr, inpstep, inwptr, Cg, CONV_WINO_IBLOCK, CONV_WINO_ATOM_F32);
opt_AVX::winofunc_BtXB_8x8_F32(inptr, inpstep, (float *)inwptr, Cg, CONV_WINO_IBLOCK, CONV_WINO_ATOM);
else
#endif
#if CV_NEON && CV_NEON_AARCH64
if (conv->useNEON)
opt_NEON::winofunc_BtXB_8x8_f32(inptr, inpstep, inwptr, Cg, CONV_WINO_IBLOCK, CONV_WINO_ATOM_F32);
{
#ifdef CONV_ARM_FP16
if (useFP16)
{
opt_NEON_FP16::winofunc_BtXB_8x8_F16(inptr, inpstep, inwptr, Cg, CONV_WINO_IBLOCK,
CONV_WINO_ATOM);
}
else
#endif
opt_NEON::winofunc_BtXB_8x8_F32(inptr, inpstep, (float *)inwptr, Cg, CONV_WINO_IBLOCK,
CONV_WINO_ATOM);
}
else
#endif
winofunc_BtXB_8x8_f32(inptr, inpstep, inwptr, Cg, CONV_WINO_IBLOCK, CONV_WINO_ATOM_F32);
winofunc_BtXB_8x8_F32(inptr, inpstep, (float *)inwptr, Cg, CONV_WINO_IBLOCK, CONV_WINO_ATOM);
}
else
{
for (int i = 0; i < CONV_WINO_NATOMS_F32; i++, inwptr += CONV_WINO_IBLOCK*CONV_WINO_ATOM_F32)
memset(inwptr, 0, CONV_WINO_ATOM_F32*sizeof(inwptr[0]));
for (int i = 0; i < CONV_WINO_NATOMS; i++, inwptr += CONV_WINO_IBLOCK * CONV_WINO_ATOM * esz)
memset(inwptr, 0, CONV_WINO_ATOM * esz);
}
}
}
@ -182,19 +218,37 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
// Phase 2. compute elemwise-weighted sums of transformed blocks,
// apply inverse Winograd transforms to the sums,
// add bias, apply activation function if any and store the results.
char* wptr0 = nullptr;
#ifdef CONV_ARM_FP16
if (useFP16)
{
CV_Assert(!conv->weightsWinoBuf_FP16.empty());
wptr0 = (char *)conv->getWeightsWinoFP16();
}
else
#endif
{
CV_Assert(!conv->weightsWinoBuf.empty());
wptr0 = (char *)conv->getWeightsWino();
}
parallel_for_(Range(0, ntasks), [&](const Range& r0) {
for (int task_id = r0.start; task_id < r0.end; task_id++)
{
size_t out_wbuf_size = CONV_WINO_AREA*CONV_WINO_KBLOCK*CONV_WINO_IBLOCK;
size_t out_wbuf_size = CONV_WINO_AREA * CONV_WINO_KBLOCK * CONV_WINO_IBLOCK;
size_t outbuf_size = CONV_WINO_AREA;
AutoBuffer<float> out_wbuf_, outbuf_;
out_wbuf_.allocate(out_wbuf_size + VEC_ALIGN);
float* out_wbuf = alignPtr(out_wbuf_.data(), VEC_ALIGN);
// For saving the accumulation output.
AutoBuffer<char> out_wbuf_;
out_wbuf_.allocate((out_wbuf_size + VEC_ALIGN) * esz);
char* out_wbuf = alignPtr(out_wbuf_.data(), VEC_ALIGN * esz);
memset(out_wbuf, 0, out_wbuf_size * esz);
// For saving the fuse_Add data.
AutoBuffer<float> outbuf_;
outbuf_.allocate(outbuf_size + VEC_ALIGN);
float* outbuf = alignPtr(outbuf_.data(), VEC_ALIGN);
memset(out_wbuf, 0, out_wbuf_size * sizeof(float));
memset(outbuf, 0, outbuf_size * sizeof(float));
memset(outbuf, 0, outbuf_size * sizeof(outbuf[0]));
int ngk0 = (int)(((int64_t)N*Kg_nblocks*ngroups)*task_id/ntasks);
int ngk1 = (int)(((int64_t)N*Kg_nblocks*ngroups)*(task_id+1)/ntasks);
@ -214,30 +268,40 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
size_t inwofs = ((n*ngroups + g)*blocks_per_plane_aligned + block_id0)*Cg*CONV_WINO_AREA;
size_t wofs = (g*Kg_nblocks*CONV_WINO_KBLOCK + k0)*Cg*CONV_WINO_AREA;
float* inwptr = wbuf_all + inwofs;
const float* wptr = conv->weightsWinoBufPtr + wofs;
char* inwptr = wbuf_all + inwofs * esz;
char* wptr = wptr0 + wofs * esz;
#if CV_TRY_AVX2
if (conv->useAVX2)
opt_AVX2::winofunc_accum_f32(inwptr, wptr, out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM_F32, CONV_WINO_NATOMS_F32);
opt_AVX2::winofunc_accum_F32((float *)inwptr, (float *)wptr, (float *)out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM, CONV_WINO_NATOMS);
else
#endif
#if CV_TRY_AVX
if (conv->useAVX)
opt_AVX::winofunc_accum_f32(inwptr, wptr, out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM_F32, CONV_WINO_NATOMS_F32);
opt_AVX::winofunc_accum_F32((float *)inwptr, (float *)wptr, (float *)out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM, CONV_WINO_NATOMS);
else
#endif
#if CV_NEON && CV_NEON_AARCH64
if (conv->useNEON)
opt_NEON::winofunc_accum_f32(inwptr, wptr, out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM_F32, CONV_WINO_NATOMS_F32);
{
#ifdef CONV_ARM_FP16
if (useFP16)
{
opt_NEON_FP16::winofunc_accum_F16(inwptr, wptr, out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM, CONV_WINO_NATOMS);
}
else
#endif
opt_NEON::winofunc_accum_F32((float *)inwptr, (float *)wptr, (float *)out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM, CONV_WINO_NATOMS);
}
else
#endif
winofunc_accum_F32((float *)inwptr, (float *)wptr, (float *)out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM, CONV_WINO_NATOMS);
winofunc_accum_f32(inwptr, wptr, out_wbuf, Cg, block_id1 - block_id0, CONV_WINO_IBLOCK,
CONV_WINO_KBLOCK, CONV_WINO_ATOM_F32, CONV_WINO_NATOMS_F32);
for (int k = k0; k < k1; k++)
{
float biasv = conv->biasBuf[g*Kg + k];
@ -274,31 +338,42 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
}
#if CV_TRY_AVX2
if (conv->useAVX2)
opt_AVX::winofunc_AtXA_8x8_f32(out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA, CONV_WINO_SIZE,
opt_AVX::winofunc_AtXA_8x8_F32((float *)out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA, CONV_WINO_SIZE,
bpptr, outstep, outptr, outstep, biasv, minval, maxval, ifMinMaxAct);
else
#endif
#if CV_TRY_AVX
if (conv->useAVX)
opt_AVX::winofunc_AtXA_8x8_f32(out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA, CONV_WINO_SIZE,
opt_AVX::winofunc_AtXA_8x8_F32((float *)out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA, CONV_WINO_SIZE,
bpptr, outstep, outptr, outstep, biasv, minval, maxval, ifMinMaxAct);
else
#endif
#if CV_NEON && CV_NEON_AARCH64
// NEON optimization is only for ARMv8 device, and for ARMv7 device, we use the Universal intrinsics.
if (conv->useNEON)
// NEON optimization is only for ARMv8 device, and for ARMv7 device, we use the Universal intrinsics.
opt_NEON::winofunc_AtXA_8x8_f32(out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA, CONV_WINO_SIZE,
{
#ifdef CONV_ARM_FP16
if (useFP16)
{
opt_NEON_FP16::winofunc_AtXA_8x8_F16(out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA * esz, CONV_WINO_SIZE,
bpptr, outstep, outptr, outstep, biasv, minval, maxval, ifMinMaxAct);
}
else
#endif
opt_NEON::winofunc_AtXA_8x8_F32((float *)out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA, CONV_WINO_SIZE,
bpptr, outstep, outptr, outstep, biasv, minval, maxval, ifMinMaxAct);
}
else
#endif
winofunc_AtXA_8x8_f32(out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA, CONV_WINO_SIZE,
winofunc_AtXA_8x8_F32((float *)out_wbuf + ((k - k0)*CONV_WINO_IBLOCK + (block_id - block_id0))*CONV_WINO_AREA, CONV_WINO_SIZE,
bpptr, outstep, outptr, outstep, biasv, minval, maxval, ifMinMaxAct);
if (partial)
{
if (activ)
activ->forwardSlice(outptr, outptr, CONV_WINO_SIZE*CONV_WINO_STEP, 0, g*Kg + k, g*Kg + k + 1);
for (int y = 0; y < dy1; y++)
memcpy(outptr0 + y*W0, outptr + y*CONV_WINO_SIZE,dx1*sizeof(outptr0[0]));
memcpy(outptr0 + y*W0, outptr + y*CONV_WINO_SIZE, dx1*sizeof(outptr0[0]));
}
}
}
@ -314,7 +389,7 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
#if CV_SIMD128
void winofunc_accum_f32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
void winofunc_accum_F32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF32, const int winoNatomF32)
{
#if 1
@ -411,7 +486,7 @@ void winofunc_accum_f32(const float* inwptr, const float* wptr, float* outbuf, i
}
/*Input transform*/
void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
void winofunc_BtXB_8x8_F32(const float* inptr, int inpstep,
float* outptr, int Cg, const int winoIblock, const int winoAtomF32)
{
CV_Assert(winoIblock == 3 && winoAtomF32 == 4);
@ -585,7 +660,7 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
the Winograd-transformed weights should also be transposed.
init_conv() (see OpConv.fx) takes care of that.
*/
void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
void winofunc_AtXA_8x8_F32(const float* inptr, int inpstep,
float* bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct)
{

@ -0,0 +1,476 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../../precomp.hpp"
#include "convolution.hpp"
#include "opencv2/core/hal/intrin.hpp"
namespace cv {
namespace dnn {
// NEON code work around.
namespace opt_NEON
{
#if CV_NEON && CV_NEON_AARCH64
/* Accumulate */
void winofunc_accum_F32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF32, const int winoNatomF32)
{
CV_Assert(winoIblock == 6 && winoKblock == 4 && winoAtomF32 == 4);
if (iblock > 3)
{
for (int atom_id = 0; atom_id < winoNatomF32; atom_id++,
outbuf += winoAtomF32)
{
float32x4_t s00 = vdupq_n_f32(0.f), s01 = s00, s02 = s00, s03 = s00, s04 = s00, s05 = s00;
float32x4_t s10 = vdupq_n_f32(0.f), s11 = s00, s12 = s00, s13 = s00, s14 = s00, s15 = s00;
float32x4_t s20 = vdupq_n_f32(0.f), s21 = s00, s22 = s00, s23 = s00, s24 = s00, s25 = s00;
float32x4_t s30 = vdupq_n_f32(0.f), s31 = s00, s32 = s00, s33 = s00, s34 = s00, s35 = s00;
for (int c = 0; c < Cg; c++, inwptr += winoIblock*winoAtomF32,
wptr += winoKblock*winoAtomF32) {
float32x4_t w0 = vld1q_f32(wptr), w1 = vld1q_f32(wptr + 4);
float32x4_t w2 = vld1q_f32(wptr + 8), w3 = vld1q_f32(wptr + 12);
float32x4_t x0, x1;
x0 = vld1q_f32(inwptr);
x1 = vld1q_f32(inwptr + 4);
s00 = vfmaq_f32(s00, w0, x0);
s01 = vfmaq_f32(s01, w0, x1);
s10 = vfmaq_f32(s10, w1, x0);
s11 = vfmaq_f32(s11, w1, x1);
s20 = vfmaq_f32(s20, w2, x0);
s21 = vfmaq_f32(s21, w2, x1);
s30 = vfmaq_f32(s30, w3, x0);
s31 = vfmaq_f32(s31, w3, x1);
x0 = vld1q_f32(inwptr + 8);
x1 = vld1q_f32(inwptr + 12);
s02 = vfmaq_f32(s02, w0, x0);
s03 = vfmaq_f32(s03, w0, x1);
s12 = vfmaq_f32(s12, w1, x0);
s13 = vfmaq_f32(s13, w1, x1);
s22 = vfmaq_f32(s22, w2, x0);
s23 = vfmaq_f32(s23, w2, x1);
s32 = vfmaq_f32(s32, w3, x0);
s33 = vfmaq_f32(s33, w3, x1);
x0 = vld1q_f32(inwptr + 16);
x1 = vld1q_f32(inwptr + 20);
s04 = vfmaq_f32(s04, w0, x0);
s05 = vfmaq_f32(s05, w0, x1);
s14 = vfmaq_f32(s14, w1, x0);
s15 = vfmaq_f32(s15, w1, x1);
s24 = vfmaq_f32(s24, w2, x0);
s25 = vfmaq_f32(s25, w2, x1);
s34 = vfmaq_f32(s34, w3, x0);
s35 = vfmaq_f32(s35, w3, x1);
}
vst1q_f32(outbuf, s00);
vst1q_f32(outbuf + 1*64, s01);
vst1q_f32(outbuf + 2*64, s02);
vst1q_f32(outbuf + 3*64, s03);
vst1q_f32(outbuf + 4*64, s04);
vst1q_f32(outbuf + 5*64, s05);
vst1q_f32(outbuf + 6*64, s10);
vst1q_f32(outbuf + 7*64, s11);
vst1q_f32(outbuf + 8*64, s12);
vst1q_f32(outbuf + 9*64, s13);
vst1q_f32(outbuf + 10*64, s14);
vst1q_f32(outbuf + 11*64, s15);
vst1q_f32(outbuf + 12*64, s20);
vst1q_f32(outbuf + 13*64, s21);
vst1q_f32(outbuf + 14*64, s22);
vst1q_f32(outbuf + 15*64, s23);
vst1q_f32(outbuf + 16*64, s24);
vst1q_f32(outbuf + 17*64, s25);
vst1q_f32(outbuf + 18*64, s30);
vst1q_f32(outbuf + 19*64, s31);
vst1q_f32(outbuf + 20*64, s32);
vst1q_f32(outbuf + 21*64, s33);
vst1q_f32(outbuf + 22*64, s34);
vst1q_f32(outbuf + 23*64, s35);
}
}
else
{
for (int atom_id = 0; atom_id < winoNatomF32; atom_id++,
outbuf += winoAtomF32)
{
float32x4_t s00 = vdupq_n_f32(0.f), s01 = s00, s02 = s00;
float32x4_t s10 = vdupq_n_f32(0.f), s11 = s00, s12 = s00;
float32x4_t s20 = vdupq_n_f32(0.f), s21 = s00, s22 = s00;
float32x4_t s30 = vdupq_n_f32(0.f), s31 = s00, s32 = s00;
for (int c = 0; c < Cg; c++, inwptr += winoIblock*winoAtomF32,
wptr += winoKblock*winoAtomF32) {
float32x4_t w0 = vld1q_f32(wptr), w1 = vld1q_f32(wptr + 4);
float32x4_t w2 = vld1q_f32(wptr + 8), w3 = vld1q_f32(wptr + 12);
float32x4_t x0, x1, x2;
x0 = vld1q_f32(inwptr);
x1 = vld1q_f32(inwptr + 4);
x2 = vld1q_f32(inwptr + 8);
s00 = vfmaq_f32(s00, w0, x0);
s01 = vfmaq_f32(s01, w0, x1);
s02 = vfmaq_f32(s02, w0, x2);
s10 = vfmaq_f32(s10, w1, x0);
s11 = vfmaq_f32(s11, w1, x1);
s12 = vfmaq_f32(s12, w1, x2);
s20 = vfmaq_f32(s20, w2, x0);
s21 = vfmaq_f32(s21, w2, x1);
s22 = vfmaq_f32(s22, w2, x2);
s30 = vfmaq_f32(s30, w3, x0);
s31 = vfmaq_f32(s31, w3, x1);
s32 = vfmaq_f32(s32, w3, x2);
}
vst1q_f32(outbuf, s00);
vst1q_f32(outbuf + 1*64, s01);
vst1q_f32(outbuf + 2*64, s02);
vst1q_f32(outbuf + 6*64, s10);
vst1q_f32(outbuf + 7*64, s11);
vst1q_f32(outbuf + 8*64, s12);
vst1q_f32(outbuf + 12*64, s20);
vst1q_f32(outbuf + 13*64, s21);
vst1q_f32(outbuf + 14*64, s22);
vst1q_f32(outbuf + 18*64, s30);
vst1q_f32(outbuf + 19*64, s31);
vst1q_f32(outbuf + 20*64, s32);
}
}
}
#undef T4x4
#define T4x4(a, b, c, d, tr0, tr1) \
tr0 = vtrnq_f32(a, b); \
tr1 = vtrnq_f32(c, d); \
a = vcombine_f32(vget_low_f32(tr0.val[0]), vget_low_f32(tr1.val[0])); \
b = vcombine_f32(vget_low_f32(tr0.val[1]), vget_low_f32(tr1.val[1])); \
c = vcombine_f32(vget_high_f32(tr0.val[0]), vget_high_f32(tr1.val[0])); \
d = vcombine_f32(vget_high_f32(tr0.val[1]), vget_high_f32(tr1.val[1]))
/*Input transform*/
void winofunc_BtXB_8x8_F32(const float* inptr, int inpstep,
float* outptr, int Cg, const int winoIblock, const int winoAtomF32)
{
float32x4_t x00 = vld1q_f32(inptr), x01 = vld1q_f32(inptr + 4);
float32x4_t x10 = vld1q_f32(inptr + inpstep), x11 = vld1q_f32(inptr + inpstep + 4);
float32x4_t x20 = vld1q_f32(inptr + inpstep*2), x21 = vld1q_f32(inptr + inpstep*2 + 4);
float32x4_t x30 = vld1q_f32(inptr + inpstep*3), x31 = vld1q_f32(inptr + inpstep*3 + 4);
float32x4_t x40 = vld1q_f32(inptr + inpstep*4), x41 = vld1q_f32(inptr + inpstep*4 + 4);
float32x4_t x50 = vld1q_f32(inptr + inpstep*5), x51 = vld1q_f32(inptr + inpstep*5 + 4);
float32x4_t x60 = vld1q_f32(inptr + inpstep*6), x61 = vld1q_f32(inptr + inpstep*6 + 4);
float32x4_t x70 = vld1q_f32(inptr + inpstep*7), x71 = vld1q_f32(inptr + inpstep*7 + 4);
float32x4_t z00, z01, z10, z11, z20, z21, z30, z31, z40, z41, z50, z51, z60, z61, z70, z71;
{
/* Y[0] = [1.f, 0.f, -5.25f, 0.f, 5.25f, 0.f, -1.f, 0.f]*X */
/* Y[7] = [0.f, -1.f, 0.f, 5.25f, 0.f, -5.25f, 0.f, 1.f]*X */
float32x4_t q5_25 = vdupq_n_f32(5.25f), t00, t01, t10, t11;
t00 = vsubq_f32(x40, x20);
t01 = vsubq_f32(x41, x21);
t10 = vsubq_f32(x30, x50);
t11 = vsubq_f32(x31, x51);
float32x4_t y00 = vfmaq_f32(vsubq_f32(x00, x60), t00, q5_25);
float32x4_t y01 = vfmaq_f32(vsubq_f32(x01, x61), t01, q5_25);
float32x4_t y70 = vfmaq_f32(vsubq_f32(x70, x10), t10, q5_25);
float32x4_t y71 = vfmaq_f32(vsubq_f32(x71, x11), t11, q5_25);
/* Y[1] = [0.f, 1.f, 1.f, -4.25f, -4.25f, 1.f, 1.f, 0.f]*X */
/* Y[2] = [0.f, -1.f, 1.f, 4.25f, -4.25f, -1.f, 1.f, 0.f]*X */
float32x4_t qm4_25 = vdupq_n_f32(-4.25f);
t00 = vfmaq_f32(vaddq_f32(x10, x50), x30, qm4_25);
t01 = vfmaq_f32(vaddq_f32(x11, x51), x31, qm4_25);
t10 = vfmaq_f32(vaddq_f32(x20, x60), x40, qm4_25);
t11 = vfmaq_f32(vaddq_f32(x21, x61), x41, qm4_25);
float32x4_t y10 = vaddq_f32(t00, t10), y11 = vaddq_f32(t01, t11);
float32x4_t y20 = vsubq_f32(t10, t00), y21 = vsubq_f32(t11, t01);
/* Y[3] = [0.f, 0.5f, 0.25f, -2.5f, -1.25f, 2.f, 1.f, 0.f]*X */
/* Y[4] = [0.f, -0.5f, 0.25f, 2.5f, -1.25f, -2.f, 1.f, 0.f]*X */
float32x4_t q0_5 = vdupq_n_f32(0.5f), q0_25 = vdupq_n_f32(0.25f);
float32x4_t qm2_5 = vdupq_n_f32(-2.5f), qm1_25 = vdupq_n_f32(-1.25f);
t00 = vfmaq_f32(vaddq_f32(x50, x50), x10, q0_5);
t01 = vfmaq_f32(vaddq_f32(x51, x51), x11, q0_5);
t10 = vfmaq_f32(x60, x20, q0_25);
t11 = vfmaq_f32(x61, x21, q0_25);
t00 = vfmaq_f32(t00, x30, qm2_5);
t01 = vfmaq_f32(t01, x31, qm2_5);
t10 = vfmaq_f32(t10, x40, qm1_25);
t11 = vfmaq_f32(t11, x41, qm1_25);
float32x4_t y30 = vaddq_f32(t00, t10), y31 = vaddq_f32(t01, t11);
float32x4_t y40 = vsubq_f32(t10, t00), y41 = vsubq_f32(t11, t01);
/* Y[5] = [0.f, 2.f, 4.f, -2.5f, -5.f, 0.5f, 1.f, 0.f]*X */
/* Y[6] = [0.f, -2.f, 4.f, 2.5f, -5.f, -0.5f, 1.f, 0.f]*X */
float32x4_t q4 = vdupq_n_f32(4.f), qm5 = vdupq_n_f32(-5.f);
t00 = vfmaq_f32(vaddq_f32(x10, x10), x50, q0_5);
t01 = vfmaq_f32(vaddq_f32(x11, x11), x51, q0_5);
t10 = vfmaq_f32(x60, x20, q4);
t11 = vfmaq_f32(x61, x21, q4);
t00 = vfmaq_f32(t00, x30, qm2_5);
t01 = vfmaq_f32(t01, x31, qm2_5);
t10 = vfmaq_f32(t10, x40, qm5);
t11 = vfmaq_f32(t11, x41, qm5);
float32x4_t y50 = vaddq_f32(t00, t10), y51 = vaddq_f32(t01, t11);
float32x4_t y60 = vsubq_f32(t10, t00), y61 = vsubq_f32(t11, t01);
/* transpose 8x8 matrix in-place with some renumeration of the elements: */
/* Y: */
/* y00 y01 */
/* y10 y11 */
/* ... */
/* y70 y71 */
/* Y': */
/* y00 y40 */
/* y10 y50 */
/* y20 y60 */
/* y30 y70 */
/* y01 y41 */
/* y11 y51 */
/* y21 y61 */
/* y31 y71 */
/* in other words, y40 <-> y01, y50 <-> y11, y60 <-> y21, y70 <-> y31 */
float32x4x2_t tr0, tr1;
T4x4(y00, y10, y20, y30, tr0, tr1);
T4x4(y01, y11, y21, y31, tr0, tr1);
T4x4(y40, y50, y60, y70, tr0, tr1);
T4x4(y41, y51, y61, y71, tr0, tr1);
/* Z[0] = [1.f, 0.f, -5.25f, 0.f, 5.25f, 0.f, -1.f, 0.f]*Y */
/* Z[7] = [0.f, -1.f, 0.f, 5.25f, 0.f, -5.25f, 0.f, 1.f]*Y */
t00 = vsubq_f32(y01, y20);
t01 = vsubq_f32(y41, y60);
t10 = vsubq_f32(y30, y11);
t11 = vsubq_f32(y70, y51);
z00 = vfmaq_f32(vsubq_f32(y00, y21), t00, q5_25);
z01 = vfmaq_f32(vsubq_f32(y40, y61), t01, q5_25);
z70 = vfmaq_f32(vsubq_f32(y31, y10), t10, q5_25);
z71 = vfmaq_f32(vsubq_f32(y71, y50), t11, q5_25);
/* Z[1] = [0.f, 1.f, 1.f, -4.25f, -4.25f, 1.f, 1.f, 0.f]*Y */
/* Z[2] = [0.f, -1.f, 1.f, 4.25f, -4.25f, -1.f, 1.f, 0.f]*Y */
t00 = vfmaq_f32(vaddq_f32(y10, y11), y30, qm4_25);
t01 = vfmaq_f32(vaddq_f32(y50, y51), y70, qm4_25);
t10 = vfmaq_f32(vaddq_f32(y20, y21), y01, qm4_25);
t11 = vfmaq_f32(vaddq_f32(y60, y61), y41, qm4_25);
z10 = vaddq_f32(t00, t10); z11 = vaddq_f32(t01, t11);
z20 = vsubq_f32(t10, t00); z21 = vsubq_f32(t11, t01);
/* Z[3] = [0.f, 0.5f, 0.25f, -2.5f, -1.25f, 2.f, 1.f, 0.f]*Y */
/* Z[4] = [0.f, -0.5f, 0.25f, 2.5f, -1.25f, -2.f, 1.f, 0.f]*Y */
t00 = vfmaq_f32(vaddq_f32(y11, y11), y10, q0_5);
t01 = vfmaq_f32(vaddq_f32(y51, y51), y50, q0_5);
t10 = vfmaq_f32(y21, y20, q0_25);
t11 = vfmaq_f32(y61, y60, q0_25);
t00 = vfmaq_f32(t00, y30, qm2_5);
t01 = vfmaq_f32(t01, y70, qm2_5);
t10 = vfmaq_f32(t10, y01, qm1_25);
t11 = vfmaq_f32(t11, y41, qm1_25);
z30 = vaddq_f32(t00, t10); z31 = vaddq_f32(t01, t11);
z40 = vsubq_f32(t10, t00); z41 = vsubq_f32(t11, t01);
/* Z[5] = [0.f, 2.f, 4.f, -2.5f, -5.f, 0.5f, 1.f, 0.f]*Y */
/* Z[6] = [0.f, -2.f, 4.f, 2.5f, -5.f, -0.5f, 1.f, 0.f]*Y */
t00 = vfmaq_f32(vaddq_f32(y10, y10), y11, q0_5);
t01 = vfmaq_f32(vaddq_f32(y50, y50), y51, q0_5);
t10 = vfmaq_f32(y21, y20, q4);
t11 = vfmaq_f32(y61, y60, q4);
t00 = vfmaq_f32(t00, y30, qm2_5);
t01 = vfmaq_f32(t01, y70, qm2_5);
t10 = vfmaq_f32(t10, y01, qm5);
t11 = vfmaq_f32(t11, y41, qm5);
z50 = vaddq_f32(t00, t10); z51 = vaddq_f32(t01, t11);
z60 = vsubq_f32(t10, t00); z61 = vsubq_f32(t11, t01);
}
const int outstep = winoIblock*winoAtomF32*Cg;
vst1q_f32(outptr, z00);
vst1q_f32(outptr + outstep, z01);
vst1q_f32(outptr + outstep*2, z10);
vst1q_f32(outptr + outstep*3, z11);
vst1q_f32(outptr + outstep*4, z20);
vst1q_f32(outptr + outstep*5, z21);
vst1q_f32(outptr + outstep*6, z30);
vst1q_f32(outptr + outstep*7, z31);
vst1q_f32(outptr + outstep*8, z40);
vst1q_f32(outptr + outstep*9, z41);
vst1q_f32(outptr + outstep*10, z50);
vst1q_f32(outptr + outstep*11, z51);
vst1q_f32(outptr + outstep*12, z60);
vst1q_f32(outptr + outstep*13, z61);
vst1q_f32(outptr + outstep*14, z70);
vst1q_f32(outptr + outstep*15, z71);
}
/*Output transform*/
void winofunc_AtXA_8x8_F32(const float* inptr, int inpstep,
float* bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct)
{
float32x4_t x00 = vld1q_f32(inptr), x01 = vld1q_f32(inptr + 4);
float32x4_t x10 = vld1q_f32(inptr + inpstep), x11 = vld1q_f32(inptr + inpstep + 4);
float32x4_t x20 = vld1q_f32(inptr + inpstep*2), x21 = vld1q_f32(inptr + inpstep*2 + 4);
float32x4_t x30 = vld1q_f32(inptr + inpstep*3), x31 = vld1q_f32(inptr + inpstep*3 + 4);
float32x4_t x40 = vld1q_f32(inptr + inpstep*4), x41 = vld1q_f32(inptr + inpstep*4 + 4);
float32x4_t x50 = vld1q_f32(inptr + inpstep*5), x51 = vld1q_f32(inptr + inpstep*5 + 4);
float32x4_t x60 = vld1q_f32(inptr + inpstep*6), x61 = vld1q_f32(inptr + inpstep*6 + 4);
float32x4_t x70 = vld1q_f32(inptr + inpstep*7), x71 = vld1q_f32(inptr + inpstep*7 + 4);
float32x4_t z00, z01, z10, z11, z20, z21, z30, z31, z40, z41, z50, z51;
{
float32x4_t s12_0, s12_1, s34_0, s34_1, s56_0, s56_1;
s12_0 = vaddq_f32(x10, x20); s12_1 = vaddq_f32(x11, x21);
s34_0 = vaddq_f32(x30, x40); s34_1 = vaddq_f32(x31, x41);
s56_0 = vaddq_f32(x50, x60); s56_1 = vaddq_f32(x51, x61);
float32x4_t y00 = vaddq_f32(vaddq_f32(vaddq_f32(x00, s12_0), s34_0), s56_0);
float32x4_t y01 = vaddq_f32(vaddq_f32(vaddq_f32(x01, s12_1), s34_1), s56_1);
float32x4_t y20 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 4.0f), s56_0, 0.25f);
float32x4_t y21 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 4.0f), s56_1, 0.25f);
float32x4_t y40 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 16.0f), s56_0, 1.f/16);
float32x4_t y41 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 16.0f), s56_1, 1.f/16);
s12_0 = vsubq_f32(x10, x20); s12_1 = vsubq_f32(x11, x21);
s34_0 = vsubq_f32(x30, x40); s34_1 = vsubq_f32(x31, x41);
s56_0 = vsubq_f32(x50, x60); s56_1 = vsubq_f32(x51, x61);
float32x4_t y50 = vfmaq_n_f32(vfmaq_n_f32(vaddq_f32(x70, s12_0),
s34_0, 32.f), s56_0, 1.f/32);
float32x4_t y51 = vfmaq_n_f32(vfmaq_n_f32(vaddq_f32(x71, s12_1),
s34_1, 32.f), s56_1, 1.f/32);
float32x4_t y10 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 2.0f), s56_0, 0.5f);
float32x4_t y11 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 2.0f), s56_1, 0.5f);
float32x4_t y30 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 8.0f), s56_0, 0.125f);
float32x4_t y31 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 8.0f), s56_1, 0.125f);
float32x4_t y60 = vdupq_n_f32(0.f), y61 = y60, y70 = y60, y71 = y60;
/* transpose 8x8 matrix in-place with some renumeration of the elements: */
/* Y: */
/* y00 y01 */
/* y10 y11 */
/* ... */
/* y50 y51 */
/* 0 0 */
/* 0 0 */
/* Y': */
/* y00 y40 */
/* y10 y50 */
/* y20 y60 */
/* y30 y70 */
/* y01 y41 */
/* y11 y51 */
/* y21 y61 */
/* y31 y71 */
/* in other words, y40 <-> y01, y50 <-> y11, y60 <-> y21, y70 <-> y31 */
float32x4x2_t tr0, tr1;
T4x4(y00, y10, y20, y30, tr0, tr1);
T4x4(y01, y11, y21, y31, tr0, tr1);
T4x4(y40, y50, y60, y70, tr0, tr1);
T4x4(y41, y51, y61, y71, tr0, tr1);
s12_0 = vaddq_f32(y10, y20); s12_1 = vaddq_f32(y50, y60);
s34_0 = vaddq_f32(y30, y01); s34_1 = vaddq_f32(y70, y41);
s56_0 = vaddq_f32(y11, y21); s56_1 = vaddq_f32(y51, y61);
z00 = vaddq_f32(vaddq_f32(vaddq_f32(y00, s12_0), s34_0), s56_0);
z01 = vaddq_f32(vaddq_f32(vaddq_f32(y40, s12_1), s34_1), s56_1);
z20 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 4.0f), s56_0, 0.25f);
z21 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 4.0f), s56_1, 0.25f);
z40 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 16.0f), s56_0, 1.f/16);
z41 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 16.0f), s56_1, 1.f/16);
s12_0 = vsubq_f32(y10, y20); s12_1 = vsubq_f32(y50, y60);
s34_0 = vsubq_f32(y30, y01); s34_1 = vsubq_f32(y70, y41);
s56_0 = vsubq_f32(y11, y21); s56_1 = vsubq_f32(y51, y61);
z50 = vfmaq_n_f32(vfmaq_n_f32(vaddq_f32(y31, s12_0),
s34_0, 32.f), s56_0, 1.f/32);
z51 = vfmaq_n_f32(vfmaq_n_f32(vaddq_f32(y71, s12_1),
s34_1, 32.f), s56_1, 1.f/32);
z10 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 2.0f), s56_0, 0.5f);
z11 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 2.0f), s56_1, 0.5f);
z30 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 8.0f), s56_0, 0.125f);
z31 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 8.0f), s56_1, 0.125f);
float32x4_t vbias = vdupq_n_f32(bias);
z00 = vaddq_f32(z00, vbias);
z01 = vaddq_f32(z01, vbias);
z10 = vaddq_f32(z10, vbias);
z11 = vaddq_f32(z11, vbias);
z20 = vaddq_f32(z20, vbias);
z21 = vaddq_f32(z21, vbias);
z30 = vaddq_f32(z30, vbias);
z31 = vaddq_f32(z31, vbias);
z40 = vaddq_f32(z40, vbias);
z41 = vaddq_f32(z41, vbias);
z50 = vaddq_f32(z50, vbias);
z51 = vaddq_f32(z51, vbias);
}
if (bpptr)
{
float32x2_t zhalf = vdup_n_f32(0.f);
z00 = vaddq_f32(z00, vld1q_f32(bpptr));
z01 = vaddq_f32(z01, vcombine_f32(vld1_f32(bpptr + 4), zhalf));
z10 = vaddq_f32(z10, vld1q_f32(bpptr + bpstep));
z11 = vaddq_f32(z11, vcombine_f32(vld1_f32(bpptr + bpstep + 4), zhalf));
z20 = vaddq_f32(z20, vld1q_f32(bpptr + bpstep*2));
z21 = vaddq_f32(z21, vcombine_f32(vld1_f32(bpptr + bpstep*2 + 4), zhalf));
z30 = vaddq_f32(z30, vld1q_f32(bpptr + bpstep*3));
z31 = vaddq_f32(z31, vcombine_f32(vld1_f32(bpptr + bpstep*3 + 4), zhalf));
z40 = vaddq_f32(z40, vld1q_f32(bpptr + bpstep*4));
z41 = vaddq_f32(z41, vcombine_f32(vld1_f32(bpptr + bpstep*4 + 4), zhalf));
z50 = vaddq_f32(z50, vld1q_f32(bpptr + bpstep*5));
z51 = vaddq_f32(z51, vcombine_f32(vld1_f32(bpptr + bpstep*5 + 4), zhalf));
}
if (ifMinMaxAct)
{
float32x4_t vmax = vdupq_n_f32(maxval);
float32x4_t vmin = vdupq_n_f32(minval);
z00 = vminq_f32(vmaxq_f32(z00, vmin), vmax);
z01 = vminq_f32(vmaxq_f32(z01, vmin), vmax);
z10 = vminq_f32(vmaxq_f32(z10, vmin), vmax);
z11 = vminq_f32(vmaxq_f32(z11, vmin), vmax);
z20 = vminq_f32(vmaxq_f32(z20, vmin), vmax);
z21 = vminq_f32(vmaxq_f32(z21, vmin), vmax);
z30 = vminq_f32(vmaxq_f32(z30, vmin), vmax);
z31 = vminq_f32(vmaxq_f32(z31, vmin), vmax);
z40 = vminq_f32(vmaxq_f32(z40, vmin), vmax);
z41 = vminq_f32(vmaxq_f32(z41, vmin), vmax);
z50 = vminq_f32(vmaxq_f32(z50, vmin), vmax);
z51 = vminq_f32(vmaxq_f32(z51, vmin), vmax);
}
vst1q_f32(outptr, z00);
vst1_f32(outptr + 4, vget_low_f32(z01));
vst1q_f32(outptr + outstep, z10);
vst1_f32(outptr + outstep + 4, vget_low_f32(z11));
vst1q_f32(outptr + outstep*2, z20);
vst1_f32(outptr + outstep*2 + 4, vget_low_f32(z21));
vst1q_f32(outptr + outstep*3, z30);
vst1_f32(outptr + outstep*3 + 4, vget_low_f32(z31));
vst1q_f32(outptr + outstep*4, z40);
vst1_f32(outptr + outstep*4 + 4, vget_low_f32(z41));
vst1q_f32(outptr + outstep*5, z50);
vst1_f32(outptr + outstep*5 + 4, vget_low_f32(z51));
}
#endif
}
}} // namespace

@ -9,26 +9,37 @@ namespace dnn {
CV_CPU_OPTIMIZATION_NAMESPACE_BEGIN
/* Accumulate */
void winofunc_accum_f32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
void winofunc_accum_F32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF32, const int winoNatomF32);
/*Input transform*/
void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
void winofunc_BtXB_8x8_F32(const float* inptr, int inpstep,
float* outptr, int Cg, const int winoIblock, const int winoAtomF32);
/*Output transform*/
void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
void winofunc_AtXA_8x8_F32(const float* inptr, int inpstep,
float* bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct);
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_AVX
// FP 16 branch, only ARMv8 supports.
void winofunc_accum_F16(const char* _inwptr, const char* _wptr, char* _outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF16, const int winoNatomF16);
void winofunc_BtXB_8x8_F16(const float * inptr, int inpstep,
char * _outptr, int Cg, const int winoIblock, const int winoAtomF16);
void winofunc_AtXA_8x8_F16(const char* inptr, int inpstep,
float * bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct);
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY)
#if CV_AVX
#if !CV_FMA3 // AVX workaround
#undef _mm256_fmadd_ps
#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(c, _mm256_mul_ps(a, b))
#endif
void winofunc_accum_f32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
void winofunc_accum_F32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF32, const int winoNatomF32)
{
CV_Assert(winoIblock == 6 && winoKblock == 4 && winoAtomF32 == 8);
@ -187,7 +198,7 @@ void transpose8_ps(__m256 &row0, __m256 &row1, __m256 &row2, __m256 &row3, __m25
}
/*Input transform*/
void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
void winofunc_BtXB_8x8_F32(const float* inptr, int inpstep,
float* outptr, int Cg, const int winoIblock, const int winoAtomF32)
{
__m256 x00 = _mm256_loadu_ps(inptr);
@ -311,7 +322,7 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
0.f, 1.f, 1.f, 16.f, 16.f, 1.f/16, 1.f/16, 0.f,
0.f, 1.f, -1.f, 32.f, -32.f, 1.f/32, -1.f/32, 1.f]
*/
void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
void winofunc_AtXA_8x8_F32(const float* inptr, int inpstep,
float* bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct)
{
@ -405,166 +416,183 @@ void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
STORE6_ELE_FROM_16(outptr + outstep * 5, z50, lowM, highM);
_mm256_zeroupper();
}
#endif // CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY
CV_CPU_OPTIMIZATION_NAMESPACE_END
#endif // CV_AVX
// NEON code work around.
namespace opt_NEON
{
// FP16, currently, only ARMv8 may support it
#if defined(CV_NEON_AARCH64) && CV_NEON_AARCH64 && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_NEON && CV_NEON_AARCH64
/* Accumulate */
void winofunc_accum_f32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF32, const int winoNatomF32);
#undef T4x4
#define T4x4(a, b, c, d, tr0, tr1) \
tr0 = vtrnq_f32(a, b); \
tr1 = vtrnq_f32(c, d); \
a = vcombine_f32(vget_low_f32(tr0.val[0]), vget_low_f32(tr1.val[0])); \
b = vcombine_f32(vget_low_f32(tr0.val[1]), vget_low_f32(tr1.val[1])); \
c = vcombine_f32(vget_high_f32(tr0.val[0]), vget_high_f32(tr1.val[0])); \
d = vcombine_f32(vget_high_f32(tr0.val[1]), vget_high_f32(tr1.val[1]))
/*Input transform*/
void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
float* outptr, int Cg, const int winoIblock, const int winoAtomF32);
/* Accumulate */
void winofunc_accum_F16(const char* _inwptr, const char* _wptr, char* _outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF16, const int winoNatomF16)
{
typedef __fp16 float16_t;
const float16_t* inwptr = (const float16_t*)_inwptr;
const float16_t* wptr = (const float16_t*)_wptr;
float16_t* outbuf = (float16_t*)_outbuf;
/*Output transform*/
void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
float* bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct);
CV_Assert(winoIblock == 6 && winoKblock == 4 && winoAtomF16 == 8);
void winofunc_accum_f32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtomF32, const int winoNatomF32)
{
CV_Assert(winoIblock == 6 && winoKblock == 4 && winoAtomF32 == 4);
if (iblock > 3)
{
for (int atom_id = 0; atom_id < winoNatomF32; atom_id++,
outbuf += winoAtomF32)
for (int atom_id = 0; atom_id < winoNatomF16; atom_id++, outbuf += winoAtomF16)
{
float32x4_t s00 = vdupq_n_f32(0.f), s01 = s00, s02 = s00, s03 = s00, s04 = s00, s05 = s00;
float32x4_t s10 = vdupq_n_f32(0.f), s11 = s00, s12 = s00, s13 = s00, s14 = s00, s15 = s00;
float32x4_t s20 = vdupq_n_f32(0.f), s21 = s00, s22 = s00, s23 = s00, s24 = s00, s25 = s00;
float32x4_t s30 = vdupq_n_f32(0.f), s31 = s00, s32 = s00, s33 = s00, s34 = s00, s35 = s00;
for (int c = 0; c < Cg; c++, inwptr += winoIblock*winoAtomF32,
wptr += winoKblock*winoAtomF32) {
float32x4_t w0 = vld1q_f32(wptr), w1 = vld1q_f32(wptr + 4);
float32x4_t w2 = vld1q_f32(wptr + 8), w3 = vld1q_f32(wptr + 12);
float32x4_t x0, x1;
x0 = vld1q_f32(inwptr);
x1 = vld1q_f32(inwptr + 4);
s00 = vfmaq_f32(s00, w0, x0);
s01 = vfmaq_f32(s01, w0, x1);
s10 = vfmaq_f32(s10, w1, x0);
s11 = vfmaq_f32(s11, w1, x1);
s20 = vfmaq_f32(s20, w2, x0);
s21 = vfmaq_f32(s21, w2, x1);
s30 = vfmaq_f32(s30, w3, x0);
s31 = vfmaq_f32(s31, w3, x1);
x0 = vld1q_f32(inwptr + 8);
x1 = vld1q_f32(inwptr + 12);
s02 = vfmaq_f32(s02, w0, x0);
s03 = vfmaq_f32(s03, w0, x1);
s12 = vfmaq_f32(s12, w1, x0);
s13 = vfmaq_f32(s13, w1, x1);
s22 = vfmaq_f32(s22, w2, x0);
s23 = vfmaq_f32(s23, w2, x1);
s32 = vfmaq_f32(s32, w3, x0);
s33 = vfmaq_f32(s33, w3, x1);
x0 = vld1q_f32(inwptr + 16);
x1 = vld1q_f32(inwptr + 20);
s04 = vfmaq_f32(s04, w0, x0);
s05 = vfmaq_f32(s05, w0, x1);
s14 = vfmaq_f32(s14, w1, x0);
s15 = vfmaq_f32(s15, w1, x1);
s24 = vfmaq_f32(s24, w2, x0);
s25 = vfmaq_f32(s25, w2, x1);
s34 = vfmaq_f32(s34, w3, x0);
s35 = vfmaq_f32(s35, w3, x1);
float16x8_t s00 = vdupq_n_f16(0.f), s01 = s00, s02 = s00, s03 = s00, s04 = s00, s05 = s00;
float16x8_t s10 = vdupq_n_f16(0.f), s11 = s00, s12 = s00, s13 = s00, s14 = s00, s15 = s00;
float16x8_t s20 = vdupq_n_f16(0.f), s21 = s00, s22 = s00, s23 = s00, s24 = s00, s25 = s00;
float16x8_t s30 = vdupq_n_f16(0.f), s31 = s00, s32 = s00, s33 = s00, s34 = s00, s35 = s00;
for (int c = 0; c < Cg; c++, inwptr += winoIblock*winoAtomF16,
wptr += winoKblock*winoAtomF16)
{
float16x8_t w0 = vld1q_f16(wptr), w1 = vld1q_f16(wptr + 8);
float16x8_t w2 = vld1q_f16(wptr + 16), w3 = vld1q_f16(wptr + 24);
float16x8_t x0, x1, x2;
x0 = vld1q_f16(inwptr);
x1 = vld1q_f16(inwptr + 8);
x2 = vld1q_f16(inwptr + 16);
s00 = vfmaq_f16(s00, w0, x0);
s01 = vfmaq_f16(s01, w0, x1);
s02 = vfmaq_f16(s02, w0, x2);
s10 = vfmaq_f16(s10, w1, x0);
s11 = vfmaq_f16(s11, w1, x1);
s12 = vfmaq_f16(s12, w1, x2);
s20 = vfmaq_f16(s20, w2, x0);
s21 = vfmaq_f16(s21, w2, x1);
s22 = vfmaq_f16(s22, w2, x2);
s30 = vfmaq_f16(s30, w3, x0);
s31 = vfmaq_f16(s31, w3, x1);
s32 = vfmaq_f16(s32, w3, x2);
x0 = vld1q_f16(inwptr + 24);
x1 = vld1q_f16(inwptr + 32);
x2 = vld1q_f16(inwptr + 40);
s03 = vfmaq_f16(s03, w0, x0);
s04 = vfmaq_f16(s04, w0, x1);
s05 = vfmaq_f16(s05, w0, x2);
s13 = vfmaq_f16(s13, w1, x0);
s14 = vfmaq_f16(s14, w1, x1);
s15 = vfmaq_f16(s15, w1, x2);
s23 = vfmaq_f16(s23, w2, x0);
s24 = vfmaq_f16(s24, w2, x1);
s25 = vfmaq_f16(s25, w2, x2);
s33 = vfmaq_f16(s33, w3, x0);
s34 = vfmaq_f16(s34, w3, x1);
s35 = vfmaq_f16(s35, w3, x2);
}
vst1q_f32(outbuf, s00);
vst1q_f32(outbuf + 1*64, s01);
vst1q_f32(outbuf + 2*64, s02);
vst1q_f32(outbuf + 3*64, s03);
vst1q_f32(outbuf + 4*64, s04);
vst1q_f32(outbuf + 5*64, s05);
vst1q_f32(outbuf + 6*64, s10);
vst1q_f32(outbuf + 7*64, s11);
vst1q_f32(outbuf + 8*64, s12);
vst1q_f32(outbuf + 9*64, s13);
vst1q_f32(outbuf + 10*64, s14);
vst1q_f32(outbuf + 11*64, s15);
vst1q_f32(outbuf + 12*64, s20);
vst1q_f32(outbuf + 13*64, s21);
vst1q_f32(outbuf + 14*64, s22);
vst1q_f32(outbuf + 15*64, s23);
vst1q_f32(outbuf + 16*64, s24);
vst1q_f32(outbuf + 17*64, s25);
vst1q_f32(outbuf + 18*64, s30);
vst1q_f32(outbuf + 19*64, s31);
vst1q_f32(outbuf + 20*64, s32);
vst1q_f32(outbuf + 21*64, s33);
vst1q_f32(outbuf + 22*64, s34);
vst1q_f32(outbuf + 23*64, s35);
vst1q_f16(outbuf, s00);
vst1q_f16(outbuf + 1*64, s01);
vst1q_f16(outbuf + 2*64, s02);
vst1q_f16(outbuf + 3*64, s03);
vst1q_f16(outbuf + 4*64, s04);
vst1q_f16(outbuf + 5*64, s05);
vst1q_f16(outbuf + 6*64, s10);
vst1q_f16(outbuf + 7*64, s11);
vst1q_f16(outbuf + 8*64, s12);
vst1q_f16(outbuf + 9*64, s13);
vst1q_f16(outbuf + 10*64, s14);
vst1q_f16(outbuf + 11*64, s15);
vst1q_f16(outbuf + 12*64, s20);
vst1q_f16(outbuf + 13*64, s21);
vst1q_f16(outbuf + 14*64, s22);
vst1q_f16(outbuf + 15*64, s23);
vst1q_f16(outbuf + 16*64, s24);
vst1q_f16(outbuf + 17*64, s25);
vst1q_f16(outbuf + 18*64, s30);
vst1q_f16(outbuf + 19*64, s31);
vst1q_f16(outbuf + 20*64, s32);
vst1q_f16(outbuf + 21*64, s33);
vst1q_f16(outbuf + 22*64, s34);
vst1q_f16(outbuf + 23*64, s35);
}
}
else
{
for (int atom_id = 0; atom_id < winoNatomF32; atom_id++,
outbuf += winoAtomF32)
for (int atom_id = 0; atom_id < winoNatomF16; atom_id++,
outbuf += winoAtomF16)
{
float32x4_t s00 = vdupq_n_f32(0.f), s01 = s00, s02 = s00;
float32x4_t s10 = vdupq_n_f32(0.f), s11 = s00, s12 = s00;
float32x4_t s20 = vdupq_n_f32(0.f), s21 = s00, s22 = s00;
float32x4_t s30 = vdupq_n_f32(0.f), s31 = s00, s32 = s00;
for (int c = 0; c < Cg; c++, inwptr += winoIblock*winoAtomF32,
wptr += winoKblock*winoAtomF32) {
float32x4_t w0 = vld1q_f32(wptr), w1 = vld1q_f32(wptr + 4);
float32x4_t w2 = vld1q_f32(wptr + 8), w3 = vld1q_f32(wptr + 12);
float32x4_t x0, x1, x2;
x0 = vld1q_f32(inwptr);
x1 = vld1q_f32(inwptr + 4);
x2 = vld1q_f32(inwptr + 8);
s00 = vfmaq_f32(s00, w0, x0);
s01 = vfmaq_f32(s01, w0, x1);
s02 = vfmaq_f32(s02, w0, x2);
s10 = vfmaq_f32(s10, w1, x0);
s11 = vfmaq_f32(s11, w1, x1);
s12 = vfmaq_f32(s12, w1, x2);
s20 = vfmaq_f32(s20, w2, x0);
s21 = vfmaq_f32(s21, w2, x1);
s22 = vfmaq_f32(s22, w2, x2);
s30 = vfmaq_f32(s30, w3, x0);
s31 = vfmaq_f32(s31, w3, x1);
s32 = vfmaq_f32(s32, w3, x2);
float16x8_t s00 = vdupq_n_f16(0.f), s01 = s00, s02 = s00;
float16x8_t s10 = vdupq_n_f16(0.f), s11 = s00, s12 = s00;
float16x8_t s20 = vdupq_n_f16(0.f), s21 = s00, s22 = s00;
float16x8_t s30 = vdupq_n_f16(0.f), s31 = s00, s32 = s00;
for (int c = 0; c < Cg; c++, inwptr += winoIblock*winoAtomF16,
wptr += winoKblock*winoAtomF16)
{
float16x8_t w0 = vld1q_f16(wptr), w1 = vld1q_f16(wptr + 8);
float16x8_t w2 = vld1q_f16(wptr + 16), w3 = vld1q_f16(wptr + 24);
float16x8_t x0, x1, x2;
x0 = vld1q_f16(inwptr);
x1 = vld1q_f16(inwptr + 8);
x2 = vld1q_f16(inwptr + 16);
s00 = vfmaq_f16(s00, w0, x0);
s01 = vfmaq_f16(s01, w0, x1);
s02 = vfmaq_f16(s02, w0, x2);
s10 = vfmaq_f16(s10, w1, x0);
s11 = vfmaq_f16(s11, w1, x1);
s12 = vfmaq_f16(s12, w1, x2);
s20 = vfmaq_f16(s20, w2, x0);
s21 = vfmaq_f16(s21, w2, x1);
s22 = vfmaq_f16(s22, w2, x2);
s30 = vfmaq_f16(s30, w3, x0);
s31 = vfmaq_f16(s31, w3, x1);
s32 = vfmaq_f16(s32, w3, x2);
}
vst1q_f32(outbuf, s00);
vst1q_f32(outbuf + 1*64, s01);
vst1q_f32(outbuf + 2*64, s02);
vst1q_f32(outbuf + 6*64, s10);
vst1q_f32(outbuf + 7*64, s11);
vst1q_f32(outbuf + 8*64, s12);
vst1q_f32(outbuf + 12*64, s20);
vst1q_f32(outbuf + 13*64, s21);
vst1q_f32(outbuf + 14*64, s22);
vst1q_f32(outbuf + 18*64, s30);
vst1q_f32(outbuf + 19*64, s31);
vst1q_f32(outbuf + 20*64, s32);
vst1q_f16(outbuf, s00);
vst1q_f16(outbuf + 1*64, s01);
vst1q_f16(outbuf + 2*64, s02);
vst1q_f16(outbuf + 6*64, s10);
vst1q_f16(outbuf + 7*64, s11);
vst1q_f16(outbuf + 8*64, s12);
vst1q_f16(outbuf + 12*64, s20);
vst1q_f16(outbuf + 13*64, s21);
vst1q_f16(outbuf + 14*64, s22);
vst1q_f16(outbuf + 18*64, s30);
vst1q_f16(outbuf + 19*64, s31);
vst1q_f16(outbuf + 20*64, s32);
}
}
}
#define T4x4(a, b, c, d, tr0, tr1) \
tr0 = vtrnq_f32(a, b); \
tr1 = vtrnq_f32(c, d); \
a = vcombine_f32(vget_low_f32(tr0.val[0]), vget_low_f32(tr1.val[0])); \
b = vcombine_f32(vget_low_f32(tr0.val[1]), vget_low_f32(tr1.val[1])); \
c = vcombine_f32(vget_high_f32(tr0.val[0]), vget_high_f32(tr1.val[0])); \
d = vcombine_f32(vget_high_f32(tr0.val[1]), vget_high_f32(tr1.val[1]))
/*Input transform*/
void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
float* outptr, int Cg, const int winoIblock, const int winoAtomF32)
//NOTE: Since we don't have the fully fp16 support. Current work around is that we need packing the data and
// convert it to FP16 in input transform stage. And at output transform stage we will convert it back to FP32.
void winofunc_BtXB_8x8_F16(const float * inptr, int inpstep,
char * _outptr, int Cg, const int winoIblock, const int winoAtomF16)
{
typedef __fp16 float16_t;
float16_t* outptr = (float16_t*)_outptr;
float32x4_t x00 = vld1q_f32(inptr), x01 = vld1q_f32(inptr + 4);
float32x4_t x10 = vld1q_f32(inptr + inpstep), x11 = vld1q_f32(inptr + inpstep + 4);
float32x4_t x20 = vld1q_f32(inptr + inpstep*2), x21 = vld1q_f32(inptr + inpstep*2 + 4);
@ -577,8 +605,8 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
float32x4_t z00, z01, z10, z11, z20, z21, z30, z31, z40, z41, z50, z51, z60, z61, z70, z71;
{
/* Y[0] = [1.f, 0.f, -5.25f, 0.f, 5.25f, 0.f, -1.f, 0.f]*X */
/* Y[7] = [0.f, -1.f, 0.f, 5.25f, 0.f, -5.25f, 0.f, 1.f]*X */
// Y[0] = [1.f, 0.f, -5.25f, 0.f, 5.25f, 0.f, -1.f, 0.f]*X
// Y[7] = [0.f, -1.f, 0.f, 5.25f, 0.f, -5.25f, 0.f, 1.f]*X
float32x4_t q5_25 = vdupq_n_f32(5.25f), t00, t01, t10, t11;
t00 = vsubq_f32(x40, x20);
t01 = vsubq_f32(x41, x21);
@ -589,8 +617,8 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
float32x4_t y70 = vfmaq_f32(vsubq_f32(x70, x10), t10, q5_25);
float32x4_t y71 = vfmaq_f32(vsubq_f32(x71, x11), t11, q5_25);
/* Y[1] = [0.f, 1.f, 1.f, -4.25f, -4.25f, 1.f, 1.f, 0.f]*X */
/* Y[2] = [0.f, -1.f, 1.f, 4.25f, -4.25f, -1.f, 1.f, 0.f]*X */
// Y[1] = [0.f, 1.f, 1.f, -4.25f, -4.25f, 1.f, 1.f, 0.f]*X
// Y[2] = [0.f, -1.f, 1.f, 4.25f, -4.25f, -1.f, 1.f, 0.f]*X
float32x4_t qm4_25 = vdupq_n_f32(-4.25f);
t00 = vfmaq_f32(vaddq_f32(x10, x50), x30, qm4_25);
t01 = vfmaq_f32(vaddq_f32(x11, x51), x31, qm4_25);
@ -600,8 +628,8 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
float32x4_t y10 = vaddq_f32(t00, t10), y11 = vaddq_f32(t01, t11);
float32x4_t y20 = vsubq_f32(t10, t00), y21 = vsubq_f32(t11, t01);
/* Y[3] = [0.f, 0.5f, 0.25f, -2.5f, -1.25f, 2.f, 1.f, 0.f]*X */
/* Y[4] = [0.f, -0.5f, 0.25f, 2.5f, -1.25f, -2.f, 1.f, 0.f]*X */
// Y[3] = [0.f, 0.5f, 0.25f, -2.5f, -1.25f, 2.f, 1.f, 0.f]*X
// Y[4] = [0.f, -0.5f, 0.25f, 2.5f, -1.25f, -2.f, 1.f, 0.f]*X
float32x4_t q0_5 = vdupq_n_f32(0.5f), q0_25 = vdupq_n_f32(0.25f);
float32x4_t qm2_5 = vdupq_n_f32(-2.5f), qm1_25 = vdupq_n_f32(-1.25f);
t00 = vfmaq_f32(vaddq_f32(x50, x50), x10, q0_5);
@ -616,8 +644,8 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
float32x4_t y30 = vaddq_f32(t00, t10), y31 = vaddq_f32(t01, t11);
float32x4_t y40 = vsubq_f32(t10, t00), y41 = vsubq_f32(t11, t01);
/* Y[5] = [0.f, 2.f, 4.f, -2.5f, -5.f, 0.5f, 1.f, 0.f]*X */
/* Y[6] = [0.f, -2.f, 4.f, 2.5f, -5.f, -0.5f, 1.f, 0.f]*X */
// Y[5] = [0.f, 2.f, 4.f, -2.5f, -5.f, 0.5f, 1.f, 0.f]*X
// Y[6] = [0.f, -2.f, 4.f, 2.5f, -5.f, -0.5f, 1.f, 0.f]*X
float32x4_t q4 = vdupq_n_f32(4.f), qm5 = vdupq_n_f32(-5.f);
t00 = vfmaq_f32(vaddq_f32(x10, x10), x50, q0_5);
t01 = vfmaq_f32(vaddq_f32(x11, x11), x51, q0_5);
@ -631,22 +659,22 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
float32x4_t y50 = vaddq_f32(t00, t10), y51 = vaddq_f32(t01, t11);
float32x4_t y60 = vsubq_f32(t10, t00), y61 = vsubq_f32(t11, t01);
/* transpose 8x8 matrix in-place with some renumeration of the elements: */
/* Y: */
/* y00 y01 */
/* y10 y11 */
/* ... */
/* y70 y71 */
/* Y': */
/* y00 y40 */
/* y10 y50 */
/* y20 y60 */
/* y30 y70 */
/* y01 y41 */
/* y11 y51 */
/* y21 y61 */
/* y31 y71 */
/* in other words, y40 <-> y01, y50 <-> y11, y60 <-> y21, y70 <-> y31 */
// transpose 8x8 matrix in-place with some renumeration of the elements:
// Y:
// y00 y01
// y10 y11
// ...
// y70 y71
// Y':
// y00 y40
// y10 y50
// y20 y60
// y30 y70
// y01 y41
// y11 y51
// y21 y61
// y31 y71
// in other words, y40 <-> y01, y50 <-> y11, y60 <-> y21, y70 <-> y31
float32x4x2_t tr0, tr1;
T4x4(y00, y10, y20, y30, tr0, tr1);
@ -654,8 +682,8 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
T4x4(y40, y50, y60, y70, tr0, tr1);
T4x4(y41, y51, y61, y71, tr0, tr1);
/* Z[0] = [1.f, 0.f, -5.25f, 0.f, 5.25f, 0.f, -1.f, 0.f]*Y */
/* Z[7] = [0.f, -1.f, 0.f, 5.25f, 0.f, -5.25f, 0.f, 1.f]*Y */
// Z[0] = [1.f, 0.f, -5.25f, 0.f, 5.25f, 0.f, -1.f, 0.f]*Y
// Z[7] = [0.f, -1.f, 0.f, 5.25f, 0.f, -5.25f, 0.f, 1.f]*Y
t00 = vsubq_f32(y01, y20);
t01 = vsubq_f32(y41, y60);
t10 = vsubq_f32(y30, y11);
@ -665,8 +693,8 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
z70 = vfmaq_f32(vsubq_f32(y31, y10), t10, q5_25);
z71 = vfmaq_f32(vsubq_f32(y71, y50), t11, q5_25);
/* Z[1] = [0.f, 1.f, 1.f, -4.25f, -4.25f, 1.f, 1.f, 0.f]*Y */
/* Z[2] = [0.f, -1.f, 1.f, 4.25f, -4.25f, -1.f, 1.f, 0.f]*Y */
// Z[1] = [0.f, 1.f, 1.f, -4.25f, -4.25f, 1.f, 1.f, 0.f]*Y
// Z[2] = [0.f, -1.f, 1.f, 4.25f, -4.25f, -1.f, 1.f, 0.f]*Y
t00 = vfmaq_f32(vaddq_f32(y10, y11), y30, qm4_25);
t01 = vfmaq_f32(vaddq_f32(y50, y51), y70, qm4_25);
t10 = vfmaq_f32(vaddq_f32(y20, y21), y01, qm4_25);
@ -675,8 +703,8 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
z10 = vaddq_f32(t00, t10); z11 = vaddq_f32(t01, t11);
z20 = vsubq_f32(t10, t00); z21 = vsubq_f32(t11, t01);
/* Z[3] = [0.f, 0.5f, 0.25f, -2.5f, -1.25f, 2.f, 1.f, 0.f]*Y */
/* Z[4] = [0.f, -0.5f, 0.25f, 2.5f, -1.25f, -2.f, 1.f, 0.f]*Y */
// Z[3] = [0.f, 0.5f, 0.25f, -2.5f, -1.25f, 2.f, 1.f, 0.f]*Y
// Z[4] = [0.f, -0.5f, 0.25f, 2.5f, -1.25f, -2.f, 1.f, 0.f]*Y
t00 = vfmaq_f32(vaddq_f32(y11, y11), y10, q0_5);
t01 = vfmaq_f32(vaddq_f32(y51, y51), y50, q0_5);
t10 = vfmaq_f32(y21, y20, q0_25);
@ -689,8 +717,8 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
z30 = vaddq_f32(t00, t10); z31 = vaddq_f32(t01, t11);
z40 = vsubq_f32(t10, t00); z41 = vsubq_f32(t11, t01);
/* Z[5] = [0.f, 2.f, 4.f, -2.5f, -5.f, 0.5f, 1.f, 0.f]*Y */
/* Z[6] = [0.f, -2.f, 4.f, 2.5f, -5.f, -0.5f, 1.f, 0.f]*Y */
// Z[5] = [0.f, 2.f, 4.f, -2.5f, -5.f, 0.5f, 1.f, 0.f]*Y
// Z[6] = [0.f, -2.f, 4.f, 2.5f, -5.f, -0.5f, 1.f, 0.f]*Y
t00 = vfmaq_f32(vaddq_f32(y10, y10), y11, q0_5);
t01 = vfmaq_f32(vaddq_f32(y50, y50), y51, q0_5);
t10 = vfmaq_f32(y21, y20, q4);
@ -704,39 +732,42 @@ void winofunc_BtXB_8x8_f32(const float* inptr, int inpstep,
z60 = vsubq_f32(t10, t00); z61 = vsubq_f32(t11, t01);
}
const int outstep = winoIblock*winoAtomF32*Cg;
vst1q_f32(outptr, z00);
vst1q_f32(outptr + outstep, z01);
vst1q_f32(outptr + outstep*2, z10);
vst1q_f32(outptr + outstep*3, z11);
vst1q_f32(outptr + outstep*4, z20);
vst1q_f32(outptr + outstep*5, z21);
vst1q_f32(outptr + outstep*6, z30);
vst1q_f32(outptr + outstep*7, z31);
vst1q_f32(outptr + outstep*8, z40);
vst1q_f32(outptr + outstep*9, z41);
vst1q_f32(outptr + outstep*10, z50);
vst1q_f32(outptr + outstep*11, z51);
vst1q_f32(outptr + outstep*12, z60);
vst1q_f32(outptr + outstep*13, z61);
vst1q_f32(outptr + outstep*14, z70);
vst1q_f32(outptr + outstep*15, z71);
const int outstep = winoIblock*winoAtomF16*Cg;
vst1_f16(outptr, vcvt_f16_f32(z00));
vst1_f16(outptr + 4, vcvt_f16_f32(z01));
vst1_f16(outptr + outstep, vcvt_f16_f32(z10));
vst1_f16(outptr + outstep + 4, vcvt_f16_f32(z11));
vst1_f16(outptr + outstep*2, vcvt_f16_f32(z20));
vst1_f16(outptr + outstep*2 + 4, vcvt_f16_f32(z21));
vst1_f16(outptr + outstep*3, vcvt_f16_f32(z30));
vst1_f16(outptr + outstep*3 + 4, vcvt_f16_f32(z31));
vst1_f16(outptr + outstep*4, vcvt_f16_f32(z40));
vst1_f16(outptr + outstep*4 + 4, vcvt_f16_f32(z41));
vst1_f16(outptr + outstep*5, vcvt_f16_f32(z50));
vst1_f16(outptr + outstep*5 + 4, vcvt_f16_f32(z51));
vst1_f16(outptr + outstep*6, vcvt_f16_f32(z60));
vst1_f16(outptr + outstep*6 + 4, vcvt_f16_f32(z61));
vst1_f16(outptr + outstep*7, vcvt_f16_f32(z70));
vst1_f16(outptr + outstep*7 + 4, vcvt_f16_f32(z71));
}
/*Output transform*/
void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
float* bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct)
// Output transform
void winofunc_AtXA_8x8_F16(const char* _inptr, int inpstep,
float * bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct)
{
float32x4_t x00 = vld1q_f32(inptr), x01 = vld1q_f32(inptr + 4);
float32x4_t x10 = vld1q_f32(inptr + inpstep), x11 = vld1q_f32(inptr + inpstep + 4);
float32x4_t x20 = vld1q_f32(inptr + inpstep*2), x21 = vld1q_f32(inptr + inpstep*2 + 4);
float32x4_t x30 = vld1q_f32(inptr + inpstep*3), x31 = vld1q_f32(inptr + inpstep*3 + 4);
float32x4_t x40 = vld1q_f32(inptr + inpstep*4), x41 = vld1q_f32(inptr + inpstep*4 + 4);
float32x4_t x50 = vld1q_f32(inptr + inpstep*5), x51 = vld1q_f32(inptr + inpstep*5 + 4);
float32x4_t x60 = vld1q_f32(inptr + inpstep*6), x61 = vld1q_f32(inptr + inpstep*6 + 4);
float32x4_t x70 = vld1q_f32(inptr + inpstep*7), x71 = vld1q_f32(inptr + inpstep*7 + 4);
typedef __fp16 float16_t;
const float16_t* inptr = (const float16_t*)_inptr;
float32x4_t x00 = vcvt_f32_f16(vld1_f16(inptr)), x01 = vcvt_f32_f16(vld1_f16(inptr + 4));
float32x4_t x10 = vcvt_f32_f16(vld1_f16(inptr + inpstep)), x11 = vcvt_f32_f16(vld1_f16(inptr + inpstep + 4));
float32x4_t x20 = vcvt_f32_f16(vld1_f16(inptr + inpstep*2)), x21 = vcvt_f32_f16(vld1_f16(inptr + inpstep*2 + 4));
float32x4_t x30 = vcvt_f32_f16(vld1_f16(inptr + inpstep*3)), x31 = vcvt_f32_f16(vld1_f16(inptr + inpstep*3 + 4));
float32x4_t x40 = vcvt_f32_f16(vld1_f16(inptr + inpstep*4)), x41 = vcvt_f32_f16(vld1_f16(inptr + inpstep*4 + 4));
float32x4_t x50 = vcvt_f32_f16(vld1_f16(inptr + inpstep*5)), x51 = vcvt_f32_f16(vld1_f16(inptr + inpstep*5 + 4));
float32x4_t x60 = vcvt_f32_f16(vld1_f16(inptr + inpstep*6)), x61 = vcvt_f32_f16(vld1_f16(inptr + inpstep*6 + 4));
float32x4_t x70 = vcvt_f32_f16(vld1_f16(inptr + inpstep*7)), x71 = vcvt_f32_f16(vld1_f16(inptr + inpstep*7 + 4));
float32x4_t z00, z01, z10, z11, z20, z21, z30, z31, z40, z41, z50, z51;
{
@ -757,33 +788,33 @@ void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
s56_0 = vsubq_f32(x50, x60); s56_1 = vsubq_f32(x51, x61);
float32x4_t y50 = vfmaq_n_f32(vfmaq_n_f32(vaddq_f32(x70, s12_0),
s34_0, 32.f), s56_0, 1.f/32);
s34_0, 32.f), s56_0, 1.f/32);
float32x4_t y51 = vfmaq_n_f32(vfmaq_n_f32(vaddq_f32(x71, s12_1),
s34_1, 32.f), s56_1, 1.f/32);
s34_1, 32.f), s56_1, 1.f/32);
float32x4_t y10 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 2.0f), s56_0, 0.5f);
float32x4_t y11 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 2.0f), s56_1, 0.5f);
float32x4_t y30 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 8.0f), s56_0, 0.125f);
float32x4_t y31 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 8.0f), s56_1, 0.125f);
float32x4_t y60 = vdupq_n_f32(0.f), y61 = y60, y70 = y60, y71 = y60;
/* transpose 8x8 matrix in-place with some renumeration of the elements: */
/* Y: */
/* y00 y01 */
/* y10 y11 */
/* ... */
/* y50 y51 */
/* 0 0 */
/* 0 0 */
/* Y': */
/* y00 y40 */
/* y10 y50 */
/* y20 y60 */
/* y30 y70 */
/* y01 y41 */
/* y11 y51 */
/* y21 y61 */
/* y31 y71 */
/* in other words, y40 <-> y01, y50 <-> y11, y60 <-> y21, y70 <-> y31 */
// transpose 8x8 matrix in-place with some renumeration of the elements:
// Y:
// y00 y01
// y10 y11
// ...
// y50 y51
// 0 0
// 0 0
// Y':
// y00 y40
// y10 y50
// y20 y60
// y30 y70
// y01 y41
// y11 y51
// y21 y61
// y31 y71
// in other words, y40 <-> y01, y50 <-> y11, y60 <-> y21, y70 <-> y31
float32x4x2_t tr0, tr1;
T4x4(y00, y10, y20, y30, tr0, tr1);
@ -807,9 +838,9 @@ void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
s56_0 = vsubq_f32(y11, y21); s56_1 = vsubq_f32(y51, y61);
z50 = vfmaq_n_f32(vfmaq_n_f32(vaddq_f32(y31, s12_0),
s34_0, 32.f), s56_0, 1.f/32);
s34_0, 32.f), s56_0, 1.f/32);
z51 = vfmaq_n_f32(vfmaq_n_f32(vaddq_f32(y71, s12_1),
s34_1, 32.f), s56_1, 1.f/32);
s34_1, 32.f), s56_1, 1.f/32);
z10 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 2.0f), s56_0, 0.5f);
z11 = vfmaq_n_f32(vfmaq_n_f32(s12_1, s34_1, 2.0f), s56_1, 0.5f);
z30 = vfmaq_n_f32(vfmaq_n_f32(s12_0, s34_0, 8.0f), s56_0, 0.125f);
@ -879,8 +910,8 @@ void winofunc_AtXA_8x8_f32(const float* inptr, int inpstep,
vst1q_f32(outptr + outstep*5, z50);
vst1_f32(outptr + outstep*5 + 4, vget_low_f32(z51));
}
#endif
}
#endif
CV_CPU_OPTIMIZATION_NAMESPACE_END
}} // namespace

@ -14,15 +14,76 @@
#include "conv_block.simd.hpp"
#include "layers/cpu_kernels/conv_block.simd_declarations.hpp" // defines CV_CPU_DISPATCH_MODES_ALL=AVX2,...,BASELINE based on CMakeLists.txt content
#include <opencv2/core/utils/logger.hpp>
namespace cv { namespace dnn {
enum { VEC_ALIGN = 32, DFT_TYPE = CV_32F }; // Memory alignment.
enum { VEC_ALIGN = 32}; // Memory alignment.
void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c, const int outLen,
void convBlock_F32(int np, const float* a, const float* b, float* c, int ldc, bool init_c, const int outLen,
const int convMR, const int convNR);
void convBlockMR1(int np, const float* a, const float* b, float *c, const float bias, bool init_c,
void convBlockMR1_F32(int np, const float* a, const float* b, float *c, const float bias, bool init_c,
const float minval, const float maxval, bool ifMinMaxAct, const int outLen, const int convNR);
#ifdef CONV_ARM_FP16
// Fast convert float 32 to float16
static inline void _cvt32f16f(const float* src, float16_t* dst, int len)
{
int j = 0;
const int VECSZ = 4;
__fp16* dst_FP16 = (__fp16 *)dst;
if (len > VECSZ * 4)
{
const int VECSZ4 = 4 * VECSZ;
for( ; j + VECSZ4 < len; j += VECSZ4)
{
float32x4_t v0 = vld1q_f32(src + j);
float32x4_t v1 = vld1q_f32(src + j + 4);
float32x4_t v2 = vld1q_f32(src + j + 8);
float32x4_t v3 = vld1q_f32(src + j + 12);
vst1q_f16(dst_FP16 + j, vcombine_f16(vcvt_f16_f32(v0), vcvt_f16_f32(v1)));
vst1q_f16(dst_FP16 + j + 8, vcombine_f16(vcvt_f16_f32(v2), vcvt_f16_f32(v3)));
}
}
for( ; j < len; j += VECSZ )
{
if( j > len - VECSZ )
{
if( j == 0 )
break;
j = len - VECSZ;
}
float16x4_t hv = vcvt_f16_f32(vld1q_f32(src + j));
vst1_f16(dst_FP16 + j, hv);
}
for( ; j < len; j++ )
dst[j] = float16_t(src[j]);
}
#endif
float* FastConv::getWeights()
{
return alignPtr(weightsBuf.data(), VEC_ALIGN);
}
float* FastConv::getWeightsWino()
{
return alignPtr(weightsWinoBuf.data(), VEC_ALIGN);
}
float16_t* FastConv::getWeightsFP16()
{
return alignPtr(weightsBuf_FP16.data(), VEC_ALIGN);
}
float16_t* FastConv::getWeightsWinoFP16()
{
return alignPtr(weightsWinoBuf_FP16.data(), VEC_ALIGN);
}
Ptr<FastConv> initFastConv(
InputArray _weightsMat,
float* srcBias,
@ -119,9 +180,16 @@ Ptr<FastConv> initFastConv(
conv->useFP16 = false;
#ifdef CONV_ARM_FP16
// TODO: add FP16 support for Winograd.
if (_useFP16 && (conv->conv_type == CONV_TYPE_GENERIC || conv->conv_type == CONV_TYPE_DEPTHWISE_REMAIN))
if (_useFP16 && (conv->conv_type == CONV_TYPE_GENERIC || conv->conv_type == CONV_TYPE_DEPTHWISE_REMAIN
|| conv->conv_type == CONV_TYPE_WINOGRAD3X3))
conv->useFP16 = true;
// Runtime FP16 check.
if (conv->useFP16 && !checkHardwareSupport(CPU_NEON_FP16))
{
conv->useFP16 = false;
CV_LOG_ONCE_WARNING(NULL, "DNN: the CPU does not support the instruction set required by FP16, fallback to FP32.");
}
#endif
float *srcWeights = (float *)weightsMat.data;
@ -141,31 +209,25 @@ Ptr<FastConv> initFastConv(
if (conv->useFP16)
{
conv->weightsBuf_FP16.resize(nweights + VEC_ALIGN);
conv->weightsBufPtr_FP16 = alignPtr(conv->weightsBuf_FP16.data(), VEC_ALIGN * sizeof(float16_t ));
memset(conv->weightsBufPtr_FP16, 0, nweights * sizeof(float16_t ));
auto weightsBufPtr_FP16 = conv->weightsBufPtr_FP16;
auto weightsPtr_FP16 = conv->getWeightsFP16();
memset(reinterpret_cast<short*>(weightsPtr_FP16), 0, nweights * sizeof(weightsPtr_FP16[0]));
parallel_for_(Range(0, C), [&](const Range& r0){
for(int c = r0.start; c < r0.end; c++)
{
for (int k = 0; k < ksize; k++)
weightsBufPtr_FP16[c*padded_ksize + k] = (float16_t)srcWeights[c*wstep + k];
}});
for(int c = r0.start; c < r0.end; c++)
_cvt32f16f(srcWeights + c*wstep, weightsPtr_FP16 + c*padded_ksize, ksize);
});
}
else
#endif
{
conv->weightsBuf.resize(nweights + VEC_ALIGN);
conv->weightsBufPtr = alignPtr(conv->weightsBuf.data(), VEC_ALIGN * sizeof(float ));
memset(conv->weightsBufPtr, 0, nweights*sizeof(float ));
auto weightsBufPtr = conv->weightsBufPtr;
auto weightsPtr = conv->getWeights();
memset(weightsPtr, 0, nweights*sizeof(weightsPtr[0]));
parallel_for_(Range(0, C), [&](const Range& r0){
for(int c = r0.start; c < r0.end; c++)
{
for (int k = 0; k < ksize; k++)
weightsBufPtr[c*padded_ksize + k] = srcWeights[c*wstep + k];
}});
parallel_for_(Range(0, C), [&](const Range& r0) {
for(int c = r0.start; c < r0.end; c++)
memcpy(weightsPtr + c*padded_ksize, srcWeights + c*wstep, ksize*sizeof(weightsPtr[0]));
});
}
}
else if(conv->conv_type == CONV_TYPE_WINOGRAD3X3) // winograd
@ -213,16 +275,14 @@ Ptr<FastConv> initFastConv(
if (conv->useFP16)
{
conv->weightsWinoBuf_FP16.resize(nweights + VEC_ALIGN);
conv->weightsWinoBufPtr_FP16 = alignPtr(conv->weightsWinoBuf_FP16.data(), VEC_ALIGN);
wptrWino_FP16 = conv->weightsWinoBufPtr_FP16;
memset(wptrWino_FP16, 0, nweights * sizeof(wptrWino_FP16[0]));
wptrWino_FP16 = conv->getWeightsWinoFP16();
memset(reinterpret_cast<short*>(wptrWino_FP16), 0, nweights * sizeof(wptrWino_FP16[0]));
}
else
#endif
{
conv->weightsWinoBuf.resize(nweights + VEC_ALIGN);
conv->weightsWinoBufPtr = alignPtr(conv->weightsWinoBuf.data(), VEC_ALIGN);
wptrWino = conv->weightsWinoBufPtr;
wptrWino = conv->getWeightsWino();
memset(wptrWino, 0, nweights * sizeof(wptrWino[0]));
}
@ -272,7 +332,7 @@ Ptr<FastConv> initFastConv(
for (int i = 0; i < CONV_WINO_NATOMS_F16; i++,
wptr += Cg * CONV_WINO_KBLOCK * CONV_WINO_ATOM_F16)
{
CV_Assert(conv->weightsWinoBufPtr_FP16 <= wptr && wptr + CONV_WINO_ATOM_F16 <= conv->weightsWinoBufPtr_FP16 + nweights);
CV_Assert(wptrWino_FP16 <= wptr && wptr + CONV_WINO_ATOM_F16 <= wptrWino_FP16 + nweights);
for (int j = 0; j < CONV_WINO_ATOM_F16; j++)
{
wptr[j] = (float16_t)kernelTm[i * CONV_WINO_ATOM_F16 + j];
@ -287,7 +347,7 @@ Ptr<FastConv> initFastConv(
for (int i = 0; i < CONV_WINO_NATOMS_F32; i++,
wptr += Cg * CONV_WINO_KBLOCK * CONV_WINO_ATOM_F32)
{
CV_Assert(conv->weightsWinoBufPtr <= wptr && wptr + CONV_WINO_ATOM_F32 <= conv->weightsWinoBufPtr + nweights);
CV_Assert(wptrWino <= wptr && wptr + CONV_WINO_ATOM_F32 <= wptrWino + nweights);
memcpy(wptr, kernelTm + i * CONV_WINO_ATOM_F32, CONV_WINO_ATOM_F32*sizeof (wptr[0]));
}
}
@ -305,29 +365,26 @@ Ptr<FastConv> initFastConv(
int numStripsMR = (Kg + CONV_MR_FP32 - 1) / CONV_MR_FP32;
int Kg_aligned = numStripsMR * CONV_MR_FP32;
size_t nweights = ngroups*Kg_aligned*DkHkWkCg;
float* weightsBufPtr = nullptr;
float* weightsPtr = nullptr;
#ifdef CONV_ARM_FP16
int numStripsMR_FP16 = (Kg + CONV_MR_FP16 - 1) / CONV_MR_FP16;
int Kg_aligned_FP16 = numStripsMR_FP16 * CONV_MR_FP16;
size_t nweights_FP16 = ngroups * Kg_aligned_FP16 * DkHkWkCg;
float16_t* weightsPtr_FP16 = nullptr;
float16_t* weightsBufPtr_FP16 = nullptr;
if (conv->useFP16)
{
conv->weightsBuf_FP16.resize(nweights_FP16 + VEC_ALIGN);
conv->weightsBufPtr_FP16 = alignPtr(conv->weightsBuf_FP16.data(), VEC_ALIGN);
weightsBufPtr_FP16 = conv->weightsBufPtr_FP16;
memset(weightsBufPtr_FP16, 0, nweights_FP16*sizeof(weightsBufPtr_FP16[0]));
weightsPtr_FP16 = conv->getWeightsFP16();
memset(reinterpret_cast<short*>(weightsPtr_FP16), 0, nweights_FP16*sizeof(weightsPtr_FP16[0]));
}
else
#endif
{
conv->weightsBuf.resize(nweights + VEC_ALIGN);
conv->weightsBufPtr = alignPtr(conv->weightsBuf.data(), VEC_ALIGN);
weightsBufPtr = conv->weightsBufPtr;
memset(weightsBufPtr, 0, nweights*sizeof(weightsBufPtr[0]));
weightsPtr = conv->getWeights();
memset(weightsPtr, 0, nweights*sizeof(weightsPtr[0]));
}
// Pack the weight.
@ -343,7 +400,7 @@ Ptr<FastConv> initFastConv(
int startK = si * CONV_MR_FP16;
CV_Assert(startK < Kg_aligned_FP16);
float16_t* packed_wptr = weightsBufPtr_FP16 + DkHkWkCg * (startK + g * Kg_aligned_FP16);
float16_t* packed_wptr = weightsPtr_FP16 + DkHkWkCg * (startK + g * Kg_aligned_FP16);
int dk = Kg - startK < CONV_MR_FP16 ? Kg - startK : CONV_MR_FP16; // check if we need zero padding.
int k_idx = g*Kg + startK;
@ -373,7 +430,7 @@ Ptr<FastConv> initFastConv(
int startK = si * CONV_MR_FP32;
CV_Assert(startK < Kg_aligned);
float* packed_wptr = weightsBufPtr + DkHkWkCg * (startK + g * Kg_aligned);
float* packed_wptr = weightsPtr + DkHkWkCg * (startK + g * Kg_aligned);
int dk = Kg - startK < CONV_MR_FP32 ? Kg - startK : CONV_MR_FP32; // check if we need zero padding.
int k_idx = g*Kg + startK;
@ -410,7 +467,7 @@ Ptr<FastConv> initFastConv(
}
static inline void packData8(char*& inpbuf, float*& inptrIn, int& in_w, int& x0, int& s0, const int* ofstab,
const int stride_w, const int ksize, const int esz)
const int stride_w, const int ksize, const int esz)
{
char * inpbufC = inpbuf + s0 * esz;
float* inptrInC = (float* )inptrIn;
@ -516,7 +573,7 @@ static inline void packData8(char*& inpbuf, float*& inptrIn, int& in_w, int& x0,
}
static inline void packData2(char *& inpbuf, float*& inptrIn, int& in_w, int& x0, int& s0, const int* ofstab,
const int stride_w, const int ksize, const int esz)
const int stride_w, const int ksize, const int esz)
{
char* inpbufC = inpbuf + s0 * esz;
float* inptrInC = inptrIn;
@ -553,46 +610,6 @@ static inline void packData2(char *& inpbuf, float*& inptrIn, int& in_w, int& x0
in_w += stride_w;
}
#ifdef CONV_ARM_FP16
// Fast convert float 32 to float16
static inline void _cvt32f16f( const float* src, float16_t* dst, int len)
{
int j = 0;
const int VECSZ = 4;
__fp16* dst_FP16 = (__fp16 *)dst;
if (len > VECSZ * 4)
{
const int VECSZ4 = 4 * VECSZ;
for( ; j + VECSZ4 < len; j += VECSZ4)
{
float32x4_t v0 = vld1q_f32(src + j);
float32x4_t v1 = vld1q_f32(src + j + 4);
float32x4_t v2 = vld1q_f32(src + j + 8);
float32x4_t v3 = vld1q_f32(src + j + 12);
vst1q_f16(dst_FP16 + j, vcombine_f16(vcvt_f16_f32(v0), vcvt_f16_f32(v1)));
vst1q_f16(dst_FP16 + j + 8, vcombine_f16(vcvt_f16_f32(v2), vcvt_f16_f32(v3)));
}
}
for( ; j < len; j += VECSZ )
{
if( j > len - VECSZ )
{
if( j == 0 )
break;
j = len - VECSZ;
}
float16x4_t hv = vcvt_f16_f32(vld1q_f32(src + j));
vst1_f16(dst_FP16 + j, hv);
}
for( ; j < len; j++ )
dst[j] = float16_t(src[j]);
}
#endif
static inline void packInputData(char* inpbuf_task, float* inp, const int* ofstab, const int* dhwTab, int zyx0, int zyx_limit,
int ksize, int stride_d, int stride_h, int stride_w, int pad_front, int pad_top, int pad_left,
int Dk, int Hk, int Wk, int dilation_d, int dilation_h, int dilation_w, int Di, int Hi, int Wi,
@ -1174,10 +1191,9 @@ void runFastConv(InputArray _input, OutputArray _output, const Ptr<FastConv>& co
else
activ = nullptr;
// TODO: support FP16 for winograd.
if (conv->conv_type == CONV_TYPE_WINOGRAD3X3) // winograd
{
CV_Assert(conv->weightsWinoBufPtr && input.dims == 4 && conv_dim == CONV_2D && !useFP16);
CV_Assert((!conv->weightsWinoBuf.empty() || !conv->weightsWinoBuf_FP16.empty()) && input.dims == 4 && conv_dim == CONV_2D);
if (runWinograd63(input, fusedAddMat, output, conv, ntasks, minval, maxval, activ, ifMinMaxAct))
return;
}
@ -1437,13 +1453,13 @@ void runFastConv(InputArray _input, OutputArray _output, const Ptr<FastConv>& co
if (useFP16)
{
CV_Assert(!conv->weightsBuf_FP16.empty());
weights = (char *)conv->weightsBufPtr_FP16;
weights = (char *)conv->getWeightsFP16();
}
else
#endif
{
CV_Assert(!conv->weightsBuf.empty());
weights = (char *)conv->weightsBufPtr;
weights = (char *)conv->getWeights();
}
// optional branch, only for depth-wise convolution which was implemented by generic convolution.
// In this case, CONV_MR is 1, and CONV_NR remains the same.
@ -1477,7 +1493,7 @@ void runFastConv(InputArray _input, OutputArray _output, const Ptr<FastConv>& co
#ifdef CONV_ARM_FP16
if (useFP16)
{
opt_NEON::convBlockMR1_FP16(DkHkWkCg, weights, inptr, cptr, biasVal, fusedAdd, minval, maxval, ifMinMaxAct, outLen, CONV_NR);
opt_NEON_FP16::convBlockMR1_F16(DkHkWkCg, weights, inptr, cptr, biasVal, fusedAdd, minval, maxval, ifMinMaxAct, outLen, CONV_NR);
}
else
#endif
@ -1485,7 +1501,7 @@ void runFastConv(InputArray _input, OutputArray _output, const Ptr<FastConv>& co
}
else
#endif
convBlockMR1(DkHkWkCg, (const float *)weights, (const float *)inptr, cptr, biasVal, fusedAdd, minval, maxval, ifMinMaxAct, outLen, CONV_NR);
convBlockMR1_F32(DkHkWkCg, (const float *)weights, (const float *)inptr, cptr, biasVal, fusedAdd, minval, maxval, ifMinMaxAct, outLen, CONV_NR);
if (ifBuffer)
{
@ -1526,12 +1542,12 @@ void runFastConv(InputArray _input, OutputArray _output, const Ptr<FastConv>& co
{
#if CV_TRY_AVX2
if (conv->useAVX2)
opt_AVX2::convBlock(c1 - c0, (const float *)wptr, (const float *)inptr, cptr, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
opt_AVX2::convBlock_F32(c1 - c0, (const float *)wptr, (const float *)inptr, cptr, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
else
#endif
#if CV_TRY_AVX
if (conv->useAVX)
opt_AVX::convBlock(c1 - c0, (const float *)wptr, (const float *)inptr, cptr, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
opt_AVX::convBlock_F32(c1 - c0, (const float *)wptr, (const float *)inptr, cptr, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
else
#endif
#if CV_NEON
@ -1540,16 +1556,16 @@ void runFastConv(InputArray _input, OutputArray _output, const Ptr<FastConv>& co
#ifdef CONV_ARM_FP16
if (useFP16)
{
opt_NEON::convBlock_FP16(c1 - c0, wptr, inptr, (char *)cptr_f16, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
opt_NEON_FP16::convBlock_F16(c1 - c0, wptr, inptr, (char *)cptr_f16, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
}
else
#endif
opt_NEON::convBlock(c1 - c0, (const float *)wptr, (const float *)inptr, cptr, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
opt_NEON::convBlock_F32(c1 - c0, (const float *)wptr, (const float *)inptr, cptr, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
}
else
#endif
// The possible outLen range is 24 or 8~1.
convBlock(c1 - c0, (const float *)wptr, (const float *)inptr, cptr, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
convBlock_F32(c1 - c0, (const float *)wptr, (const float *)inptr, cptr, ldc, c0 == 0, outLen, CONV_MR, CONV_NR);
}
}
}
@ -1838,7 +1854,7 @@ static inline void convBlockMR1x12(int np, const float* a, const float* b, float
}
#endif
void convBlockMR1(int np, const float* a, const float* b, float *c, const float bias, bool init_c,
void convBlockMR1_F32(int np, const float* a, const float* b, float *c, const float bias, bool init_c,
const float minval, const float maxval, bool ifMinMaxAct, const int outLen, const int convNR)
{
#if CV_SIMD128
@ -2088,7 +2104,7 @@ static inline void convBlockNoSIMD(int np, const float* a, const float* b, float
}
}
void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c, const int outLen,
void convBlock_F32(int np, const float* a, const float* b, float* c, int ldc, bool init_c, const int outLen,
const int convMR, const int convNR)
{
// The possible outLen range is [24, 8~1].

@ -14,7 +14,7 @@
#define CONV_NR_FP32 28
// The FP16 can only be supported by ARM64 and with FP16 FMA supported.
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && CV_FP16 // check FP16 FMA.
#if CV_FP16 // check FP16 FMA.
#define CONV_ARM_FP16 1
#endif
@ -22,7 +22,6 @@
// Currently, only ARM 64 support FP16.
#define CONV_MR_FP16 8
#define CONV_NR_FP16 24
typedef __fp16 float16_t; // Fix conflict between float16_t in arm_neon.h and float16_t in cvdef.h.
#endif
#elif CV_NEON // 16 registers.
@ -58,17 +57,15 @@ struct FastConv
int pad_top, pad_bottom, pad_left, pad_right, pad_front, pad_behind;
std::vector<float> weightsBuf; // For generic Conv 2D
float* weightsBufPtr;
std::vector<float> weightsWinoBuf; // For Winograd F(6x6, 3x3).
float* weightsWinoBufPtr;
std::vector<float> biasBuf;
float* getWeights();
float* getWeightsWino();
#if CV_NEON && CV_NEON_AARCH64 && CV_FP16
std::vector<float16_t> weightsBuf_FP16;
float16_t* weightsBufPtr_FP16;
std::vector<float16_t> weightsWinoBuf_FP16;
float16_t* weightsWinoBufPtr_FP16;
#endif
float16_t* getWeightsFP16();
float16_t* getWeightsWinoFP16();
int conv_type;
int conv_dim; // Flag for conv1d, conv2d, or conv3d.
@ -115,6 +112,32 @@ void runDepthwise(InputArray _input, OutputArray _output, const Ptr<FastConv>& c
int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv>& conv, int ntasks,
float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct);
// Work around of NEON, the following functions are only used internally.
namespace opt_NEON {
#if CV_NEON
void convBlock_F32(int np, const float* a, const float* b, float* c, int ldc, bool init_c, int width, const int convMR, const int convNR);
void convBlockMR1_F32(int np, const float* a, const float* b, float* c, const float bias, bool init_c,
const float minval, const float maxval, bool ifMinMaxAct, const int width, const int convNR);
#if CV_NEON_AARCH64
/* Accumulate */
void winofunc_accum_F32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock,
const int winoIblock, const int winoKblock, const int winoAtom, const int winoNatom);
/*Input transform*/
void winofunc_BtXB_8x8_F32(const float* inptr, int inpstep,
float* outptr, int Cg, const int winoIblock, const int winoAtom);
/*Output transform*/
void winofunc_AtXA_8x8_F32(const float* inptr, int inpstep,
float* bpptr, int bpstep, float* outptr, int outstep,
float bias, float minval, float maxval, bool ifMinMaxAct);
#endif // CV_NEON_AARCH64
#endif // CV_NEON
} // namespace opt_NEON.
} // namespace dnn
} // namespace cv

@ -37,6 +37,7 @@ public:
virtual void setPreferableBackend(Backend backendId) { net.setPreferableBackend(backendId); }
virtual void setPreferableTarget(Target targetId) { net.setPreferableTarget(targetId); }
virtual void enableWinograd(bool useWinograd) { net.enableWinograd(useWinograd); }
virtual
void initNet(const Net& network)
@ -151,6 +152,7 @@ Model& Model::setPreferableBackend(Backend backendId)
impl->setPreferableBackend(backendId);
return *this;
}
Model& Model::setPreferableTarget(Target targetId)
{
CV_DbgAssert(impl);
@ -158,6 +160,13 @@ Model& Model::setPreferableTarget(Target targetId)
return *this;
}
Model& Model::enableWinograd(bool useWinograd)
{
CV_DbgAssert(impl);
impl->enableWinograd(useWinograd);
return *this;
}
Model& Model::setInputSize(const Size& size)
{
CV_DbgAssert(impl);

@ -49,7 +49,10 @@ public:
net.setInput(inp);
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
net.enableWinograd(useWinograd);
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
if (backend == DNN_BACKEND_HALIDE && !halideScheduler.empty())
{
halideScheduler = findDataFile(halideScheduler);
@ -545,7 +548,7 @@ TEST_P(DNNTestNetwork, FastNeuralStyle_eccv16)
else if (target == DNN_TARGET_CPU_FP16)
{
l1 = 0.4;
lInf = 19.;
lInf = 22.;
}
else if (target == DNN_TARGET_VULKAN)
{

@ -62,6 +62,10 @@ public:
findDataFile("dnn/" + model, false));
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
Mat img = imread(findDataFile("dnn/dog416.png"));
resize(img, img, Size(800, 600));
Mat blob = blobFromImage(img, 1.0, Size(), Scalar(102.9801, 115.9465, 122.7717), false, false);
@ -219,6 +223,9 @@ TEST_P(Reproducibility_AlexNet, Accuracy)
net.setPreferableBackend(DNN_BACKEND_OPENCV);
net.setPreferableTarget(targetId);
if (targetId == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
Mat sample = imread(_tf("grace_hopper_227.png"));
ASSERT_TRUE(!sample.empty());
@ -383,6 +390,9 @@ TEST_P(Reproducibility_ResNet50, Accuracy)
net.setPreferableBackend(DNN_BACKEND_OPENCV);
net.setPreferableTarget(targetId);
if (targetId == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
float l1 = (targetId == DNN_TARGET_OPENCL_FP16 || targetId == DNN_TARGET_CPU_FP16) ? 3e-5 : 1e-5;
float lInf = (targetId == DNN_TARGET_OPENCL_FP16 || targetId == DNN_TARGET_CPU_FP16) ? 6e-3 : 1e-4;
@ -503,6 +513,10 @@ TEST_P(Test_Caffe_nets, Colorization)
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
// This model has bad accuracy when the FP16 and Winograd are enable at same time.
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
net.getLayer(net.getLayerId("class8_ab"))->blobs.push_back(kernel);
net.getLayer(net.getLayerId("conv8_313_rh"))->blobs.push_back(Mat(1, 313, CV_32F, 2.606));
@ -568,10 +582,15 @@ TEST_P(Test_Caffe_nets, DenseNet_121)
{
l1 = 0.11; lInf = 0.5;
}
else if (target == DNN_TARGET_CUDA_FP16 || target == DNN_TARGET_CPU_FP16)
else if (target == DNN_TARGET_CUDA_FP16)
{
l1 = 0.04; lInf = 0.2;
}
else if (target == DNN_TARGET_CPU_FP16)
{
l1 = 0.06; lInf = 0.3;
}
normAssert(outs[0], ref, "", l1, lInf);
if (target != DNN_TARGET_MYRIAD || getInferenceEngineVPUType() != CV_DNN_INFERENCE_ENGINE_VPU_TYPE_MYRIAD_X)
expectNoFallbacksFromIE(model.getNetwork_());

@ -473,7 +473,8 @@ TEST_P(Test_Darknet_nets, TinyYoloVoc)
1, 6, 0.928758f, 0.651024f, 0.463539f, 0.823784f, 0.654998f); // a car
double scoreDiff = 8e-5, iouDiff = 3e-4;
if (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD || target == DNN_TARGET_CPU_FP16)
bool useWinograd = true;
if (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD)
{
scoreDiff = 8e-3;
iouDiff = 0.018;
@ -483,18 +484,24 @@ TEST_P(Test_Darknet_nets, TinyYoloVoc)
scoreDiff = 0.008;
iouDiff = 0.02;
}
else if (target == DNN_TARGET_CPU_FP16)
{
useWinograd = false;
scoreDiff = 8e-3;
iouDiff = 0.018;
}
std::string config_file = "tiny-yolo-voc.cfg";
std::string weights_file = "tiny-yolo-voc.weights";
{
SCOPED_TRACE("batch size 1");
testDarknetModel(config_file, weights_file, ref.rowRange(0, 2), scoreDiff, iouDiff);
testDarknetModel(config_file, weights_file, ref.rowRange(0, 2), scoreDiff, iouDiff, 0.24, 0.4, useWinograd);
}
{
SCOPED_TRACE("batch size 2");
testDarknetModel(config_file, weights_file, ref, scoreDiff, iouDiff);
testDarknetModel(config_file, weights_file, ref, scoreDiff, iouDiff, 0.24, 0.4, useWinograd);
}
}
@ -890,12 +897,12 @@ TEST_P(Test_Darknet_nets, YOLOv4_tiny)
{
SCOPED_TRACE("batch size 1");
testDarknetModel(config_file, weights_file, ref.rowRange(0, N0), scoreDiff, iouDiff, confThreshold);
testDarknetModel(config_file, weights_file, ref.rowRange(0, N0), scoreDiff, iouDiff, confThreshold, 0.4, false);
}
{
SCOPED_TRACE("batch size 2");
testDarknetModel(config_file, weights_file, ref, scoreDiff, iouDiff, confThreshold);
testDarknetModel(config_file, weights_file, ref, scoreDiff, iouDiff, confThreshold, 0.4, false);
}
#if defined(INF_ENGINE_RELEASE)

@ -40,6 +40,8 @@ public:
model.setPreferableTarget(target);
model.setNmsAcrossClasses(nmsAcrossClasses);
if (target == DNN_TARGET_CPU_FP16)
model.enableWinograd(false);
std::vector<int> classIds;
std::vector<float> confidences;

@ -55,7 +55,7 @@ public:
void testONNXModels(const String& basename, const Extension ext = npy,
double l1 = 0, double lInf = 0, const bool useSoftmax = false,
bool checkNoFallbacks = true, int numInps = 1,
bool testShapes = true)
bool testShapes = true, bool useWinograd = true)
{
String onnxmodel = _tf("models/" + basename + ".onnx", required);
std::vector<Mat> inps(numInps);
@ -82,6 +82,7 @@ public:
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
net.enableWinograd(useWinograd);
std::vector<String> inputNames;
for (int i = 0; i < numInps; ++i)
@ -1929,7 +1930,9 @@ TEST_P(Test_ONNX_layers, ConvResizePool1d)
#endif
}
#endif
testONNXModels("conv_resize_pool_1d");
const double lInf = (target == DNN_TARGET_CPU_FP16) ? 0.024 : default_lInf;
testONNXModels("conv_resize_pool_1d", npy, default_l1, lInf);
}
TEST_P(Test_ONNX_layers, DepthWiseAdd)
@ -2130,6 +2133,7 @@ TEST_P(Test_ONNX_nets, Alexnet)
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
net.enableWinograd(false);
Mat inp = imread(_tf("../grace_hopper_227.png"));
Mat ref = blobFromNPY(_tf("../caffe_alexnet_prob.npy"));
@ -2202,6 +2206,9 @@ TEST_P(Test_ONNX_nets, Googlenet)
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
std::vector<Mat> images;
images.push_back( imread(_tf("../googlenet_0.png")) );
images.push_back( imread(_tf("../googlenet_1.png")) );
@ -2346,7 +2353,7 @@ TEST_P(Test_ONNX_nets, TinyYolov2)
}
#endif
testONNXModels("tiny_yolo2", pb, l1, lInf);
testONNXModels("tiny_yolo2", pb, l1, lInf, false, true, 1, true, false);
}
TEST_P(Test_ONNX_nets, CNN_MNIST)
@ -2391,6 +2398,7 @@ TEST_P(Test_ONNX_nets, LResNet100E_IR)
double l1 = default_l1, lInf = default_lInf;
// output range: [-3; 3]
bool useWinograd = true;
if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
{
l1 = 0.009;
@ -2406,7 +2414,14 @@ TEST_P(Test_ONNX_nets, LResNet100E_IR)
l1 = 0.009;
lInf = 0.04;
}
testONNXModels("LResNet100E_IR", pb, l1, lInf);
else if (target == DNN_TARGET_CPU_FP16)
{
useWinograd = false;
l1 = 0.009;
lInf = 0.035;
}
testONNXModels("LResNet100E_IR", pb, l1, lInf, false, true, 1, true, useWinograd);
}
TEST_P(Test_ONNX_nets, Emotion_ferplus)
@ -2421,7 +2436,7 @@ TEST_P(Test_ONNX_nets, Emotion_ferplus)
double l1 = default_l1;
double lInf = default_lInf;
bool useWinograd = true;
// Output values are in range [-2.011, 2.111]
if ((backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) || (target == DNN_TARGET_CUDA_FP16))
l1 = 0.007;
@ -2434,6 +2449,11 @@ TEST_P(Test_ONNX_nets, Emotion_ferplus)
l1 = 2.4e-4;
lInf = 6e-4;
}
else if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_CPU_FP16)
{
useWinograd = false;
l1 = 0.007;
}
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_GE(2020040000)
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && target == DNN_TARGET_OPENCL_FP16)
{
@ -2441,7 +2461,7 @@ TEST_P(Test_ONNX_nets, Emotion_ferplus)
}
#endif
testONNXModels("emotion_ferplus", pb, l1, lInf);
testONNXModels("emotion_ferplus", pb, l1, lInf, false, true, 1, true, useWinograd);
}
TEST_P(Test_ONNX_nets, Inception_v2)

@ -974,6 +974,9 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
net.setInput(blob);
// Output has shape 1x1xNx7 where N - number of detections.
// An every detection is a vector of values [id, classId, confidence, left, top, right, bottom]
@ -1307,6 +1310,8 @@ TEST_P(Test_TensorFlow_nets, EAST_text_detection)
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
Mat img = imread(imgPath);
Mat inp = blobFromImage(img, 1.0, Size(), Scalar(123.68, 116.78, 103.94), true, false);
@ -1341,8 +1346,9 @@ TEST_P(Test_TensorFlow_nets, EAST_text_detection)
}
else if (target == DNN_TARGET_CPU_FP16)
{
lInf_scores = 0.1;
l1_geometry = 0.28; lInf_geometry = 5.94;
lInf_scores = 0.17;
l1_geometry = 0.28;
lInf_geometry = 5.94;
}
else
{
@ -1810,6 +1816,8 @@ TEST_P(Test_TensorFlow_nets, Mask_RCNN)
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
net.setInput(blob);

@ -358,6 +358,8 @@ TEST_P(Test_Torch_nets, OpenFace_accuracy)
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
Mat sample = imread(findDataFile("cv/shared/lena.png"));
Mat sampleF32(sample.size(), CV_32FC3);
@ -542,6 +544,9 @@ TEST_P(Test_Torch_nets, FastNeuralStyle_accuracy)
Mat img = imread(findDataFile("dnn/googlenet_1.png"));
Mat inputBlob = blobFromImage(img, 1.0, Size(), Scalar(103.939, 116.779, 123.68), false);
if (target == DNN_TARGET_CPU_FP16)
net.enableWinograd(false);
net.setInput(inputBlob);
Mat out = net.forward();
@ -570,7 +575,7 @@ TEST_P(Test_Torch_nets, FastNeuralStyle_accuracy)
}
else if (target == DNN_TARGET_CPU_FP16)
{
normAssert(out, refBlob, "", 0.64, 25);
normAssert(out, refBlob, "", 0.7, 25);
}
else
normAssert(out, refBlob, "", 0.5, 1.16);

Loading…
Cancel
Save