From 17006196a6a51d8a34c41518cc1d17d703f61201 Mon Sep 17 00:00:00 2001 From: "Guo, Yejun" Date: Fri, 10 Apr 2020 22:32:02 +0800 Subject: [PATCH] dnn-layer-mathbinary-test: add unit test for add Signed-off-by: Guo, Yejun --- tests/dnn/dnn-layer-mathbinary-test.c | 55 ++++++++++++++++++--------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/tests/dnn/dnn-layer-mathbinary-test.c b/tests/dnn/dnn-layer-mathbinary-test.c index 1243784b07..fd8037fca0 100644 --- a/tests/dnn/dnn-layer-mathbinary-test.c +++ b/tests/dnn/dnn-layer-mathbinary-test.c @@ -22,10 +22,25 @@ #include #include #include "libavfilter/dnn/dnn_backend_native_layer_mathbinary.h" +#include "libavutil/avassert.h" #define EPSON 0.00001 -static int test_sub_broadcast_input0(void) +static float get_expected(float f1, float f2, DNNMathBinaryOperation op) +{ + switch (op) + { + case DMBO_SUB: + return f1 - f2; + case DMBO_ADD: + return f1 + f2; + default: + av_assert0(!"not supported yet"); + return 0.f; + } +} + +static int test_broadcast_input0(DNNMathBinaryOperation op) { DnnLayerMathBinaryParams params; DnnOperand operands[2]; @@ -35,7 +50,7 @@ static int test_sub_broadcast_input0(void) }; float *output; - params.bin_op = DMBO_SUB; + params.bin_op = op; params.input0_broadcast = 1; params.input1_broadcast = 0; params.v = 7.28; @@ -52,9 +67,10 @@ static int test_sub_broadcast_input0(void) output = operands[1].data; for (int i = 0; i < sizeof(input) / sizeof(float); i++) { - float expected_output = params.v - input[i]; + float expected_output = get_expected(params.v, input[i], op); if (fabs(output[i] - expected_output) > EPSON) { - printf("at index %d, output: %f, expected_output: %f\n", i, output[i], expected_output); + printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n", + op, i, output[i], expected_output, __FILE__, __LINE__); av_freep(&output); return 1; } @@ -64,7 +80,7 @@ static int test_sub_broadcast_input0(void) return 0; } -static int test_sub_broadcast_input1(void) +static int test_broadcast_input1(DNNMathBinaryOperation op) { DnnLayerMathBinaryParams params; DnnOperand operands[2]; @@ -74,7 +90,7 @@ static int test_sub_broadcast_input1(void) }; float *output; - params.bin_op = DMBO_SUB; + params.bin_op = op; params.input0_broadcast = 0; params.input1_broadcast = 1; params.v = 7.28; @@ -91,9 +107,10 @@ static int test_sub_broadcast_input1(void) output = operands[1].data; for (int i = 0; i < sizeof(input) / sizeof(float); i++) { - float expected_output = input[i] - params.v; + float expected_output = get_expected(input[i], params.v, op); if (fabs(output[i] - expected_output) > EPSON) { - printf("at index %d, output: %f, expected_output: %f\n", i, output[i], expected_output); + printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n", + op, i, output[i], expected_output, __FILE__, __LINE__); av_freep(&output); return 1; } @@ -103,7 +120,7 @@ static int test_sub_broadcast_input1(void) return 0; } -static int test_sub_no_broadcast(void) +static int test_no_broadcast(DNNMathBinaryOperation op) { DnnLayerMathBinaryParams params; DnnOperand operands[3]; @@ -116,7 +133,7 @@ static int test_sub_no_broadcast(void) }; float *output; - params.bin_op = DMBO_SUB; + params.bin_op = op; params.input0_broadcast = 0; params.input1_broadcast = 0; @@ -138,9 +155,10 @@ static int test_sub_no_broadcast(void) output = operands[2].data; for (int i = 0; i < sizeof(input0) / sizeof(float); i++) { - float expected_output = input0[i] - input1[i]; + float expected_output = get_expected(input0[i], input1[i], op); if (fabs(output[i] - expected_output) > EPSON) { - printf("at index %d, output: %f, expected_output: %f\n", i, output[i], expected_output); + printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n", + op, i, output[i], expected_output, __FILE__, __LINE__); av_freep(&output); return 1; } @@ -150,15 +168,15 @@ static int test_sub_no_broadcast(void) return 0; } -static int test_sub(void) +static int test(DNNMathBinaryOperation op) { - if (test_sub_broadcast_input0()) + if (test_broadcast_input0(op)) return 1; - if (test_sub_broadcast_input1()) + if (test_broadcast_input1(op)) return 1; - if (test_sub_no_broadcast()) + if (test_no_broadcast(op)) return 1; return 0; @@ -166,7 +184,10 @@ static int test_sub(void) int main(int argc, char **argv) { - if (test_sub()) + if (test(DMBO_SUB)) + return 1; + + if (test(DMBO_ADD)) return 1; return 0;