From a75840d19c6ee1f04ec05e039db1869a4aeadaee Mon Sep 17 00:00:00 2001 From: Arjan van de Ven Date: Wed, 31 Jan 2018 05:34:12 -0800 Subject: [PATCH] Merge pull request #10468 from fenrus75:avx512-2 * Add a 512 bit codepath to the AVX512 fastConv function this patch adds a 512 wide codepath to the fastConv() function for AVX512 use. The basic idea is to process the first N * 16 elements of the vector with avx512, and then run the rest of the vector using the traditional AVX2 codepath. * dnn: use unaligned AVX512 load (OpenCV aligns data on 32-byte boundary) * dnn: change "vecsize" condition for AVX512 * dnn: fix indentation --- modules/dnn/src/layers/layers_common.simd.hpp | 61 ++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/modules/dnn/src/layers/layers_common.simd.hpp b/modules/dnn/src/layers/layers_common.simd.hpp index 99d5538631..bee3e912e1 100644 --- a/modules/dnn/src/layers/layers_common.simd.hpp +++ b/modules/dnn/src/layers/layers_common.simd.hpp @@ -112,6 +112,7 @@ void fastConv( const float* weights, size_t wstep, const float* bias, int j = 0; for( ; j <= blockSize - 4; j += 4 ) { + int k = 0; const float* rptr = rowbuf + j*vecsize_aligned; __m256 vs00 = _mm256_setzero_ps(), vs01 = _mm256_setzero_ps(), @@ -121,7 +122,65 @@ void fastConv( const float* weights, size_t wstep, const float* bias, vs20 = _mm256_setzero_ps(), vs21 = _mm256_setzero_ps(), vs22 = _mm256_setzero_ps(), vs23 = _mm256_setzero_ps(); - for( int k = 0; k < vecsize; k += 8, rptr += 8 ) +#if CV_AVX512_SKX // AVX512VL is necessary to avoid register spilling + if (vecsize >= 32) + { + __m512 vs00_5 = _mm512_setzero_ps(), vs01_5 = _mm512_setzero_ps(), + vs02_5 = _mm512_setzero_ps(), vs03_5 = _mm512_setzero_ps(), + vs10_5 = _mm512_setzero_ps(), vs11_5 = _mm512_setzero_ps(), + vs12_5 = _mm512_setzero_ps(), vs13_5 = _mm512_setzero_ps(), + vs20_5 = _mm512_setzero_ps(), vs21_5 = _mm512_setzero_ps(), + vs22_5 = _mm512_setzero_ps(), vs23_5 = _mm512_setzero_ps(); + + for (; k <= vecsize - 16; k += 16, rptr += 16) + { + __m512 w0 = _mm512_loadu_ps(wptr0 + k); + __m512 w1 = _mm512_loadu_ps(wptr1 + k); + __m512 w2 = _mm512_loadu_ps(wptr2 + k); + __m512 r0 = _mm512_loadu_ps(rptr); + + vs00_5 = _mm512_fmadd_ps(w0, r0, vs00_5); + vs10_5 = _mm512_fmadd_ps(w1, r0, vs10_5); + vs20_5 = _mm512_fmadd_ps(w2, r0, vs20_5); + + r0 = _mm512_loadu_ps(rptr + vecsize_aligned); + vs01_5 = _mm512_fmadd_ps(w0, r0, vs01_5); + vs11_5 = _mm512_fmadd_ps(w1, r0, vs11_5); + vs21_5 = _mm512_fmadd_ps(w2, r0, vs21_5); + + r0 = _mm512_loadu_ps(rptr + vecsize_aligned*2); + vs02_5 = _mm512_fmadd_ps(w0, r0, vs02_5); + vs12_5 = _mm512_fmadd_ps(w1, r0, vs12_5); + vs22_5 = _mm512_fmadd_ps(w2, r0, vs22_5); + + r0 = _mm512_loadu_ps(rptr + vecsize_aligned*3); + vs03_5 = _mm512_fmadd_ps(w0, r0, vs03_5); + vs13_5 = _mm512_fmadd_ps(w1, r0, vs13_5); + vs23_5 = _mm512_fmadd_ps(w2, r0, vs23_5); + } + /* + * now fold the 512 bit accumulator vectors into 256 bit vectors so that the AVX2 code can finish + * the tail of the vector + */ + vs00 = _mm256_add_ps( _mm512_extractf32x8_ps(vs00_5, 0), _mm512_extractf32x8_ps(vs00_5, 1)); + vs10 = _mm256_add_ps( _mm512_extractf32x8_ps(vs10_5, 0), _mm512_extractf32x8_ps(vs10_5, 1)); + vs20 = _mm256_add_ps( _mm512_extractf32x8_ps(vs20_5, 0), _mm512_extractf32x8_ps(vs20_5, 1)); + + vs01 = _mm256_add_ps( _mm512_extractf32x8_ps(vs01_5, 0), _mm512_extractf32x8_ps(vs01_5, 1)); + vs11 = _mm256_add_ps( _mm512_extractf32x8_ps(vs11_5, 0), _mm512_extractf32x8_ps(vs11_5, 1)); + vs21 = _mm256_add_ps( _mm512_extractf32x8_ps(vs21_5, 0), _mm512_extractf32x8_ps(vs21_5, 1)); + + vs02 = _mm256_add_ps( _mm512_extractf32x8_ps(vs02_5, 0), _mm512_extractf32x8_ps(vs02_5, 1)); + vs12 = _mm256_add_ps( _mm512_extractf32x8_ps(vs12_5, 0), _mm512_extractf32x8_ps(vs12_5, 1)); + vs22 = _mm256_add_ps( _mm512_extractf32x8_ps(vs22_5, 0), _mm512_extractf32x8_ps(vs22_5, 1)); + + vs03 = _mm256_add_ps( _mm512_extractf32x8_ps(vs03_5, 0), _mm512_extractf32x8_ps(vs03_5, 1)); + vs13 = _mm256_add_ps( _mm512_extractf32x8_ps(vs13_5, 0), _mm512_extractf32x8_ps(vs13_5, 1)); + vs23 = _mm256_add_ps( _mm512_extractf32x8_ps(vs23_5, 0), _mm512_extractf32x8_ps(vs23_5, 1)); + } +#endif + + for (; k < vecsize; k += 8, rptr += 8 ) { __m256 w0 = _mm256_load_ps(wptr0 + k); __m256 w1 = _mm256_load_ps(wptr1 + k);