From bf54a370e56f24da87c5789d3f013c815f731cd9 Mon Sep 17 00:00:00 2001 From: TolyaTalamanov Date: Sun, 4 Sep 2022 20:15:53 +0100 Subject: [PATCH] Add tests for stateful kernel functionality --- .../python/test/test_gapi_sample_pipelines.py | 2 +- .../python/test/test_gapi_stateful_kernel.py | 96 +++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 modules/gapi/misc/python/test/test_gapi_stateful_kernel.py diff --git a/modules/gapi/misc/python/test/test_gapi_sample_pipelines.py b/modules/gapi/misc/python/test/test_gapi_sample_pipelines.py index 7763579ebf..8e15d457d9 100644 --- a/modules/gapi/misc/python/test/test_gapi_sample_pipelines.py +++ b/modules/gapi/misc/python/test/test_gapi_sample_pipelines.py @@ -432,7 +432,7 @@ try: with self.assertRaises(Exception): create_op([cv.GMat, int], [cv.GMat]).on(cv.GMat()) - def test_stateful_kernel(self): + def test_state_in_class(self): @cv.gapi.op('custom.sum', in_types=[cv.GArray.Int], out_types=[cv.GOpaque.Int]) class GSum: @staticmethod diff --git a/modules/gapi/misc/python/test/test_gapi_stateful_kernel.py b/modules/gapi/misc/python/test/test_gapi_stateful_kernel.py new file mode 100644 index 0000000000..f24ceccabf --- /dev/null +++ b/modules/gapi/misc/python/test/test_gapi_stateful_kernel.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +import numpy as np +import cv2 as cv +import os +import sys +import unittest + +from tests_common import NewOpenCVTests + + +try: + + if sys.version_info[:2] < (3, 0): + raise unittest.SkipTest('Python 2.x is not supported') + + + class CounterState: + def __init__(self): + self.counter = 0 + + + @cv.gapi.op('stateful_counter', + in_types=[cv.GOpaque.Int], + out_types=[cv.GOpaque.Int]) + class GStatefulCounter: + """Accumulate state counter on every call""" + + @staticmethod + def outMeta(desc): + return cv.empty_gopaque_desc() + + + @cv.gapi.kernel(GStatefulCounter) + class GStatefulCounterImpl: + """Implementation for GStatefulCounter operation.""" + + @staticmethod + def setup(desc): + return CounterState() + + @staticmethod + def run(value, state): + state.counter += value + return state.counter + + + class gapi_sample_pipelines(NewOpenCVTests): + def test_stateful_kernel_single_instance(self): + g_in = cv.GOpaque.Int() + g_out = GStatefulCounter.on(g_in) + comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out)) + pkg = cv.gapi.kernels(GStatefulCounterImpl) + + nums = [i for i in range(10)] + acc = 0 + for v in nums: + acc = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg)) + + self.assertEqual(sum(nums), acc) + + + def test_stateful_kernel_multiple_instances(self): + # NB: Every counter has his own independent state. + g_in = cv.GOpaque.Int() + g_out0 = GStatefulCounter.on(g_in) + g_out1 = GStatefulCounter.on(g_in) + comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out0, g_out1)) + pkg = cv.gapi.kernels(GStatefulCounterImpl) + + nums = [i for i in range(10)] + acc0 = acc1 = 0 + for v in nums: + acc0, acc1 = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg)) + + ref = sum(nums) + self.assertEqual(ref, acc0) + self.assertEqual(ref, acc1) + + +except unittest.SkipTest as e: + + message = str(e) + + class TestSkip(unittest.TestCase): + def setUp(self): + self.skipTest('Skip tests: ' + message) + + def test_skip(): + pass + + pass + + +if __name__ == '__main__': + NewOpenCVTests.bootstrap()