x86/tx_float: implement inverse MDCT AVX2 assembly

This commit implements an iMDCT in pure assembly.

This is capable of processing any mod-8 transforms, rather than just
power of two, but since power of two is all we have assembly for
currently, that's what's supported.
It would really benefit if we could somehow use the C code to decide
which function to jump into, but exposing function labels from assebly
into C is anything but easy.
The post-transform loop could probably be improved.

This was somewhat annoying to write, as we must support arbitrary
strides during runtime. There's a fast branch for stride == 4 bytes
and a slower one which uses vgatherdps.

Zen 3 benchmarks for stride == 4 for old (av_imdct_half) vs new (av_tx):

128pt:
   2811 decicycles in         av_tx (imdct),16775916 runs,   1300 skips
   3082 decicycles in         av_imdct_half,16776751 runs,    465 skips

256pt:
   4920 decicycles in         av_tx (imdct),16775820 runs,   1396 skips
   5378 decicycles in         av_imdct_half,16776411 runs,    805 skips

512pt:
   9668 decicycles in         av_tx (imdct),16775774 runs,   1442 skips
  10626 decicycles in         av_imdct_half,16775647 runs,   1569 skips

1024pt:
  19812 decicycles in         av_tx (imdct),16777144 runs,     72 skips
  23036 decicycles in         av_imdct_half,16777167 runs,     49 skips
pull/388/head
Lynne 2 years ago
parent 2425d5cd7e
commit 4537d9554d
No known key found for this signature in database
GPG Key ID: A2FEA5F03F034464
  1. 19
      libavutil/tx.c
  2. 8
      libavutil/tx_priv.h
  3. 185
      libavutil/x86/tx_float.asm
  4. 32
      libavutil/x86/tx_float_init.c

@ -206,23 +206,24 @@ static void parity_revtab_generator(int *revtab, int n, int inv, int offset,
1, 1, len >> 1, basis, dual_stride, inv_lookup);
}
int ff_tx_gen_split_radix_parity_revtab(AVTXContext *s, int invert_lookup,
int basis, int dual_stride)
int ff_tx_gen_split_radix_parity_revtab(AVTXContext *s, int len, int inv,
int inv_lookup, int basis, int dual_stride)
{
int len = s->len;
int inv = s->inv;
if (!(s->map = av_mallocz(len*sizeof(*s->map))))
return AVERROR(ENOMEM);
basis >>= 1;
if (len < basis)
return AVERROR(EINVAL);
if (!(s->map = av_mallocz((inv_lookup == -1 ? 2 : 1)*len*sizeof(*s->map))))
return AVERROR(ENOMEM);
av_assert0(!dual_stride || !(dual_stride & (dual_stride - 1)));
av_assert0(dual_stride <= basis);
parity_revtab_generator(s->map, len, inv, 0, 0, 0, len,
basis, dual_stride, invert_lookup);
basis, dual_stride, inv_lookup != 0);
if (inv_lookup == -1)
parity_revtab_generator(s->map + len, len, inv, 0, 0, 0, len,
basis, dual_stride, 0);
return 0;
}

@ -288,9 +288,13 @@ int ff_tx_gen_ptwo_inplace_revtab_idx(AVTXContext *s);
* functions in AVX mode.
*
* If length is smaller than basis/2 this function will not do anything.
*
* If inv_lookup is set to 1, it will flip the lookup from out[map[i]] = src[i]
* to out[i] = src[map[i]]. If set to -1, will generate 2 maps, the first one
* flipped, the second one regular.
*/
int ff_tx_gen_split_radix_parity_revtab(AVTXContext *s, int invert_lookup,
int basis, int dual_stride);
int ff_tx_gen_split_radix_parity_revtab(AVTXContext *s, int len, int inv,
int inv_lookup, int basis, int dual_stride);
/* Typed init function to initialize shared tables. Will initialize all tables
* for all factors of a length. */

@ -1355,3 +1355,188 @@ FFT_SPLIT_RADIX_FN avx2, 0
FFT_SPLIT_RADIX_FN avx2, 1
%endif
%endif
%macro IMDCT_FN 1
INIT_YMM %1
cglobal mdct_sr_inv_float, 4, 12, 16, 288, ctx, out, in, stride, len, lut, exp, t1, t2, t3, t4, t5
movsxd lenq, dword [ctxq + AVTXContext.len]
mov expq, [ctxq + AVTXContext.exp]
lea t1d, [lend - 1]
imul t1d, strided
PUSH outq ; backup original output
mov t5q, [ctxq + AVTXContext.fn] ; subtransform's jump point
PUSH ctxq ; backup original context
mov ctxq, [ctxq + AVTXContext.sub] ; load subtransform's context
mov lutq, [ctxq + AVTXContext.map] ; load subtransform's map
cmp strideq, 4
je .stride4
shl strideq, 1
movd xm4, strided
vpbroadcastd m4, xm4 ; stride splatted
movd xm5, t1d
vpbroadcastd m5, xm5 ; offset splatted
mov t2q, outq ; don't modify the original output
pcmpeqd m15, m15 ; set all bits to 1
.stridex_pre:
pmulld m2, m4, [lutq] ; multiply by stride
movaps m0, m15
psubd m3, m5, m2 ; subtract from offset
movaps m1, m15
vgatherdps m6, [inq + m2], m0 ; im
vgatherdps m7, [inq + m3], m1 ; re
movaps m8, [expq + 0*mmsize] ; tab 1
movaps m9, [expq + 1*mmsize] ; tab 2
unpcklps m0, m7, m6 ; re, im, re, im
unpckhps m1, m7, m6 ; re, im, re, im
vperm2f128 m2, m1, m0, 0x02 ; output order
vperm2f128 m3, m1, m0, 0x13 ; output order
movshdup m10, m8 ; tab 1 imim
movshdup m11, m9 ; tab 2 imim
movsldup m12, m8 ; tab 1 rere
movsldup m13, m9 ; tab 2 rere
mulps m10, m2 ; 1 reim * imim
mulps m11, m3 ; 2 reim * imim
shufps m10, m10, q2301
shufps m11, m11, q2301
fmaddsubps m10, m12, m2, m10
fmaddsubps m11, m13, m3, m11
mova [t2q + 0*mmsize], m10
mova [t2q + 1*mmsize], m11
add expq, mmsize*2
add lutq, mmsize
add t2q, mmsize*2
sub lenq, mmsize/2
jg .stridex_pre
jmp .transform
.stride4:
lea expq, [expq + lenq*4]
lea lutq, [lutq + lenq*2]
lea t1q, [inq + t1q]
lea t1q, [t1q + strideq - mmsize]
lea t2q, [lenq*2 - mmsize/2]
.stride4_pre:
movaps m4, [inq]
movaps m3, [t1q]
movsldup m1, m4 ; im im, im im
movshdup m0, m3 ; re re, re re
movshdup m4, m4 ; re re, re re (2)
movsldup m3, m3 ; im im, im im (2)
movaps m2, [expq] ; tab
movaps m5, [expq + 2*t2q] ; tab (2)
vpermpd m0, m0, q0123 ; flip
shufps m7, m2, m2, q2301
vpermpd m4, m4, q0123 ; flip (2)
shufps m8, m5, m5, q2301
mulps m1, m7 ; im im * tab.reim
mulps m3, m8 ; im im * tab.reim (2)
fmaddsubps m0, m0, m2, m1
fmaddsubps m4, m4, m5, m3
vextractf128 xm3, m0, 1
vextractf128 xm6, m4, 1
; scatter
movsxd strideq, dword [lutq + 0*4]
movsxd lenq, dword [lutq + 1*4]
movsxd t3q, dword [lutq + 2*4]
movsxd t4q, dword [lutq + 3*4]
movlps [outq + strideq*8], xm0
movhps [outq + lenq*8], xm0
movlps [outq + t3q*8], xm3
movhps [outq + t4q*8], xm3
movsxd strideq, dword [lutq + 0*4 + t2q]
movsxd lenq, dword [lutq + 1*4 + t2q]
movsxd t3q, dword [lutq + 2*4 + t2q]
movsxd t4q, dword [lutq + 3*4 + t2q]
movlps [outq + strideq*8], xm4
movhps [outq + lenq*8], xm4
movlps [outq + t3q*8], xm6
movhps [outq + t4q*8], xm6
add lutq, mmsize/2
add expq, mmsize
add inq, mmsize
sub t1q, mmsize
sub t2q, mmsize
jg .stride4_pre
.transform:
movsxd lenq, dword [ctxq + AVTXContext.len]
mov t2q, lenq ; target length (for ptwo transforms)
mov inq, outq ; in-place transform
call t5q ; call the FFT
POP ctxq ; restore original context
movsxd lenq, dword [ctxq + AVTXContext.len]
mov expq, [ctxq + AVTXContext.exp]
lea expq, [expq + lenq*4]
lea t1q, [lenq*2] ; high
lea t2q, [lenq*2 - mmsize] ; low
POP outq
.post:
movaps m2, [expq + t1q] ; tab h
movaps m3, [expq + t2q] ; tab l
movaps m0, [outq + t1q] ; in h
movaps m1, [outq + t2q] ; in l
movshdup m4, m2 ; tab h imim
movshdup m5, m3 ; tab l imim
movsldup m6, m2 ; tab h rere
movsldup m7, m3 ; tab l rere
shufps m2, m0, m0, q2301 ; in h imre
shufps m3, m1, m1, q2301 ; in l imre
mulps m6, m0
mulps m7, m1
fmaddsubps m4, m4, m2, m6
fmaddsubps m5, m5, m3, m7
vpermpd m3, m5, q0123 ; flip
vpermpd m2, m4, q0123 ; flip
blendps m1, m2, m5, 01010101b
blendps m0, m3, m4, 01010101b
movaps [outq + t2q], m1
movaps [outq + t1q], m0
add t1q, mmsize
sub t2q, mmsize
jge .post
RET
%endmacro
%if ARCH_X86_64
IMDCT_FN avx2
%endif

@ -43,6 +43,8 @@ TX_DECL_FN(fft_sr_ns, fma3)
TX_DECL_FN(fft_sr, avx2)
TX_DECL_FN(fft_sr_ns, avx2)
TX_DECL_FN(mdct_sr_inv, avx2)
TX_DECL_FN(fft8_asm, sse3)
TX_DECL_FN(fft8_asm, avx)
TX_DECL_FN(fft16_asm, avx)
@ -65,13 +67,38 @@ static av_cold int b ##basis## _i ##interleave(AVTXContext *s, \
if (cd->max_len == 2) \
return ff_tx_gen_ptwo_revtab(s, inv_lookup); \
else \
return ff_tx_gen_split_radix_parity_revtab(s, inv_lookup, \
return ff_tx_gen_split_radix_parity_revtab(s, len, inv, inv_lookup, \
basis, interleave); \
}
DECL_INIT_FN(8, 0)
DECL_INIT_FN(8, 2)
static av_cold int m_inv_init(AVTXContext *s, const FFTXCodelet *cd,
uint64_t flags, FFTXCodeletOptions *opts,
int len, int inv, const void *scale)
{
int ret;
FFTXCodeletOptions sub_opts = { .invert_lookup = -1 };
s->scale_d = *((SCALE_TYPE *)scale);
s->scale_f = s->scale_d;
flags &= ~FF_TX_OUT_OF_PLACE; /* We want the subtransform to be */
flags |= AV_TX_INPLACE; /* in-place */
flags |= FF_TX_PRESHUFFLE; /* This function handles the permute step */
flags |= FF_TX_ASM_CALL; /* We want an assembly function, not C */
if ((ret = ff_tx_init_subtx(s, TX_TYPE(FFT), flags, &sub_opts, len >> 1,
inv, scale)))
return ret;
if ((ret = ff_tx_mdct_gen_exp_float(s, s->sub->map)))
return ret;
return 0;
}
const FFTXCodelet * const ff_tx_codelet_list_float_x86[] = {
TX_DEF(fft2, FFT, 2, 2, 2, 0, 128, NULL, sse3, SSE3, AV_TX_INPLACE, 0),
TX_DEF(fft2, FFT, 2, 2, 2, 0, 192, b8_i0, sse3, SSE3, AV_TX_INPLACE | FF_TX_PRESHUFFLE, 0),
@ -121,6 +148,9 @@ const FFTXCodelet * const ff_tx_codelet_list_float_x86[] = {
AV_TX_INPLACE | FF_TX_PRESHUFFLE | FF_TX_ASM_CALL, AV_CPU_FLAG_AVXSLOW | AV_CPU_FLAG_SLOW_GATHER),
TX_DEF(fft_sr_ns, FFT, 64, 131072, 2, 0, 384, b8_i2, avx2, AVX2, AV_TX_INPLACE | FF_TX_PRESHUFFLE,
AV_CPU_FLAG_AVXSLOW | AV_CPU_FLAG_SLOW_GATHER),
TX_DEF(mdct_sr_inv, MDCT, 16, TX_LEN_UNLIMITED, 2, TX_FACTOR_ANY, 384, m_inv_init, avx2, AVX2,
FF_TX_INVERSE_ONLY, AV_CPU_FLAG_AVXSLOW | AV_CPU_FLAG_SLOW_GATHER),
#endif
#endif

Loading…
Cancel
Save