mirror of
https://github.com/opencv/opencv.git
synced 2024-11-27 20:50:25 +08:00
Merge pull request #24066 from VadimLevin:dev/vlevin/python-typing-register-dnn-layer
Python typing refinement for dnn_registerLayer/dnn_unregisterLayer functions #24066 This patch introduces typings generation for `dnn_registerLayer`/`dnn_unregisterLayer` manually defined in [`cv2/modules/dnn/misc/python/pyopencv_dnn.hpp`](https://github.com/opencv/opencv/blob/4.x/modules/dnn/misc/python/pyopencv_dnn.hpp) Updates: - Add `LayerProtocol` to `cv2/dnn/__init__.pyi`: ```python class LayerProtocol(Protocol): def __init__( self, params: dict[str, DictValue], blobs: typing.Sequence[cv2.typing.MatLike] ) -> None: ... def getMemoryShapes( self, inputs: typing.Sequence[typing.Sequence[int]] ) -> typing.Sequence[typing.Sequence[int]]: ... def forward( self, inputs: typing.Sequence[cv2.typing.MatLike] ) -> typing.Sequence[cv2.typing.MatLike]: ... ``` - Add `dnn_registerLayer` function to `cv2/__init__.pyi`: ```python def dnn_registerLayer(layerTypeName: str, layerClass: typing.Type[LayerProtocol]) -> None: ... ``` - Add `dnn_unregisterLayer` function to `cv2/__init__.pyi`: ```python def dnn_unregisterLayer(layerTypeName: str) -> None: ... ``` ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
5c090b9eec
commit
0c5d74ec1a
@ -7,15 +7,18 @@ from typing import cast, Sequence, Callable, Iterable
|
||||
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
|
||||
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
|
||||
AggregatedTypeNode, CallableTypeNode, AnyTypeNode,
|
||||
TupleTypeNode, UnionTypeNode)
|
||||
TupleTypeNode, UnionTypeNode, ProtocolClassNode,
|
||||
DictTypeNode, ClassTypeNode)
|
||||
from .ast_utils import (find_function_node, SymbolName,
|
||||
for_each_function_overload)
|
||||
from .types_conversion import create_type_node
|
||||
|
||||
|
||||
def apply_manual_api_refinement(root: NamespaceNode) -> None:
|
||||
refine_highgui_module(root)
|
||||
refine_cuda_module(root)
|
||||
export_matrix_type_constants(root)
|
||||
refine_dnn_module(root)
|
||||
# Export OpenCV exception class
|
||||
builtin_exception = root.add_class("Exception")
|
||||
builtin_exception.is_exported = False
|
||||
@ -215,6 +218,86 @@ def refine_highgui_module(root: NamespaceNode) -> None:
|
||||
)
|
||||
|
||||
|
||||
def refine_dnn_module(root: NamespaceNode) -> None:
|
||||
if "dnn" not in root.namespaces:
|
||||
return
|
||||
dnn_module = root.namespaces["dnn"]
|
||||
|
||||
"""
|
||||
class LayerProtocol(Protocol):
|
||||
def __init__(
|
||||
self, params: dict[str, DictValue],
|
||||
blobs: typing.Sequence[cv2.typing.MatLike]
|
||||
) -> None: ...
|
||||
|
||||
def getMemoryShapes(
|
||||
self, inputs: typing.Sequence[typing.Sequence[int]]
|
||||
) -> typing.Sequence[typing.Sequence[int]]: ...
|
||||
|
||||
def forward(
|
||||
self, inputs: typing.Sequence[cv2.typing.MatLike]
|
||||
) -> typing.Sequence[cv2.typing.MatLike]: ...
|
||||
"""
|
||||
layer_proto = ProtocolClassNode("LayerProtocol", dnn_module)
|
||||
layer_proto.add_function(
|
||||
"__init__",
|
||||
arguments=[
|
||||
FunctionNode.Arg(
|
||||
"params",
|
||||
DictTypeNode(
|
||||
"LayerParams", PrimitiveTypeNode.str_(),
|
||||
create_type_node("cv::dnn::DictValue")
|
||||
)
|
||||
),
|
||||
FunctionNode.Arg("blobs", create_type_node("vector<cv::Mat>"))
|
||||
]
|
||||
)
|
||||
layer_proto.add_function(
|
||||
"getMemoryShapes",
|
||||
arguments=[
|
||||
FunctionNode.Arg("inputs",
|
||||
create_type_node("vector<vector<int>>"))
|
||||
],
|
||||
return_type=FunctionNode.RetType(
|
||||
create_type_node("vector<vector<int>>")
|
||||
)
|
||||
)
|
||||
layer_proto.add_function(
|
||||
"forward",
|
||||
arguments=[
|
||||
FunctionNode.Arg("inputs", create_type_node("vector<cv::Mat>"))
|
||||
],
|
||||
return_type=FunctionNode.RetType(create_type_node("vector<cv::Mat>"))
|
||||
)
|
||||
|
||||
"""
|
||||
def dnn_registerLayer(layerTypeName: str,
|
||||
layerClass: typing.Type[LayerProtocol]) -> None: ...
|
||||
"""
|
||||
root.add_function(
|
||||
"dnn_registerLayer",
|
||||
arguments=[
|
||||
FunctionNode.Arg("layerTypeName", PrimitiveTypeNode.str_()),
|
||||
FunctionNode.Arg(
|
||||
"layerClass",
|
||||
ClassTypeNode(ASTNodeTypeNode(
|
||||
layer_proto.export_name, f"dnn.{layer_proto.export_name}"
|
||||
))
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
"""
|
||||
def dnn_unregisterLayer(layerTypeName: str) -> None: ...
|
||||
"""
|
||||
root.add_function(
|
||||
"dnn_unregisterLayer",
|
||||
arguments=[
|
||||
FunctionNode.Arg("layerTypeName", PrimitiveTypeNode.str_())
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _trim_class_name_from_argument_types(
|
||||
overloads: Iterable[FunctionNode.Overload],
|
||||
class_name: str
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import (NamedTuple, Sequence, Tuple, Union, List,
|
||||
Dict, Callable, Optional, Generator)
|
||||
Dict, Callable, Optional, Generator, cast)
|
||||
import keyword
|
||||
|
||||
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
|
||||
@ -204,9 +204,7 @@ def create_function_node_in_scope(scope: Union[NamespaceNode, ClassNode],
|
||||
outlist = variant.py_outlist
|
||||
for _, argno in outlist:
|
||||
assert argno >= 0, \
|
||||
"Logic Error! Outlist contains function return type: {}".format(
|
||||
outlist
|
||||
)
|
||||
f"Logic Error! Outlist contains function return type: {outlist}"
|
||||
|
||||
ret_types.append(create_type_node(variant.args[argno].tp))
|
||||
|
||||
@ -379,7 +377,7 @@ def get_enclosing_namespace(
|
||||
node.full_export_name, node.native_name
|
||||
)
|
||||
if class_node_callback:
|
||||
class_node_callback(parent_node)
|
||||
class_node_callback(cast(ClassNode, parent_node))
|
||||
parent_node = parent_node.parent
|
||||
return parent_node
|
||||
|
||||
@ -395,12 +393,14 @@ def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, st
|
||||
Returns:
|
||||
Tuple[str, str]: a pair of enum export name and its full module name.
|
||||
"""
|
||||
enum_export_name = enum_node.export_name
|
||||
|
||||
def update_full_export_name(class_node: ClassNode) -> None:
|
||||
nonlocal enum_export_name
|
||||
enum_export_name = class_node.export_name + "_" + enum_export_name
|
||||
|
||||
enum_export_name = enum_node.export_name
|
||||
namespace_node = get_enclosing_namespace(enum_node, update_full_export_name)
|
||||
namespace_node = get_enclosing_namespace(enum_node,
|
||||
update_full_export_name)
|
||||
return enum_export_name, namespace_node.full_export_name
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@ __all__ = ("generate_typing_stubs", )
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import (Type, Callable, NamedTuple, Union, Set, Dict,
|
||||
from typing import (Callable, NamedTuple, Union, Set, Dict,
|
||||
Collection, Tuple, List)
|
||||
import warnings
|
||||
|
||||
@ -16,7 +16,8 @@ from .predefined_types import PREDEFINED_TYPES
|
||||
from .api_refinement import apply_manual_api_refinement
|
||||
|
||||
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
|
||||
FunctionNode, EnumerationNode, ConstantNode)
|
||||
FunctionNode, EnumerationNode, ConstantNode,
|
||||
ProtocolClassNode)
|
||||
|
||||
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
|
||||
AggregatedTypeNode, ASTNodeTypeNode,
|
||||
@ -112,11 +113,13 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
|
||||
# NOTE: Enumerations require special handling, because all enumeration
|
||||
# constants are exposed as module attributes
|
||||
has_enums = _generate_section_stub(
|
||||
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0
|
||||
StubSection("# Enumerations", ASTNodeType.Enumeration), root,
|
||||
output_stream, 0
|
||||
)
|
||||
# Collect all enums from class level and export them to module level
|
||||
for class_node in root.classes.values():
|
||||
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
|
||||
if _generate_enums_from_classes_tree(class_node, output_stream,
|
||||
indent=0):
|
||||
has_enums = True
|
||||
# 2 empty lines between enum and classes definitions
|
||||
if has_enums:
|
||||
@ -134,14 +137,15 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
|
||||
|
||||
class StubSection(NamedTuple):
|
||||
name: str
|
||||
node_type: Type[ASTNode]
|
||||
node_type: ASTNodeType
|
||||
|
||||
|
||||
STUB_SECTIONS = (
|
||||
StubSection("# Constants", ConstantNode),
|
||||
# StubSection("# Enumerations", EnumerationNode), # Skipped for now (special rules)
|
||||
StubSection("# Classes", ClassNode),
|
||||
StubSection("# Functions", FunctionNode)
|
||||
StubSection("# Constants", ASTNodeType.Constant),
|
||||
# Enumerations are skipped due to special handling rules
|
||||
# StubSection("# Enumerations", ASTNodeType.Enumeration),
|
||||
StubSection("# Classes", ASTNodeType.Class),
|
||||
StubSection("# Functions", ASTNodeType.Function)
|
||||
)
|
||||
|
||||
|
||||
@ -250,9 +254,9 @@ def _generate_class_stub(class_node: ClassNode, output_stream: StringIO,
|
||||
else:
|
||||
bases.append(base.export_name)
|
||||
|
||||
inheritance_str = "({})".format(
|
||||
', '.join(bases)
|
||||
)
|
||||
inheritance_str = f"({', '.join(bases)})"
|
||||
elif isinstance(class_node, ProtocolClassNode):
|
||||
inheritance_str = "(Protocol)"
|
||||
else:
|
||||
inheritance_str = ""
|
||||
|
||||
@ -547,7 +551,8 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||
|
||||
def _collect_required_imports(root: NamespaceNode) -> Collection[str]:
|
||||
"""Collects all imports required for classes and functions typing stubs
|
||||
declarations.
|
||||
|
||||
@ -555,8 +560,8 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||
root (NamespaceNode): Namespace node to collect imports for
|
||||
|
||||
Returns:
|
||||
Set[str]: Collection of unique `import smth` statements required for
|
||||
classes and function declarations of `root` node.
|
||||
Collection[str]: Collection of unique `import smth` statements required
|
||||
for classes and function declarations of `root` node.
|
||||
"""
|
||||
|
||||
def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
|
||||
@ -569,6 +574,7 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||
has_overload = check_overload_presence(root)
|
||||
# if there is no module-level functions with overload, check its presence
|
||||
# during class traversing, including their inner-classes
|
||||
has_protocol = False
|
||||
for cls in for_each_class(root):
|
||||
if not has_overload and check_overload_presence(cls):
|
||||
has_overload = True
|
||||
@ -583,6 +589,8 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||
required_imports.add(
|
||||
"import " + base_namespace.full_export_name
|
||||
)
|
||||
if isinstance(cls, ProtocolClassNode):
|
||||
has_protocol = True
|
||||
|
||||
if has_overload:
|
||||
required_imports.add("import typing")
|
||||
@ -599,7 +607,20 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||
if root_import in required_imports:
|
||||
required_imports.remove(root_import)
|
||||
|
||||
return required_imports
|
||||
if has_protocol:
|
||||
required_imports.add("import sys")
|
||||
ordered_required_imports = sorted(required_imports)
|
||||
|
||||
# Protocol import always goes as last import statement
|
||||
if has_protocol:
|
||||
ordered_required_imports.append(
|
||||
"""if sys.version_info >= (3, 8):
|
||||
from typing import Protocol
|
||||
else:
|
||||
from typing_extensions import Protocol"""
|
||||
)
|
||||
|
||||
return ordered_required_imports
|
||||
|
||||
|
||||
def _populate_reexported_symbols(root: NamespaceNode) -> None:
|
||||
@ -666,7 +687,7 @@ def _write_required_imports(required_imports: Collection[str],
|
||||
output_stream (StringIO): Output stream for import statements.
|
||||
"""
|
||||
|
||||
for required_import in sorted(required_imports):
|
||||
for required_import in required_imports:
|
||||
output_stream.write(required_import)
|
||||
output_stream.write("\n")
|
||||
if len(required_imports):
|
||||
@ -803,8 +824,8 @@ StubGenerator = Callable[[ASTNode, StringIO, int], None]
|
||||
|
||||
|
||||
NODE_TYPE_TO_STUB_GENERATOR = {
|
||||
ClassNode: _generate_class_stub,
|
||||
ConstantNode: _generate_constant_stub,
|
||||
EnumerationNode: _generate_enumeration_stub,
|
||||
FunctionNode: _generate_function_stub
|
||||
ASTNodeType.Class: _generate_class_stub,
|
||||
ASTNodeType.Constant: _generate_constant_stub,
|
||||
ASTNodeType.Enumeration: _generate_enumeration_stub,
|
||||
ASTNodeType.Function: _generate_function_stub
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
from .node import ASTNode, ASTNodeType
|
||||
from .namespace_node import NamespaceNode
|
||||
from .class_node import ClassNode, ClassProperty
|
||||
from .class_node import ClassNode, ClassProperty, ProtocolClassNode
|
||||
from .function_node import FunctionNode
|
||||
from .enumeration_node import EnumerationNode
|
||||
from .constant_node import ConstantNode
|
||||
@ -8,5 +8,5 @@ from .type_node import (
|
||||
TypeNode, OptionalTypeNode, UnionTypeNode, NoneTypeNode, TupleTypeNode,
|
||||
ASTNodeTypeNode, AliasTypeNode, SequenceTypeNode, AnyTypeNode,
|
||||
AggregatedTypeNode, NDArrayTypeNode, AliasRefTypeNode, PrimitiveTypeNode,
|
||||
CallableTypeNode,
|
||||
CallableTypeNode, DictTypeNode, ClassTypeNode
|
||||
)
|
||||
|
@ -63,8 +63,9 @@ class ClassNode(ASTNode):
|
||||
return 1 + sum(base.weight for base in self.bases)
|
||||
|
||||
@property
|
||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
||||
return (ClassNode, FunctionNode, EnumerationNode, ConstantNode)
|
||||
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||
return (ASTNodeType.Class, ASTNodeType.Function,
|
||||
ASTNodeType.Enumeration, ASTNodeType.Constant)
|
||||
|
||||
@property
|
||||
def node_type(self) -> ASTNodeType:
|
||||
@ -72,19 +73,19 @@ class ClassNode(ASTNode):
|
||||
|
||||
@property
|
||||
def classes(self) -> Dict[str, "ClassNode"]:
|
||||
return self._children[ClassNode]
|
||||
return self._children[ASTNodeType.Class]
|
||||
|
||||
@property
|
||||
def functions(self) -> Dict[str, FunctionNode]:
|
||||
return self._children[FunctionNode]
|
||||
return self._children[ASTNodeType.Function]
|
||||
|
||||
@property
|
||||
def enumerations(self) -> Dict[str, EnumerationNode]:
|
||||
return self._children[EnumerationNode]
|
||||
return self._children[ASTNodeType.Enumeration]
|
||||
|
||||
@property
|
||||
def constants(self) -> Dict[str, ConstantNode]:
|
||||
return self._children[ConstantNode]
|
||||
return self._children[ASTNodeType.Constant]
|
||||
|
||||
def add_class(self, name: str,
|
||||
bases: Sequence["weakref.ProxyType[ClassNode]"] = (),
|
||||
@ -179,3 +180,11 @@ class ClassNode(ASTNode):
|
||||
self.full_export_name, root.full_export_name, errors
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ProtocolClassNode(ClassNode):
|
||||
def __init__(self, name: str, parent: Optional[ASTNode] = None,
|
||||
export_name: Optional[str] = None,
|
||||
properties: Sequence[ClassProperty] = ()) -> None:
|
||||
super().__init__(name, parent, export_name, bases=(),
|
||||
properties=properties)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Type, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .node import ASTNode, ASTNodeType
|
||||
|
||||
@ -14,7 +14,7 @@ class ConstantNode(ASTNode):
|
||||
self._value_type = "int"
|
||||
|
||||
@property
|
||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
||||
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||
return ()
|
||||
|
||||
@property
|
||||
|
@ -18,8 +18,8 @@ class EnumerationNode(ASTNode):
|
||||
self.is_scoped = is_scoped
|
||||
|
||||
@property
|
||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
||||
return (ConstantNode, )
|
||||
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||
return (ASTNodeType.Constant, )
|
||||
|
||||
@property
|
||||
def node_type(self) -> ASTNodeType:
|
||||
@ -27,7 +27,7 @@ class EnumerationNode(ASTNode):
|
||||
|
||||
@property
|
||||
def constants(self) -> Dict[str, ConstantNode]:
|
||||
return self._children[ConstantNode]
|
||||
return self._children[ASTNodeType.Constant]
|
||||
|
||||
def add_constant(self, name: str, value: str) -> ConstantNode:
|
||||
return self._add_child(ConstantNode, name, value=value)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import NamedTuple, Sequence, Type, Optional, Tuple, List
|
||||
from typing import NamedTuple, Sequence, Optional, Tuple, List
|
||||
|
||||
from .node import ASTNode, ASTNodeType
|
||||
from .type_node import TypeNode, NoneTypeNode, TypeResolutionError
|
||||
@ -98,7 +98,7 @@ class FunctionNode(ASTNode):
|
||||
return ASTNodeType.Function
|
||||
|
||||
@property
|
||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
||||
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||
return ()
|
||||
|
||||
def add_overload(self, arguments: Sequence["FunctionNode.Arg"] = (),
|
||||
|
@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Type
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from .class_node import ClassNode, ClassProperty
|
||||
from .constant_node import ConstantNode
|
||||
@ -33,29 +33,29 @@ class NamespaceNode(ASTNode):
|
||||
return ASTNodeType.Namespace
|
||||
|
||||
@property
|
||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
||||
return (NamespaceNode, ClassNode, FunctionNode,
|
||||
EnumerationNode, ConstantNode)
|
||||
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||
return (ASTNodeType.Namespace, ASTNodeType.Class, ASTNodeType.Function,
|
||||
ASTNodeType.Enumeration, ASTNodeType.Constant)
|
||||
|
||||
@property
|
||||
def namespaces(self) -> Dict[str, "NamespaceNode"]:
|
||||
return self._children[NamespaceNode]
|
||||
return self._children[ASTNodeType.Namespace]
|
||||
|
||||
@property
|
||||
def classes(self) -> Dict[str, ClassNode]:
|
||||
return self._children[ClassNode]
|
||||
return self._children[ASTNodeType.Class]
|
||||
|
||||
@property
|
||||
def functions(self) -> Dict[str, FunctionNode]:
|
||||
return self._children[FunctionNode]
|
||||
return self._children[ASTNodeType.Function]
|
||||
|
||||
@property
|
||||
def enumerations(self) -> Dict[str, EnumerationNode]:
|
||||
return self._children[EnumerationNode]
|
||||
return self._children[ASTNodeType.Enumeration]
|
||||
|
||||
@property
|
||||
def constants(self) -> Dict[str, ConstantNode]:
|
||||
return self._children[ConstantNode]
|
||||
return self._children[ASTNodeType.Constant]
|
||||
|
||||
def add_namespace(self, name: str) -> "NamespaceNode":
|
||||
return self._add_child(NamespaceNode, name)
|
||||
|
@ -70,22 +70,22 @@ class ASTNode:
|
||||
self._parent: Optional["ASTNode"] = None
|
||||
self.parent = parent
|
||||
self.is_exported = True
|
||||
self._children: DefaultDict[NodeType, NameToNode] = defaultdict(dict)
|
||||
self._children: DefaultDict[ASTNodeType, NameToNode] = defaultdict(dict)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}('{}' exported as '{}')".format(
|
||||
type(self).__name__.replace("Node", ""), self.name, self.export_name
|
||||
self.node_type.name, self.name, self.export_name
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
@abc.abstractproperty
|
||||
def children_types(self) -> Tuple[Type["ASTNode"], ...]:
|
||||
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||
"""Set of ASTNode types that are allowed to be children of this node
|
||||
|
||||
Returns:
|
||||
Tuple[Type[ASTNode], ...]: Types of children nodes
|
||||
Tuple[ASTNodeType, ...]: Types of children nodes
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -99,6 +99,9 @@ class ASTNode:
|
||||
"""
|
||||
pass
|
||||
|
||||
def node_type_name(self) -> str:
|
||||
return f"{self.node_type.name}::{self.name}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.__name
|
||||
@ -126,11 +129,11 @@ class ASTNode:
|
||||
"but got: {}".format(type(value))
|
||||
|
||||
if value is not None:
|
||||
value.__check_child_before_add(type(self), self.name)
|
||||
value.__check_child_before_add(self, self.name)
|
||||
|
||||
# Detach from previous parent
|
||||
if self._parent is not None:
|
||||
self._parent._children[type(self)].pop(self.name)
|
||||
self._parent._children[self.node_type].pop(self.name)
|
||||
|
||||
if value is None:
|
||||
self._parent = None
|
||||
@ -138,28 +141,26 @@ class ASTNode:
|
||||
|
||||
# Set a weak reference to a new parent and add self to its children
|
||||
self._parent = weakref.proxy(value)
|
||||
value._children[type(self)][self.name] = self
|
||||
value._children[self.node_type][self.name] = self
|
||||
|
||||
def __check_child_before_add(self, child_type: Type[ASTNodeSubtype],
|
||||
def __check_child_before_add(self, child: ASTNodeSubtype,
|
||||
name: str) -> None:
|
||||
assert len(self.children_types) > 0, \
|
||||
"Trying to add child node '{}::{}' to node '{}::{}' " \
|
||||
"that can't have children nodes".format(child_type.__name__, name,
|
||||
type(self).__name__,
|
||||
self.name)
|
||||
assert len(self.children_types) > 0, (
|
||||
f"Trying to add child node '{child.node_type_name}' to node "
|
||||
f"'{self.node_type_name}' that can't have children nodes"
|
||||
)
|
||||
|
||||
assert child_type in self.children_types, \
|
||||
"Trying to add child node '{}::{}' to node '{}::{}' " \
|
||||
assert child.node_type in self.children_types, \
|
||||
"Trying to add child node '{}' to node '{}' " \
|
||||
"that supports only ({}) as its children types".format(
|
||||
child_type.__name__, name, type(self).__name__, self.name,
|
||||
",".join(t.__name__ for t in self.children_types)
|
||||
child.node_type_name, self.node_type_name,
|
||||
",".join(t.name for t in self.children_types)
|
||||
)
|
||||
|
||||
if self._find_child(child_type, name) is not None:
|
||||
if self._find_child(child.node_type, name) is not None:
|
||||
raise ValueError(
|
||||
"Node '{}::{}' already has a child '{}::{}'".format(
|
||||
type(self).__name__, self.name, child_type.__name__, name
|
||||
)
|
||||
f"Node '{self.node_type_name}' already has a "
|
||||
f"child '{child.node_type_name}'"
|
||||
)
|
||||
|
||||
def _add_child(self, child_type: Type[ASTNodeSubtype], name: str,
|
||||
@ -180,15 +181,14 @@ class ASTNode:
|
||||
Returns:
|
||||
ASTNodeSubtype: Created ASTNode
|
||||
"""
|
||||
self.__check_child_before_add(child_type, name)
|
||||
return child_type(name, parent=self, **kwargs)
|
||||
|
||||
def _find_child(self, child_type: Type[ASTNodeSubtype],
|
||||
def _find_child(self, child_type: ASTNodeType,
|
||||
name: str) -> Optional[ASTNodeSubtype]:
|
||||
"""Looks for child node with the given type and name.
|
||||
|
||||
Args:
|
||||
child_type (Type[ASTNodeSubtype]): Type of the child node.
|
||||
child_type (ASTNodeType): Type of the child node.
|
||||
name (str): Name of the child node.
|
||||
|
||||
Returns:
|
||||
|
@ -839,6 +839,21 @@ class CallableTypeNode(AggregatedTypeNode):
|
||||
yield from super().required_usage_imports
|
||||
|
||||
|
||||
class ClassTypeNode(ContainerTypeNode):
|
||||
"""Type node representing types themselves (refer to typing.Type)
|
||||
"""
|
||||
def __init__(self, value: TypeNode) -> None:
|
||||
super().__init__(value.ctype_name, (value,))
|
||||
|
||||
@property
|
||||
def type_format(self) -> str:
|
||||
return "typing.Type[{}]"
|
||||
|
||||
@property
|
||||
def types_separator(self) -> str:
|
||||
return ", "
|
||||
|
||||
|
||||
def _resolve_symbol(root: Optional[ASTNode], full_symbol_name: str) -> Optional[ASTNode]:
|
||||
"""Searches for a symbol with the given full export name in the AST
|
||||
starting from the `root`.
|
||||
|
Loading…
Reference in New Issue
Block a user