parent
a3d6994afa
commit
bf54a370e5
2 changed files with 97 additions and 1 deletions
@ -0,0 +1,96 @@ |
|||||||
|
#!/usr/bin/env python |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
import cv2 as cv |
||||||
|
import os |
||||||
|
import sys |
||||||
|
import unittest |
||||||
|
|
||||||
|
from tests_common import NewOpenCVTests |
||||||
|
|
||||||
|
|
||||||
|
try: |
||||||
|
|
||||||
|
if sys.version_info[:2] < (3, 0): |
||||||
|
raise unittest.SkipTest('Python 2.x is not supported') |
||||||
|
|
||||||
|
|
||||||
|
class CounterState: |
||||||
|
def __init__(self): |
||||||
|
self.counter = 0 |
||||||
|
|
||||||
|
|
||||||
|
@cv.gapi.op('stateful_counter', |
||||||
|
in_types=[cv.GOpaque.Int], |
||||||
|
out_types=[cv.GOpaque.Int]) |
||||||
|
class GStatefulCounter: |
||||||
|
"""Accumulate state counter on every call""" |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def outMeta(desc): |
||||||
|
return cv.empty_gopaque_desc() |
||||||
|
|
||||||
|
|
||||||
|
@cv.gapi.kernel(GStatefulCounter) |
||||||
|
class GStatefulCounterImpl: |
||||||
|
"""Implementation for GStatefulCounter operation.""" |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def setup(desc): |
||||||
|
return CounterState() |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def run(value, state): |
||||||
|
state.counter += value |
||||||
|
return state.counter |
||||||
|
|
||||||
|
|
||||||
|
class gapi_sample_pipelines(NewOpenCVTests): |
||||||
|
def test_stateful_kernel_single_instance(self): |
||||||
|
g_in = cv.GOpaque.Int() |
||||||
|
g_out = GStatefulCounter.on(g_in) |
||||||
|
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out)) |
||||||
|
pkg = cv.gapi.kernels(GStatefulCounterImpl) |
||||||
|
|
||||||
|
nums = [i for i in range(10)] |
||||||
|
acc = 0 |
||||||
|
for v in nums: |
||||||
|
acc = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg)) |
||||||
|
|
||||||
|
self.assertEqual(sum(nums), acc) |
||||||
|
|
||||||
|
|
||||||
|
def test_stateful_kernel_multiple_instances(self): |
||||||
|
# NB: Every counter has his own independent state. |
||||||
|
g_in = cv.GOpaque.Int() |
||||||
|
g_out0 = GStatefulCounter.on(g_in) |
||||||
|
g_out1 = GStatefulCounter.on(g_in) |
||||||
|
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out0, g_out1)) |
||||||
|
pkg = cv.gapi.kernels(GStatefulCounterImpl) |
||||||
|
|
||||||
|
nums = [i for i in range(10)] |
||||||
|
acc0 = acc1 = 0 |
||||||
|
for v in nums: |
||||||
|
acc0, acc1 = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg)) |
||||||
|
|
||||||
|
ref = sum(nums) |
||||||
|
self.assertEqual(ref, acc0) |
||||||
|
self.assertEqual(ref, acc1) |
||||||
|
|
||||||
|
|
||||||
|
except unittest.SkipTest as e: |
||||||
|
|
||||||
|
message = str(e) |
||||||
|
|
||||||
|
class TestSkip(unittest.TestCase): |
||||||
|
def setUp(self): |
||||||
|
self.skipTest('Skip tests: ' + message) |
||||||
|
|
||||||
|
def test_skip(): |
||||||
|
pass |
||||||
|
|
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
NewOpenCVTests.bootstrap() |
Loading…
Reference in new issue