Add tests for stateful kernel functionality

This commit is contained in:
TolyaTalamanov 2022-09-04 20:15:53 +01:00
parent a3d6994afa
commit bf54a370e5
2 changed files with 97 additions and 1 deletions

View File

@ -432,7 +432,7 @@ try:
with self.assertRaises(Exception): create_op([cv.GMat, int], [cv.GMat]).on(cv.GMat())
def test_stateful_kernel(self):
def test_state_in_class(self):
@cv.gapi.op('custom.sum', in_types=[cv.GArray.Int], out_types=[cv.GOpaque.Int])
class GSum:
@staticmethod

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python
import numpy as np
import cv2 as cv
import os
import sys
import unittest
from tests_common import NewOpenCVTests
try:
if sys.version_info[:2] < (3, 0):
raise unittest.SkipTest('Python 2.x is not supported')
class CounterState:
def __init__(self):
self.counter = 0
@cv.gapi.op('stateful_counter',
in_types=[cv.GOpaque.Int],
out_types=[cv.GOpaque.Int])
class GStatefulCounter:
"""Accumulate state counter on every call"""
@staticmethod
def outMeta(desc):
return cv.empty_gopaque_desc()
@cv.gapi.kernel(GStatefulCounter)
class GStatefulCounterImpl:
"""Implementation for GStatefulCounter operation."""
@staticmethod
def setup(desc):
return CounterState()
@staticmethod
def run(value, state):
state.counter += value
return state.counter
class gapi_sample_pipelines(NewOpenCVTests):
def test_stateful_kernel_single_instance(self):
g_in = cv.GOpaque.Int()
g_out = GStatefulCounter.on(g_in)
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out))
pkg = cv.gapi.kernels(GStatefulCounterImpl)
nums = [i for i in range(10)]
acc = 0
for v in nums:
acc = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg))
self.assertEqual(sum(nums), acc)
def test_stateful_kernel_multiple_instances(self):
# NB: Every counter has his own independent state.
g_in = cv.GOpaque.Int()
g_out0 = GStatefulCounter.on(g_in)
g_out1 = GStatefulCounter.on(g_in)
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out0, g_out1))
pkg = cv.gapi.kernels(GStatefulCounterImpl)
nums = [i for i in range(10)]
acc0 = acc1 = 0
for v in nums:
acc0, acc1 = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg))
ref = sum(nums)
self.assertEqual(ref, acc0)
self.assertEqual(ref, acc1)
except unittest.SkipTest as e:
message = str(e)
class TestSkip(unittest.TestCase):
def setUp(self):
self.skipTest('Skip tests: ' + message)
def test_skip():
pass
pass
if __name__ == '__main__':
NewOpenCVTests.bootstrap()