mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
Add tests for stateful kernel functionality
This commit is contained in:
parent
a3d6994afa
commit
bf54a370e5
@ -432,7 +432,7 @@ try:
|
||||
with self.assertRaises(Exception): create_op([cv.GMat, int], [cv.GMat]).on(cv.GMat())
|
||||
|
||||
|
||||
def test_stateful_kernel(self):
|
||||
def test_state_in_class(self):
|
||||
@cv.gapi.op('custom.sum', in_types=[cv.GArray.Int], out_types=[cv.GOpaque.Int])
|
||||
class GSum:
|
||||
@staticmethod
|
||||
|
96
modules/gapi/misc/python/test/test_gapi_stateful_kernel.py
Normal file
96
modules/gapi/misc/python/test/test_gapi_stateful_kernel.py
Normal file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from tests_common import NewOpenCVTests
|
||||
|
||||
|
||||
try:
|
||||
|
||||
if sys.version_info[:2] < (3, 0):
|
||||
raise unittest.SkipTest('Python 2.x is not supported')
|
||||
|
||||
|
||||
class CounterState:
|
||||
def __init__(self):
|
||||
self.counter = 0
|
||||
|
||||
|
||||
@cv.gapi.op('stateful_counter',
|
||||
in_types=[cv.GOpaque.Int],
|
||||
out_types=[cv.GOpaque.Int])
|
||||
class GStatefulCounter:
|
||||
"""Accumulate state counter on every call"""
|
||||
|
||||
@staticmethod
|
||||
def outMeta(desc):
|
||||
return cv.empty_gopaque_desc()
|
||||
|
||||
|
||||
@cv.gapi.kernel(GStatefulCounter)
|
||||
class GStatefulCounterImpl:
|
||||
"""Implementation for GStatefulCounter operation."""
|
||||
|
||||
@staticmethod
|
||||
def setup(desc):
|
||||
return CounterState()
|
||||
|
||||
@staticmethod
|
||||
def run(value, state):
|
||||
state.counter += value
|
||||
return state.counter
|
||||
|
||||
|
||||
class gapi_sample_pipelines(NewOpenCVTests):
|
||||
def test_stateful_kernel_single_instance(self):
|
||||
g_in = cv.GOpaque.Int()
|
||||
g_out = GStatefulCounter.on(g_in)
|
||||
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out))
|
||||
pkg = cv.gapi.kernels(GStatefulCounterImpl)
|
||||
|
||||
nums = [i for i in range(10)]
|
||||
acc = 0
|
||||
for v in nums:
|
||||
acc = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg))
|
||||
|
||||
self.assertEqual(sum(nums), acc)
|
||||
|
||||
|
||||
def test_stateful_kernel_multiple_instances(self):
|
||||
# NB: Every counter has his own independent state.
|
||||
g_in = cv.GOpaque.Int()
|
||||
g_out0 = GStatefulCounter.on(g_in)
|
||||
g_out1 = GStatefulCounter.on(g_in)
|
||||
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out0, g_out1))
|
||||
pkg = cv.gapi.kernels(GStatefulCounterImpl)
|
||||
|
||||
nums = [i for i in range(10)]
|
||||
acc0 = acc1 = 0
|
||||
for v in nums:
|
||||
acc0, acc1 = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg))
|
||||
|
||||
ref = sum(nums)
|
||||
self.assertEqual(ref, acc0)
|
||||
self.assertEqual(ref, acc1)
|
||||
|
||||
|
||||
except unittest.SkipTest as e:
|
||||
|
||||
message = str(e)
|
||||
|
||||
class TestSkip(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.skipTest('Skip tests: ' + message)
|
||||
|
||||
def test_skip():
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
NewOpenCVTests.bootstrap()
|
Loading…
Reference in New Issue
Block a user