Merge pull request #24066 from VadimLevin:dev/vlevin/python-typing-register-dnn-layer

Python typing refinement for dnn_registerLayer/dnn_unregisterLayer functions #24066

This patch introduces typings generation for `dnn_registerLayer`/`dnn_unregisterLayer` manually defined in [`cv2/modules/dnn/misc/python/pyopencv_dnn.hpp`](https://github.com/opencv/opencv/blob/4.x/modules/dnn/misc/python/pyopencv_dnn.hpp)

Updates:

- Add `LayerProtocol` to `cv2/dnn/__init__.pyi`:

    ```python
    class LayerProtocol(Protocol):
        def __init__(
            self, params: dict[str, DictValue],
            blobs: typing.Sequence[cv2.typing.MatLike]
        ) -> None: ...

        def getMemoryShapes(
            self, inputs: typing.Sequence[typing.Sequence[int]]
        ) -> typing.Sequence[typing.Sequence[int]]: ...

        def forward(
            self, inputs: typing.Sequence[cv2.typing.MatLike]
        ) -> typing.Sequence[cv2.typing.MatLike]: ...
    ```

- Add `dnn_registerLayer` function to `cv2/__init__.pyi`:

    ```python
    def dnn_registerLayer(layerTypeName: str,
                          layerClass: typing.Type[LayerProtocol]) -> None: ...
    ```

- Add `dnn_unregisterLayer` function to `cv2/__init__.pyi`:

    ```python
    def dnn_unregisterLayer(layerTypeName: str) -> None: ...
    ```
### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Vadim Levin 2023-07-27 11:28:00 +03:00 committed by GitHub
parent 5c090b9eec
commit 0c5d74ec1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 205 additions and 77 deletions

View File

@ -7,15 +7,18 @@ from typing import cast, Sequence, Callable, Iterable
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode, from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode, ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
AggregatedTypeNode, CallableTypeNode, AnyTypeNode, AggregatedTypeNode, CallableTypeNode, AnyTypeNode,
TupleTypeNode, UnionTypeNode) TupleTypeNode, UnionTypeNode, ProtocolClassNode,
DictTypeNode, ClassTypeNode)
from .ast_utils import (find_function_node, SymbolName, from .ast_utils import (find_function_node, SymbolName,
for_each_function_overload) for_each_function_overload)
from .types_conversion import create_type_node
def apply_manual_api_refinement(root: NamespaceNode) -> None: def apply_manual_api_refinement(root: NamespaceNode) -> None:
refine_highgui_module(root) refine_highgui_module(root)
refine_cuda_module(root) refine_cuda_module(root)
export_matrix_type_constants(root) export_matrix_type_constants(root)
refine_dnn_module(root)
# Export OpenCV exception class # Export OpenCV exception class
builtin_exception = root.add_class("Exception") builtin_exception = root.add_class("Exception")
builtin_exception.is_exported = False builtin_exception.is_exported = False
@ -215,6 +218,86 @@ def refine_highgui_module(root: NamespaceNode) -> None:
) )
def refine_dnn_module(root: NamespaceNode) -> None:
if "dnn" not in root.namespaces:
return
dnn_module = root.namespaces["dnn"]
"""
class LayerProtocol(Protocol):
def __init__(
self, params: dict[str, DictValue],
blobs: typing.Sequence[cv2.typing.MatLike]
) -> None: ...
def getMemoryShapes(
self, inputs: typing.Sequence[typing.Sequence[int]]
) -> typing.Sequence[typing.Sequence[int]]: ...
def forward(
self, inputs: typing.Sequence[cv2.typing.MatLike]
) -> typing.Sequence[cv2.typing.MatLike]: ...
"""
layer_proto = ProtocolClassNode("LayerProtocol", dnn_module)
layer_proto.add_function(
"__init__",
arguments=[
FunctionNode.Arg(
"params",
DictTypeNode(
"LayerParams", PrimitiveTypeNode.str_(),
create_type_node("cv::dnn::DictValue")
)
),
FunctionNode.Arg("blobs", create_type_node("vector<cv::Mat>"))
]
)
layer_proto.add_function(
"getMemoryShapes",
arguments=[
FunctionNode.Arg("inputs",
create_type_node("vector<vector<int>>"))
],
return_type=FunctionNode.RetType(
create_type_node("vector<vector<int>>")
)
)
layer_proto.add_function(
"forward",
arguments=[
FunctionNode.Arg("inputs", create_type_node("vector<cv::Mat>"))
],
return_type=FunctionNode.RetType(create_type_node("vector<cv::Mat>"))
)
"""
def dnn_registerLayer(layerTypeName: str,
layerClass: typing.Type[LayerProtocol]) -> None: ...
"""
root.add_function(
"dnn_registerLayer",
arguments=[
FunctionNode.Arg("layerTypeName", PrimitiveTypeNode.str_()),
FunctionNode.Arg(
"layerClass",
ClassTypeNode(ASTNodeTypeNode(
layer_proto.export_name, f"dnn.{layer_proto.export_name}"
))
)
]
)
"""
def dnn_unregisterLayer(layerTypeName: str) -> None: ...
"""
root.add_function(
"dnn_unregisterLayer",
arguments=[
FunctionNode.Arg("layerTypeName", PrimitiveTypeNode.str_())
]
)
def _trim_class_name_from_argument_types( def _trim_class_name_from_argument_types(
overloads: Iterable[FunctionNode.Overload], overloads: Iterable[FunctionNode.Overload],
class_name: str class_name: str

View File

@ -1,5 +1,5 @@
from typing import (NamedTuple, Sequence, Tuple, Union, List, from typing import (NamedTuple, Sequence, Tuple, Union, List,
Dict, Callable, Optional, Generator) Dict, Callable, Optional, Generator, cast)
import keyword import keyword
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode, from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
@ -204,9 +204,7 @@ def create_function_node_in_scope(scope: Union[NamespaceNode, ClassNode],
outlist = variant.py_outlist outlist = variant.py_outlist
for _, argno in outlist: for _, argno in outlist:
assert argno >= 0, \ assert argno >= 0, \
"Logic Error! Outlist contains function return type: {}".format( f"Logic Error! Outlist contains function return type: {outlist}"
outlist
)
ret_types.append(create_type_node(variant.args[argno].tp)) ret_types.append(create_type_node(variant.args[argno].tp))
@ -379,7 +377,7 @@ def get_enclosing_namespace(
node.full_export_name, node.native_name node.full_export_name, node.native_name
) )
if class_node_callback: if class_node_callback:
class_node_callback(parent_node) class_node_callback(cast(ClassNode, parent_node))
parent_node = parent_node.parent parent_node = parent_node.parent
return parent_node return parent_node
@ -395,12 +393,14 @@ def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, st
Returns: Returns:
Tuple[str, str]: a pair of enum export name and its full module name. Tuple[str, str]: a pair of enum export name and its full module name.
""" """
enum_export_name = enum_node.export_name
def update_full_export_name(class_node: ClassNode) -> None: def update_full_export_name(class_node: ClassNode) -> None:
nonlocal enum_export_name nonlocal enum_export_name
enum_export_name = class_node.export_name + "_" + enum_export_name enum_export_name = class_node.export_name + "_" + enum_export_name
enum_export_name = enum_node.export_name namespace_node = get_enclosing_namespace(enum_node,
namespace_node = get_enclosing_namespace(enum_node, update_full_export_name) update_full_export_name)
return enum_export_name, namespace_node.full_export_name return enum_export_name, namespace_node.full_export_name

View File

@ -3,7 +3,7 @@ __all__ = ("generate_typing_stubs", )
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
import re import re
from typing import (Type, Callable, NamedTuple, Union, Set, Dict, from typing import (Callable, NamedTuple, Union, Set, Dict,
Collection, Tuple, List) Collection, Tuple, List)
import warnings import warnings
@ -16,7 +16,8 @@ from .predefined_types import PREDEFINED_TYPES
from .api_refinement import apply_manual_api_refinement from .api_refinement import apply_manual_api_refinement
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
FunctionNode, EnumerationNode, ConstantNode) FunctionNode, EnumerationNode, ConstantNode,
ProtocolClassNode)
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode, from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
AggregatedTypeNode, ASTNodeTypeNode, AggregatedTypeNode, ASTNodeTypeNode,
@ -112,11 +113,13 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
# NOTE: Enumerations require special handling, because all enumeration # NOTE: Enumerations require special handling, because all enumeration
# constants are exposed as module attributes # constants are exposed as module attributes
has_enums = _generate_section_stub( has_enums = _generate_section_stub(
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0 StubSection("# Enumerations", ASTNodeType.Enumeration), root,
output_stream, 0
) )
# Collect all enums from class level and export them to module level # Collect all enums from class level and export them to module level
for class_node in root.classes.values(): for class_node in root.classes.values():
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0): if _generate_enums_from_classes_tree(class_node, output_stream,
indent=0):
has_enums = True has_enums = True
# 2 empty lines between enum and classes definitions # 2 empty lines between enum and classes definitions
if has_enums: if has_enums:
@ -134,14 +137,15 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
class StubSection(NamedTuple): class StubSection(NamedTuple):
name: str name: str
node_type: Type[ASTNode] node_type: ASTNodeType
STUB_SECTIONS = ( STUB_SECTIONS = (
StubSection("# Constants", ConstantNode), StubSection("# Constants", ASTNodeType.Constant),
# StubSection("# Enumerations", EnumerationNode), # Skipped for now (special rules) # Enumerations are skipped due to special handling rules
StubSection("# Classes", ClassNode), # StubSection("# Enumerations", ASTNodeType.Enumeration),
StubSection("# Functions", FunctionNode) StubSection("# Classes", ASTNodeType.Class),
StubSection("# Functions", ASTNodeType.Function)
) )
@ -250,9 +254,9 @@ def _generate_class_stub(class_node: ClassNode, output_stream: StringIO,
else: else:
bases.append(base.export_name) bases.append(base.export_name)
inheritance_str = "({})".format( inheritance_str = f"({', '.join(bases)})"
', '.join(bases) elif isinstance(class_node, ProtocolClassNode):
) inheritance_str = "(Protocol)"
else: else:
inheritance_str = "" inheritance_str = ""
@ -547,7 +551,8 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
return True return True
return False return False
def _collect_required_imports(root: NamespaceNode) -> Set[str]:
def _collect_required_imports(root: NamespaceNode) -> Collection[str]:
"""Collects all imports required for classes and functions typing stubs """Collects all imports required for classes and functions typing stubs
declarations. declarations.
@ -555,8 +560,8 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
root (NamespaceNode): Namespace node to collect imports for root (NamespaceNode): Namespace node to collect imports for
Returns: Returns:
Set[str]: Collection of unique `import smth` statements required for Collection[str]: Collection of unique `import smth` statements required
classes and function declarations of `root` node. for classes and function declarations of `root` node.
""" """
def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]): def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
@ -569,6 +574,7 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
has_overload = check_overload_presence(root) has_overload = check_overload_presence(root)
# if there is no module-level functions with overload, check its presence # if there is no module-level functions with overload, check its presence
# during class traversing, including their inner-classes # during class traversing, including their inner-classes
has_protocol = False
for cls in for_each_class(root): for cls in for_each_class(root):
if not has_overload and check_overload_presence(cls): if not has_overload and check_overload_presence(cls):
has_overload = True has_overload = True
@ -583,6 +589,8 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
required_imports.add( required_imports.add(
"import " + base_namespace.full_export_name "import " + base_namespace.full_export_name
) )
if isinstance(cls, ProtocolClassNode):
has_protocol = True
if has_overload: if has_overload:
required_imports.add("import typing") required_imports.add("import typing")
@ -599,7 +607,20 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
if root_import in required_imports: if root_import in required_imports:
required_imports.remove(root_import) required_imports.remove(root_import)
return required_imports if has_protocol:
required_imports.add("import sys")
ordered_required_imports = sorted(required_imports)
# Protocol import always goes as last import statement
if has_protocol:
ordered_required_imports.append(
"""if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol"""
)
return ordered_required_imports
def _populate_reexported_symbols(root: NamespaceNode) -> None: def _populate_reexported_symbols(root: NamespaceNode) -> None:
@ -666,7 +687,7 @@ def _write_required_imports(required_imports: Collection[str],
output_stream (StringIO): Output stream for import statements. output_stream (StringIO): Output stream for import statements.
""" """
for required_import in sorted(required_imports): for required_import in required_imports:
output_stream.write(required_import) output_stream.write(required_import)
output_stream.write("\n") output_stream.write("\n")
if len(required_imports): if len(required_imports):
@ -803,8 +824,8 @@ StubGenerator = Callable[[ASTNode, StringIO, int], None]
NODE_TYPE_TO_STUB_GENERATOR = { NODE_TYPE_TO_STUB_GENERATOR = {
ClassNode: _generate_class_stub, ASTNodeType.Class: _generate_class_stub,
ConstantNode: _generate_constant_stub, ASTNodeType.Constant: _generate_constant_stub,
EnumerationNode: _generate_enumeration_stub, ASTNodeType.Enumeration: _generate_enumeration_stub,
FunctionNode: _generate_function_stub ASTNodeType.Function: _generate_function_stub
} }

View File

@ -1,6 +1,6 @@
from .node import ASTNode, ASTNodeType from .node import ASTNode, ASTNodeType
from .namespace_node import NamespaceNode from .namespace_node import NamespaceNode
from .class_node import ClassNode, ClassProperty from .class_node import ClassNode, ClassProperty, ProtocolClassNode
from .function_node import FunctionNode from .function_node import FunctionNode
from .enumeration_node import EnumerationNode from .enumeration_node import EnumerationNode
from .constant_node import ConstantNode from .constant_node import ConstantNode
@ -8,5 +8,5 @@ from .type_node import (
TypeNode, OptionalTypeNode, UnionTypeNode, NoneTypeNode, TupleTypeNode, TypeNode, OptionalTypeNode, UnionTypeNode, NoneTypeNode, TupleTypeNode,
ASTNodeTypeNode, AliasTypeNode, SequenceTypeNode, AnyTypeNode, ASTNodeTypeNode, AliasTypeNode, SequenceTypeNode, AnyTypeNode,
AggregatedTypeNode, NDArrayTypeNode, AliasRefTypeNode, PrimitiveTypeNode, AggregatedTypeNode, NDArrayTypeNode, AliasRefTypeNode, PrimitiveTypeNode,
CallableTypeNode, CallableTypeNode, DictTypeNode, ClassTypeNode
) )

View File

@ -63,8 +63,9 @@ class ClassNode(ASTNode):
return 1 + sum(base.weight for base in self.bases) return 1 + sum(base.weight for base in self.bases)
@property @property
def children_types(self) -> Tuple[Type[ASTNode], ...]: def children_types(self) -> Tuple[ASTNodeType, ...]:
return (ClassNode, FunctionNode, EnumerationNode, ConstantNode) return (ASTNodeType.Class, ASTNodeType.Function,
ASTNodeType.Enumeration, ASTNodeType.Constant)
@property @property
def node_type(self) -> ASTNodeType: def node_type(self) -> ASTNodeType:
@ -72,19 +73,19 @@ class ClassNode(ASTNode):
@property @property
def classes(self) -> Dict[str, "ClassNode"]: def classes(self) -> Dict[str, "ClassNode"]:
return self._children[ClassNode] return self._children[ASTNodeType.Class]
@property @property
def functions(self) -> Dict[str, FunctionNode]: def functions(self) -> Dict[str, FunctionNode]:
return self._children[FunctionNode] return self._children[ASTNodeType.Function]
@property @property
def enumerations(self) -> Dict[str, EnumerationNode]: def enumerations(self) -> Dict[str, EnumerationNode]:
return self._children[EnumerationNode] return self._children[ASTNodeType.Enumeration]
@property @property
def constants(self) -> Dict[str, ConstantNode]: def constants(self) -> Dict[str, ConstantNode]:
return self._children[ConstantNode] return self._children[ASTNodeType.Constant]
def add_class(self, name: str, def add_class(self, name: str,
bases: Sequence["weakref.ProxyType[ClassNode]"] = (), bases: Sequence["weakref.ProxyType[ClassNode]"] = (),
@ -179,3 +180,11 @@ class ClassNode(ASTNode):
self.full_export_name, root.full_export_name, errors self.full_export_name, root.full_export_name, errors
) )
) )
class ProtocolClassNode(ClassNode):
def __init__(self, name: str, parent: Optional[ASTNode] = None,
export_name: Optional[str] = None,
properties: Sequence[ClassProperty] = ()) -> None:
super().__init__(name, parent, export_name, bases=(),
properties=properties)

View File

@ -1,4 +1,4 @@
from typing import Type, Optional, Tuple from typing import Optional, Tuple
from .node import ASTNode, ASTNodeType from .node import ASTNode, ASTNodeType
@ -14,7 +14,7 @@ class ConstantNode(ASTNode):
self._value_type = "int" self._value_type = "int"
@property @property
def children_types(self) -> Tuple[Type[ASTNode], ...]: def children_types(self) -> Tuple[ASTNodeType, ...]:
return () return ()
@property @property

View File

@ -18,8 +18,8 @@ class EnumerationNode(ASTNode):
self.is_scoped = is_scoped self.is_scoped = is_scoped
@property @property
def children_types(self) -> Tuple[Type[ASTNode], ...]: def children_types(self) -> Tuple[ASTNodeType, ...]:
return (ConstantNode, ) return (ASTNodeType.Constant, )
@property @property
def node_type(self) -> ASTNodeType: def node_type(self) -> ASTNodeType:
@ -27,7 +27,7 @@ class EnumerationNode(ASTNode):
@property @property
def constants(self) -> Dict[str, ConstantNode]: def constants(self) -> Dict[str, ConstantNode]:
return self._children[ConstantNode] return self._children[ASTNodeType.Constant]
def add_constant(self, name: str, value: str) -> ConstantNode: def add_constant(self, name: str, value: str) -> ConstantNode:
return self._add_child(ConstantNode, name, value=value) return self._add_child(ConstantNode, name, value=value)

View File

@ -1,4 +1,4 @@
from typing import NamedTuple, Sequence, Type, Optional, Tuple, List from typing import NamedTuple, Sequence, Optional, Tuple, List
from .node import ASTNode, ASTNodeType from .node import ASTNode, ASTNodeType
from .type_node import TypeNode, NoneTypeNode, TypeResolutionError from .type_node import TypeNode, NoneTypeNode, TypeResolutionError
@ -98,7 +98,7 @@ class FunctionNode(ASTNode):
return ASTNodeType.Function return ASTNodeType.Function
@property @property
def children_types(self) -> Tuple[Type[ASTNode], ...]: def children_types(self) -> Tuple[ASTNodeType, ...]:
return () return ()
def add_overload(self, arguments: Sequence["FunctionNode.Arg"] = (), def add_overload(self, arguments: Sequence["FunctionNode.Arg"] = (),

View File

@ -1,7 +1,7 @@
import itertools import itertools
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Tuple, Type from typing import Dict, List, Optional, Sequence, Tuple
from .class_node import ClassNode, ClassProperty from .class_node import ClassNode, ClassProperty
from .constant_node import ConstantNode from .constant_node import ConstantNode
@ -33,29 +33,29 @@ class NamespaceNode(ASTNode):
return ASTNodeType.Namespace return ASTNodeType.Namespace
@property @property
def children_types(self) -> Tuple[Type[ASTNode], ...]: def children_types(self) -> Tuple[ASTNodeType, ...]:
return (NamespaceNode, ClassNode, FunctionNode, return (ASTNodeType.Namespace, ASTNodeType.Class, ASTNodeType.Function,
EnumerationNode, ConstantNode) ASTNodeType.Enumeration, ASTNodeType.Constant)
@property @property
def namespaces(self) -> Dict[str, "NamespaceNode"]: def namespaces(self) -> Dict[str, "NamespaceNode"]:
return self._children[NamespaceNode] return self._children[ASTNodeType.Namespace]
@property @property
def classes(self) -> Dict[str, ClassNode]: def classes(self) -> Dict[str, ClassNode]:
return self._children[ClassNode] return self._children[ASTNodeType.Class]
@property @property
def functions(self) -> Dict[str, FunctionNode]: def functions(self) -> Dict[str, FunctionNode]:
return self._children[FunctionNode] return self._children[ASTNodeType.Function]
@property @property
def enumerations(self) -> Dict[str, EnumerationNode]: def enumerations(self) -> Dict[str, EnumerationNode]:
return self._children[EnumerationNode] return self._children[ASTNodeType.Enumeration]
@property @property
def constants(self) -> Dict[str, ConstantNode]: def constants(self) -> Dict[str, ConstantNode]:
return self._children[ConstantNode] return self._children[ASTNodeType.Constant]
def add_namespace(self, name: str) -> "NamespaceNode": def add_namespace(self, name: str) -> "NamespaceNode":
return self._add_child(NamespaceNode, name) return self._add_child(NamespaceNode, name)

View File

@ -70,22 +70,22 @@ class ASTNode:
self._parent: Optional["ASTNode"] = None self._parent: Optional["ASTNode"] = None
self.parent = parent self.parent = parent
self.is_exported = True self.is_exported = True
self._children: DefaultDict[NodeType, NameToNode] = defaultdict(dict) self._children: DefaultDict[ASTNodeType, NameToNode] = defaultdict(dict)
def __str__(self) -> str: def __str__(self) -> str:
return "{}('{}' exported as '{}')".format( return "{}('{}' exported as '{}')".format(
type(self).__name__.replace("Node", ""), self.name, self.export_name self.node_type.name, self.name, self.export_name
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return str(self) return str(self)
@abc.abstractproperty @abc.abstractproperty
def children_types(self) -> Tuple[Type["ASTNode"], ...]: def children_types(self) -> Tuple[ASTNodeType, ...]:
"""Set of ASTNode types that are allowed to be children of this node """Set of ASTNode types that are allowed to be children of this node
Returns: Returns:
Tuple[Type[ASTNode], ...]: Types of children nodes Tuple[ASTNodeType, ...]: Types of children nodes
""" """
pass pass
@ -99,6 +99,9 @@ class ASTNode:
""" """
pass pass
def node_type_name(self) -> str:
return f"{self.node_type.name}::{self.name}"
@property @property
def name(self) -> str: def name(self) -> str:
return self.__name return self.__name
@ -126,11 +129,11 @@ class ASTNode:
"but got: {}".format(type(value)) "but got: {}".format(type(value))
if value is not None: if value is not None:
value.__check_child_before_add(type(self), self.name) value.__check_child_before_add(self, self.name)
# Detach from previous parent # Detach from previous parent
if self._parent is not None: if self._parent is not None:
self._parent._children[type(self)].pop(self.name) self._parent._children[self.node_type].pop(self.name)
if value is None: if value is None:
self._parent = None self._parent = None
@ -138,28 +141,26 @@ class ASTNode:
# Set a weak reference to a new parent and add self to its children # Set a weak reference to a new parent and add self to its children
self._parent = weakref.proxy(value) self._parent = weakref.proxy(value)
value._children[type(self)][self.name] = self value._children[self.node_type][self.name] = self
def __check_child_before_add(self, child_type: Type[ASTNodeSubtype], def __check_child_before_add(self, child: ASTNodeSubtype,
name: str) -> None: name: str) -> None:
assert len(self.children_types) > 0, \ assert len(self.children_types) > 0, (
"Trying to add child node '{}::{}' to node '{}::{}' " \ f"Trying to add child node '{child.node_type_name}' to node "
"that can't have children nodes".format(child_type.__name__, name, f"'{self.node_type_name}' that can't have children nodes"
type(self).__name__, )
self.name)
assert child_type in self.children_types, \ assert child.node_type in self.children_types, \
"Trying to add child node '{}::{}' to node '{}::{}' " \ "Trying to add child node '{}' to node '{}' " \
"that supports only ({}) as its children types".format( "that supports only ({}) as its children types".format(
child_type.__name__, name, type(self).__name__, self.name, child.node_type_name, self.node_type_name,
",".join(t.__name__ for t in self.children_types) ",".join(t.name for t in self.children_types)
) )
if self._find_child(child_type, name) is not None: if self._find_child(child.node_type, name) is not None:
raise ValueError( raise ValueError(
"Node '{}::{}' already has a child '{}::{}'".format( f"Node '{self.node_type_name}' already has a "
type(self).__name__, self.name, child_type.__name__, name f"child '{child.node_type_name}'"
)
) )
def _add_child(self, child_type: Type[ASTNodeSubtype], name: str, def _add_child(self, child_type: Type[ASTNodeSubtype], name: str,
@ -180,15 +181,14 @@ class ASTNode:
Returns: Returns:
ASTNodeSubtype: Created ASTNode ASTNodeSubtype: Created ASTNode
""" """
self.__check_child_before_add(child_type, name)
return child_type(name, parent=self, **kwargs) return child_type(name, parent=self, **kwargs)
def _find_child(self, child_type: Type[ASTNodeSubtype], def _find_child(self, child_type: ASTNodeType,
name: str) -> Optional[ASTNodeSubtype]: name: str) -> Optional[ASTNodeSubtype]:
"""Looks for child node with the given type and name. """Looks for child node with the given type and name.
Args: Args:
child_type (Type[ASTNodeSubtype]): Type of the child node. child_type (ASTNodeType): Type of the child node.
name (str): Name of the child node. name (str): Name of the child node.
Returns: Returns:

View File

@ -839,6 +839,21 @@ class CallableTypeNode(AggregatedTypeNode):
yield from super().required_usage_imports yield from super().required_usage_imports
class ClassTypeNode(ContainerTypeNode):
"""Type node representing types themselves (refer to typing.Type)
"""
def __init__(self, value: TypeNode) -> None:
super().__init__(value.ctype_name, (value,))
@property
def type_format(self) -> str:
return "typing.Type[{}]"
@property
def types_separator(self) -> str:
return ", "
def _resolve_symbol(root: Optional[ASTNode], full_symbol_name: str) -> Optional[ASTNode]: def _resolve_symbol(root: Optional[ASTNode], full_symbol_name: str) -> Optional[ASTNode]:
"""Searches for a symbol with the given full export name in the AST """Searches for a symbol with the given full export name in the AST
starting from the `root`. starting from the `root`.