Merge pull request #24066 from VadimLevin:dev/vlevin/python-typing-register-dnn-layer

Python typing refinement for dnn_registerLayer/dnn_unregisterLayer functions #24066

This patch introduces typings generation for `dnn_registerLayer`/`dnn_unregisterLayer` manually defined in [`cv2/modules/dnn/misc/python/pyopencv_dnn.hpp`](https://github.com/opencv/opencv/blob/4.x/modules/dnn/misc/python/pyopencv_dnn.hpp)

Updates:

- Add `LayerProtocol` to `cv2/dnn/__init__.pyi`:

    ```python
    class LayerProtocol(Protocol):
        def __init__(
            self, params: dict[str, DictValue],
            blobs: typing.Sequence[cv2.typing.MatLike]
        ) -> None: ...

        def getMemoryShapes(
            self, inputs: typing.Sequence[typing.Sequence[int]]
        ) -> typing.Sequence[typing.Sequence[int]]: ...

        def forward(
            self, inputs: typing.Sequence[cv2.typing.MatLike]
        ) -> typing.Sequence[cv2.typing.MatLike]: ...
    ```

- Add `dnn_registerLayer` function to `cv2/__init__.pyi`:

    ```python
    def dnn_registerLayer(layerTypeName: str,
                          layerClass: typing.Type[LayerProtocol]) -> None: ...
    ```

- Add `dnn_unregisterLayer` function to `cv2/__init__.pyi`:

    ```python
    def dnn_unregisterLayer(layerTypeName: str) -> None: ...
    ```
### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Vadim Levin 2023-07-27 11:28:00 +03:00 committed by GitHub
parent 5c090b9eec
commit 0c5d74ec1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 205 additions and 77 deletions

View File

@ -7,15 +7,18 @@ from typing import cast, Sequence, Callable, Iterable
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
AggregatedTypeNode, CallableTypeNode, AnyTypeNode,
TupleTypeNode, UnionTypeNode)
TupleTypeNode, UnionTypeNode, ProtocolClassNode,
DictTypeNode, ClassTypeNode)
from .ast_utils import (find_function_node, SymbolName,
for_each_function_overload)
from .types_conversion import create_type_node
def apply_manual_api_refinement(root: NamespaceNode) -> None:
refine_highgui_module(root)
refine_cuda_module(root)
export_matrix_type_constants(root)
refine_dnn_module(root)
# Export OpenCV exception class
builtin_exception = root.add_class("Exception")
builtin_exception.is_exported = False
@ -215,6 +218,86 @@ def refine_highgui_module(root: NamespaceNode) -> None:
)
def refine_dnn_module(root: NamespaceNode) -> None:
if "dnn" not in root.namespaces:
return
dnn_module = root.namespaces["dnn"]
"""
class LayerProtocol(Protocol):
def __init__(
self, params: dict[str, DictValue],
blobs: typing.Sequence[cv2.typing.MatLike]
) -> None: ...
def getMemoryShapes(
self, inputs: typing.Sequence[typing.Sequence[int]]
) -> typing.Sequence[typing.Sequence[int]]: ...
def forward(
self, inputs: typing.Sequence[cv2.typing.MatLike]
) -> typing.Sequence[cv2.typing.MatLike]: ...
"""
layer_proto = ProtocolClassNode("LayerProtocol", dnn_module)
layer_proto.add_function(
"__init__",
arguments=[
FunctionNode.Arg(
"params",
DictTypeNode(
"LayerParams", PrimitiveTypeNode.str_(),
create_type_node("cv::dnn::DictValue")
)
),
FunctionNode.Arg("blobs", create_type_node("vector<cv::Mat>"))
]
)
layer_proto.add_function(
"getMemoryShapes",
arguments=[
FunctionNode.Arg("inputs",
create_type_node("vector<vector<int>>"))
],
return_type=FunctionNode.RetType(
create_type_node("vector<vector<int>>")
)
)
layer_proto.add_function(
"forward",
arguments=[
FunctionNode.Arg("inputs", create_type_node("vector<cv::Mat>"))
],
return_type=FunctionNode.RetType(create_type_node("vector<cv::Mat>"))
)
"""
def dnn_registerLayer(layerTypeName: str,
layerClass: typing.Type[LayerProtocol]) -> None: ...
"""
root.add_function(
"dnn_registerLayer",
arguments=[
FunctionNode.Arg("layerTypeName", PrimitiveTypeNode.str_()),
FunctionNode.Arg(
"layerClass",
ClassTypeNode(ASTNodeTypeNode(
layer_proto.export_name, f"dnn.{layer_proto.export_name}"
))
)
]
)
"""
def dnn_unregisterLayer(layerTypeName: str) -> None: ...
"""
root.add_function(
"dnn_unregisterLayer",
arguments=[
FunctionNode.Arg("layerTypeName", PrimitiveTypeNode.str_())
]
)
def _trim_class_name_from_argument_types(
overloads: Iterable[FunctionNode.Overload],
class_name: str

View File

@ -1,5 +1,5 @@
from typing import (NamedTuple, Sequence, Tuple, Union, List,
Dict, Callable, Optional, Generator)
Dict, Callable, Optional, Generator, cast)
import keyword
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
@ -204,9 +204,7 @@ def create_function_node_in_scope(scope: Union[NamespaceNode, ClassNode],
outlist = variant.py_outlist
for _, argno in outlist:
assert argno >= 0, \
"Logic Error! Outlist contains function return type: {}".format(
outlist
)
f"Logic Error! Outlist contains function return type: {outlist}"
ret_types.append(create_type_node(variant.args[argno].tp))
@ -379,7 +377,7 @@ def get_enclosing_namespace(
node.full_export_name, node.native_name
)
if class_node_callback:
class_node_callback(parent_node)
class_node_callback(cast(ClassNode, parent_node))
parent_node = parent_node.parent
return parent_node
@ -395,12 +393,14 @@ def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, st
Returns:
Tuple[str, str]: a pair of enum export name and its full module name.
"""
enum_export_name = enum_node.export_name
def update_full_export_name(class_node: ClassNode) -> None:
nonlocal enum_export_name
enum_export_name = class_node.export_name + "_" + enum_export_name
enum_export_name = enum_node.export_name
namespace_node = get_enclosing_namespace(enum_node, update_full_export_name)
namespace_node = get_enclosing_namespace(enum_node,
update_full_export_name)
return enum_export_name, namespace_node.full_export_name

View File

@ -3,7 +3,7 @@ __all__ = ("generate_typing_stubs", )
from io import StringIO
from pathlib import Path
import re
from typing import (Type, Callable, NamedTuple, Union, Set, Dict,
from typing import (Callable, NamedTuple, Union, Set, Dict,
Collection, Tuple, List)
import warnings
@ -16,7 +16,8 @@ from .predefined_types import PREDEFINED_TYPES
from .api_refinement import apply_manual_api_refinement
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
FunctionNode, EnumerationNode, ConstantNode)
FunctionNode, EnumerationNode, ConstantNode,
ProtocolClassNode)
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
AggregatedTypeNode, ASTNodeTypeNode,
@ -112,11 +113,13 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
# NOTE: Enumerations require special handling, because all enumeration
# constants are exposed as module attributes
has_enums = _generate_section_stub(
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0
StubSection("# Enumerations", ASTNodeType.Enumeration), root,
output_stream, 0
)
# Collect all enums from class level and export them to module level
for class_node in root.classes.values():
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
if _generate_enums_from_classes_tree(class_node, output_stream,
indent=0):
has_enums = True
# 2 empty lines between enum and classes definitions
if has_enums:
@ -134,14 +137,15 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
class StubSection(NamedTuple):
name: str
node_type: Type[ASTNode]
node_type: ASTNodeType
STUB_SECTIONS = (
StubSection("# Constants", ConstantNode),
# StubSection("# Enumerations", EnumerationNode), # Skipped for now (special rules)
StubSection("# Classes", ClassNode),
StubSection("# Functions", FunctionNode)
StubSection("# Constants", ASTNodeType.Constant),
# Enumerations are skipped due to special handling rules
# StubSection("# Enumerations", ASTNodeType.Enumeration),
StubSection("# Classes", ASTNodeType.Class),
StubSection("# Functions", ASTNodeType.Function)
)
@ -250,9 +254,9 @@ def _generate_class_stub(class_node: ClassNode, output_stream: StringIO,
else:
bases.append(base.export_name)
inheritance_str = "({})".format(
', '.join(bases)
)
inheritance_str = f"({', '.join(bases)})"
elif isinstance(class_node, ProtocolClassNode):
inheritance_str = "(Protocol)"
else:
inheritance_str = ""
@ -547,7 +551,8 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
return True
return False
def _collect_required_imports(root: NamespaceNode) -> Set[str]:
def _collect_required_imports(root: NamespaceNode) -> Collection[str]:
"""Collects all imports required for classes and functions typing stubs
declarations.
@ -555,8 +560,8 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
root (NamespaceNode): Namespace node to collect imports for
Returns:
Set[str]: Collection of unique `import smth` statements required for
classes and function declarations of `root` node.
Collection[str]: Collection of unique `import smth` statements required
for classes and function declarations of `root` node.
"""
def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
@ -569,6 +574,7 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
has_overload = check_overload_presence(root)
# if there is no module-level functions with overload, check its presence
# during class traversing, including their inner-classes
has_protocol = False
for cls in for_each_class(root):
if not has_overload and check_overload_presence(cls):
has_overload = True
@ -583,6 +589,8 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
required_imports.add(
"import " + base_namespace.full_export_name
)
if isinstance(cls, ProtocolClassNode):
has_protocol = True
if has_overload:
required_imports.add("import typing")
@ -599,7 +607,20 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
if root_import in required_imports:
required_imports.remove(root_import)
return required_imports
if has_protocol:
required_imports.add("import sys")
ordered_required_imports = sorted(required_imports)
# Protocol import always goes as last import statement
if has_protocol:
ordered_required_imports.append(
"""if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol"""
)
return ordered_required_imports
def _populate_reexported_symbols(root: NamespaceNode) -> None:
@ -666,7 +687,7 @@ def _write_required_imports(required_imports: Collection[str],
output_stream (StringIO): Output stream for import statements.
"""
for required_import in sorted(required_imports):
for required_import in required_imports:
output_stream.write(required_import)
output_stream.write("\n")
if len(required_imports):
@ -803,8 +824,8 @@ StubGenerator = Callable[[ASTNode, StringIO, int], None]
NODE_TYPE_TO_STUB_GENERATOR = {
ClassNode: _generate_class_stub,
ConstantNode: _generate_constant_stub,
EnumerationNode: _generate_enumeration_stub,
FunctionNode: _generate_function_stub
ASTNodeType.Class: _generate_class_stub,
ASTNodeType.Constant: _generate_constant_stub,
ASTNodeType.Enumeration: _generate_enumeration_stub,
ASTNodeType.Function: _generate_function_stub
}

View File

@ -1,6 +1,6 @@
from .node import ASTNode, ASTNodeType
from .namespace_node import NamespaceNode
from .class_node import ClassNode, ClassProperty
from .class_node import ClassNode, ClassProperty, ProtocolClassNode
from .function_node import FunctionNode
from .enumeration_node import EnumerationNode
from .constant_node import ConstantNode
@ -8,5 +8,5 @@ from .type_node import (
TypeNode, OptionalTypeNode, UnionTypeNode, NoneTypeNode, TupleTypeNode,
ASTNodeTypeNode, AliasTypeNode, SequenceTypeNode, AnyTypeNode,
AggregatedTypeNode, NDArrayTypeNode, AliasRefTypeNode, PrimitiveTypeNode,
CallableTypeNode,
CallableTypeNode, DictTypeNode, ClassTypeNode
)

View File

@ -63,8 +63,9 @@ class ClassNode(ASTNode):
return 1 + sum(base.weight for base in self.bases)
@property
def children_types(self) -> Tuple[Type[ASTNode], ...]:
return (ClassNode, FunctionNode, EnumerationNode, ConstantNode)
def children_types(self) -> Tuple[ASTNodeType, ...]:
return (ASTNodeType.Class, ASTNodeType.Function,
ASTNodeType.Enumeration, ASTNodeType.Constant)
@property
def node_type(self) -> ASTNodeType:
@ -72,19 +73,19 @@ class ClassNode(ASTNode):
@property
def classes(self) -> Dict[str, "ClassNode"]:
return self._children[ClassNode]
return self._children[ASTNodeType.Class]
@property
def functions(self) -> Dict[str, FunctionNode]:
return self._children[FunctionNode]
return self._children[ASTNodeType.Function]
@property
def enumerations(self) -> Dict[str, EnumerationNode]:
return self._children[EnumerationNode]
return self._children[ASTNodeType.Enumeration]
@property
def constants(self) -> Dict[str, ConstantNode]:
return self._children[ConstantNode]
return self._children[ASTNodeType.Constant]
def add_class(self, name: str,
bases: Sequence["weakref.ProxyType[ClassNode]"] = (),
@ -179,3 +180,11 @@ class ClassNode(ASTNode):
self.full_export_name, root.full_export_name, errors
)
)
class ProtocolClassNode(ClassNode):
def __init__(self, name: str, parent: Optional[ASTNode] = None,
export_name: Optional[str] = None,
properties: Sequence[ClassProperty] = ()) -> None:
super().__init__(name, parent, export_name, bases=(),
properties=properties)

View File

@ -1,4 +1,4 @@
from typing import Type, Optional, Tuple
from typing import Optional, Tuple
from .node import ASTNode, ASTNodeType
@ -14,7 +14,7 @@ class ConstantNode(ASTNode):
self._value_type = "int"
@property
def children_types(self) -> Tuple[Type[ASTNode], ...]:
def children_types(self) -> Tuple[ASTNodeType, ...]:
return ()
@property

View File

@ -18,8 +18,8 @@ class EnumerationNode(ASTNode):
self.is_scoped = is_scoped
@property
def children_types(self) -> Tuple[Type[ASTNode], ...]:
return (ConstantNode, )
def children_types(self) -> Tuple[ASTNodeType, ...]:
return (ASTNodeType.Constant, )
@property
def node_type(self) -> ASTNodeType:
@ -27,7 +27,7 @@ class EnumerationNode(ASTNode):
@property
def constants(self) -> Dict[str, ConstantNode]:
return self._children[ConstantNode]
return self._children[ASTNodeType.Constant]
def add_constant(self, name: str, value: str) -> ConstantNode:
return self._add_child(ConstantNode, name, value=value)

View File

@ -1,4 +1,4 @@
from typing import NamedTuple, Sequence, Type, Optional, Tuple, List
from typing import NamedTuple, Sequence, Optional, Tuple, List
from .node import ASTNode, ASTNodeType
from .type_node import TypeNode, NoneTypeNode, TypeResolutionError
@ -98,7 +98,7 @@ class FunctionNode(ASTNode):
return ASTNodeType.Function
@property
def children_types(self) -> Tuple[Type[ASTNode], ...]:
def children_types(self) -> Tuple[ASTNodeType, ...]:
return ()
def add_overload(self, arguments: Sequence["FunctionNode.Arg"] = (),

View File

@ -1,7 +1,7 @@
import itertools
import weakref
from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Tuple, Type
from typing import Dict, List, Optional, Sequence, Tuple
from .class_node import ClassNode, ClassProperty
from .constant_node import ConstantNode
@ -33,29 +33,29 @@ class NamespaceNode(ASTNode):
return ASTNodeType.Namespace
@property
def children_types(self) -> Tuple[Type[ASTNode], ...]:
return (NamespaceNode, ClassNode, FunctionNode,
EnumerationNode, ConstantNode)
def children_types(self) -> Tuple[ASTNodeType, ...]:
return (ASTNodeType.Namespace, ASTNodeType.Class, ASTNodeType.Function,
ASTNodeType.Enumeration, ASTNodeType.Constant)
@property
def namespaces(self) -> Dict[str, "NamespaceNode"]:
return self._children[NamespaceNode]
return self._children[ASTNodeType.Namespace]
@property
def classes(self) -> Dict[str, ClassNode]:
return self._children[ClassNode]
return self._children[ASTNodeType.Class]
@property
def functions(self) -> Dict[str, FunctionNode]:
return self._children[FunctionNode]
return self._children[ASTNodeType.Function]
@property
def enumerations(self) -> Dict[str, EnumerationNode]:
return self._children[EnumerationNode]
return self._children[ASTNodeType.Enumeration]
@property
def constants(self) -> Dict[str, ConstantNode]:
return self._children[ConstantNode]
return self._children[ASTNodeType.Constant]
def add_namespace(self, name: str) -> "NamespaceNode":
return self._add_child(NamespaceNode, name)

View File

@ -70,22 +70,22 @@ class ASTNode:
self._parent: Optional["ASTNode"] = None
self.parent = parent
self.is_exported = True
self._children: DefaultDict[NodeType, NameToNode] = defaultdict(dict)
self._children: DefaultDict[ASTNodeType, NameToNode] = defaultdict(dict)
def __str__(self) -> str:
return "{}('{}' exported as '{}')".format(
type(self).__name__.replace("Node", ""), self.name, self.export_name
self.node_type.name, self.name, self.export_name
)
def __repr__(self) -> str:
return str(self)
@abc.abstractproperty
def children_types(self) -> Tuple[Type["ASTNode"], ...]:
def children_types(self) -> Tuple[ASTNodeType, ...]:
"""Set of ASTNode types that are allowed to be children of this node
Returns:
Tuple[Type[ASTNode], ...]: Types of children nodes
Tuple[ASTNodeType, ...]: Types of children nodes
"""
pass
@ -99,6 +99,9 @@ class ASTNode:
"""
pass
def node_type_name(self) -> str:
return f"{self.node_type.name}::{self.name}"
@property
def name(self) -> str:
return self.__name
@ -126,11 +129,11 @@ class ASTNode:
"but got: {}".format(type(value))
if value is not None:
value.__check_child_before_add(type(self), self.name)
value.__check_child_before_add(self, self.name)
# Detach from previous parent
if self._parent is not None:
self._parent._children[type(self)].pop(self.name)
self._parent._children[self.node_type].pop(self.name)
if value is None:
self._parent = None
@ -138,28 +141,26 @@ class ASTNode:
# Set a weak reference to a new parent and add self to its children
self._parent = weakref.proxy(value)
value._children[type(self)][self.name] = self
value._children[self.node_type][self.name] = self
def __check_child_before_add(self, child_type: Type[ASTNodeSubtype],
def __check_child_before_add(self, child: ASTNodeSubtype,
name: str) -> None:
assert len(self.children_types) > 0, \
"Trying to add child node '{}::{}' to node '{}::{}' " \
"that can't have children nodes".format(child_type.__name__, name,
type(self).__name__,
self.name)
assert len(self.children_types) > 0, (
f"Trying to add child node '{child.node_type_name}' to node "
f"'{self.node_type_name}' that can't have children nodes"
)
assert child_type in self.children_types, \
"Trying to add child node '{}::{}' to node '{}::{}' " \
assert child.node_type in self.children_types, \
"Trying to add child node '{}' to node '{}' " \
"that supports only ({}) as its children types".format(
child_type.__name__, name, type(self).__name__, self.name,
",".join(t.__name__ for t in self.children_types)
child.node_type_name, self.node_type_name,
",".join(t.name for t in self.children_types)
)
if self._find_child(child_type, name) is not None:
if self._find_child(child.node_type, name) is not None:
raise ValueError(
"Node '{}::{}' already has a child '{}::{}'".format(
type(self).__name__, self.name, child_type.__name__, name
)
f"Node '{self.node_type_name}' already has a "
f"child '{child.node_type_name}'"
)
def _add_child(self, child_type: Type[ASTNodeSubtype], name: str,
@ -180,15 +181,14 @@ class ASTNode:
Returns:
ASTNodeSubtype: Created ASTNode
"""
self.__check_child_before_add(child_type, name)
return child_type(name, parent=self, **kwargs)
def _find_child(self, child_type: Type[ASTNodeSubtype],
def _find_child(self, child_type: ASTNodeType,
name: str) -> Optional[ASTNodeSubtype]:
"""Looks for child node with the given type and name.
Args:
child_type (Type[ASTNodeSubtype]): Type of the child node.
child_type (ASTNodeType): Type of the child node.
name (str): Name of the child node.
Returns:

View File

@ -839,6 +839,21 @@ class CallableTypeNode(AggregatedTypeNode):
yield from super().required_usage_imports
class ClassTypeNode(ContainerTypeNode):
"""Type node representing types themselves (refer to typing.Type)
"""
def __init__(self, value: TypeNode) -> None:
super().__init__(value.ctype_name, (value,))
@property
def type_format(self) -> str:
return "typing.Type[{}]"
@property
def types_separator(self) -> str:
return ", "
def _resolve_symbol(root: Optional[ASTNode], full_symbol_name: str) -> Optional[ASTNode]:
"""Searches for a symbol with the given full export name in the AST
starting from the `root`.