mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 03:33:28 +08:00
Merge pull request #24066 from VadimLevin:dev/vlevin/python-typing-register-dnn-layer
Python typing refinement for dnn_registerLayer/dnn_unregisterLayer functions #24066 This patch introduces typings generation for `dnn_registerLayer`/`dnn_unregisterLayer` manually defined in [`cv2/modules/dnn/misc/python/pyopencv_dnn.hpp`](https://github.com/opencv/opencv/blob/4.x/modules/dnn/misc/python/pyopencv_dnn.hpp) Updates: - Add `LayerProtocol` to `cv2/dnn/__init__.pyi`: ```python class LayerProtocol(Protocol): def __init__( self, params: dict[str, DictValue], blobs: typing.Sequence[cv2.typing.MatLike] ) -> None: ... def getMemoryShapes( self, inputs: typing.Sequence[typing.Sequence[int]] ) -> typing.Sequence[typing.Sequence[int]]: ... def forward( self, inputs: typing.Sequence[cv2.typing.MatLike] ) -> typing.Sequence[cv2.typing.MatLike]: ... ``` - Add `dnn_registerLayer` function to `cv2/__init__.pyi`: ```python def dnn_registerLayer(layerTypeName: str, layerClass: typing.Type[LayerProtocol]) -> None: ... ``` - Add `dnn_unregisterLayer` function to `cv2/__init__.pyi`: ```python def dnn_unregisterLayer(layerTypeName: str) -> None: ... ``` ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
5c090b9eec
commit
0c5d74ec1a
@ -7,15 +7,18 @@ from typing import cast, Sequence, Callable, Iterable
|
|||||||
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
|
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
|
||||||
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
|
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
|
||||||
AggregatedTypeNode, CallableTypeNode, AnyTypeNode,
|
AggregatedTypeNode, CallableTypeNode, AnyTypeNode,
|
||||||
TupleTypeNode, UnionTypeNode)
|
TupleTypeNode, UnionTypeNode, ProtocolClassNode,
|
||||||
|
DictTypeNode, ClassTypeNode)
|
||||||
from .ast_utils import (find_function_node, SymbolName,
|
from .ast_utils import (find_function_node, SymbolName,
|
||||||
for_each_function_overload)
|
for_each_function_overload)
|
||||||
|
from .types_conversion import create_type_node
|
||||||
|
|
||||||
|
|
||||||
def apply_manual_api_refinement(root: NamespaceNode) -> None:
|
def apply_manual_api_refinement(root: NamespaceNode) -> None:
|
||||||
refine_highgui_module(root)
|
refine_highgui_module(root)
|
||||||
refine_cuda_module(root)
|
refine_cuda_module(root)
|
||||||
export_matrix_type_constants(root)
|
export_matrix_type_constants(root)
|
||||||
|
refine_dnn_module(root)
|
||||||
# Export OpenCV exception class
|
# Export OpenCV exception class
|
||||||
builtin_exception = root.add_class("Exception")
|
builtin_exception = root.add_class("Exception")
|
||||||
builtin_exception.is_exported = False
|
builtin_exception.is_exported = False
|
||||||
@ -215,6 +218,86 @@ def refine_highgui_module(root: NamespaceNode) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def refine_dnn_module(root: NamespaceNode) -> None:
|
||||||
|
if "dnn" not in root.namespaces:
|
||||||
|
return
|
||||||
|
dnn_module = root.namespaces["dnn"]
|
||||||
|
|
||||||
|
"""
|
||||||
|
class LayerProtocol(Protocol):
|
||||||
|
def __init__(
|
||||||
|
self, params: dict[str, DictValue],
|
||||||
|
blobs: typing.Sequence[cv2.typing.MatLike]
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def getMemoryShapes(
|
||||||
|
self, inputs: typing.Sequence[typing.Sequence[int]]
|
||||||
|
) -> typing.Sequence[typing.Sequence[int]]: ...
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, inputs: typing.Sequence[cv2.typing.MatLike]
|
||||||
|
) -> typing.Sequence[cv2.typing.MatLike]: ...
|
||||||
|
"""
|
||||||
|
layer_proto = ProtocolClassNode("LayerProtocol", dnn_module)
|
||||||
|
layer_proto.add_function(
|
||||||
|
"__init__",
|
||||||
|
arguments=[
|
||||||
|
FunctionNode.Arg(
|
||||||
|
"params",
|
||||||
|
DictTypeNode(
|
||||||
|
"LayerParams", PrimitiveTypeNode.str_(),
|
||||||
|
create_type_node("cv::dnn::DictValue")
|
||||||
|
)
|
||||||
|
),
|
||||||
|
FunctionNode.Arg("blobs", create_type_node("vector<cv::Mat>"))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
layer_proto.add_function(
|
||||||
|
"getMemoryShapes",
|
||||||
|
arguments=[
|
||||||
|
FunctionNode.Arg("inputs",
|
||||||
|
create_type_node("vector<vector<int>>"))
|
||||||
|
],
|
||||||
|
return_type=FunctionNode.RetType(
|
||||||
|
create_type_node("vector<vector<int>>")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
layer_proto.add_function(
|
||||||
|
"forward",
|
||||||
|
arguments=[
|
||||||
|
FunctionNode.Arg("inputs", create_type_node("vector<cv::Mat>"))
|
||||||
|
],
|
||||||
|
return_type=FunctionNode.RetType(create_type_node("vector<cv::Mat>"))
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
def dnn_registerLayer(layerTypeName: str,
|
||||||
|
layerClass: typing.Type[LayerProtocol]) -> None: ...
|
||||||
|
"""
|
||||||
|
root.add_function(
|
||||||
|
"dnn_registerLayer",
|
||||||
|
arguments=[
|
||||||
|
FunctionNode.Arg("layerTypeName", PrimitiveTypeNode.str_()),
|
||||||
|
FunctionNode.Arg(
|
||||||
|
"layerClass",
|
||||||
|
ClassTypeNode(ASTNodeTypeNode(
|
||||||
|
layer_proto.export_name, f"dnn.{layer_proto.export_name}"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
def dnn_unregisterLayer(layerTypeName: str) -> None: ...
|
||||||
|
"""
|
||||||
|
root.add_function(
|
||||||
|
"dnn_unregisterLayer",
|
||||||
|
arguments=[
|
||||||
|
FunctionNode.Arg("layerTypeName", PrimitiveTypeNode.str_())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _trim_class_name_from_argument_types(
|
def _trim_class_name_from_argument_types(
|
||||||
overloads: Iterable[FunctionNode.Overload],
|
overloads: Iterable[FunctionNode.Overload],
|
||||||
class_name: str
|
class_name: str
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from typing import (NamedTuple, Sequence, Tuple, Union, List,
|
from typing import (NamedTuple, Sequence, Tuple, Union, List,
|
||||||
Dict, Callable, Optional, Generator)
|
Dict, Callable, Optional, Generator, cast)
|
||||||
import keyword
|
import keyword
|
||||||
|
|
||||||
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
|
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
|
||||||
@ -204,9 +204,7 @@ def create_function_node_in_scope(scope: Union[NamespaceNode, ClassNode],
|
|||||||
outlist = variant.py_outlist
|
outlist = variant.py_outlist
|
||||||
for _, argno in outlist:
|
for _, argno in outlist:
|
||||||
assert argno >= 0, \
|
assert argno >= 0, \
|
||||||
"Logic Error! Outlist contains function return type: {}".format(
|
f"Logic Error! Outlist contains function return type: {outlist}"
|
||||||
outlist
|
|
||||||
)
|
|
||||||
|
|
||||||
ret_types.append(create_type_node(variant.args[argno].tp))
|
ret_types.append(create_type_node(variant.args[argno].tp))
|
||||||
|
|
||||||
@ -379,7 +377,7 @@ def get_enclosing_namespace(
|
|||||||
node.full_export_name, node.native_name
|
node.full_export_name, node.native_name
|
||||||
)
|
)
|
||||||
if class_node_callback:
|
if class_node_callback:
|
||||||
class_node_callback(parent_node)
|
class_node_callback(cast(ClassNode, parent_node))
|
||||||
parent_node = parent_node.parent
|
parent_node = parent_node.parent
|
||||||
return parent_node
|
return parent_node
|
||||||
|
|
||||||
@ -395,12 +393,14 @@ def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, st
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str]: a pair of enum export name and its full module name.
|
Tuple[str, str]: a pair of enum export name and its full module name.
|
||||||
"""
|
"""
|
||||||
|
enum_export_name = enum_node.export_name
|
||||||
|
|
||||||
def update_full_export_name(class_node: ClassNode) -> None:
|
def update_full_export_name(class_node: ClassNode) -> None:
|
||||||
nonlocal enum_export_name
|
nonlocal enum_export_name
|
||||||
enum_export_name = class_node.export_name + "_" + enum_export_name
|
enum_export_name = class_node.export_name + "_" + enum_export_name
|
||||||
|
|
||||||
enum_export_name = enum_node.export_name
|
namespace_node = get_enclosing_namespace(enum_node,
|
||||||
namespace_node = get_enclosing_namespace(enum_node, update_full_export_name)
|
update_full_export_name)
|
||||||
return enum_export_name, namespace_node.full_export_name
|
return enum_export_name, namespace_node.full_export_name
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ __all__ = ("generate_typing_stubs", )
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
from typing import (Type, Callable, NamedTuple, Union, Set, Dict,
|
from typing import (Callable, NamedTuple, Union, Set, Dict,
|
||||||
Collection, Tuple, List)
|
Collection, Tuple, List)
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -16,7 +16,8 @@ from .predefined_types import PREDEFINED_TYPES
|
|||||||
from .api_refinement import apply_manual_api_refinement
|
from .api_refinement import apply_manual_api_refinement
|
||||||
|
|
||||||
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
|
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
|
||||||
FunctionNode, EnumerationNode, ConstantNode)
|
FunctionNode, EnumerationNode, ConstantNode,
|
||||||
|
ProtocolClassNode)
|
||||||
|
|
||||||
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
|
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
|
||||||
AggregatedTypeNode, ASTNodeTypeNode,
|
AggregatedTypeNode, ASTNodeTypeNode,
|
||||||
@ -112,11 +113,13 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
|
|||||||
# NOTE: Enumerations require special handling, because all enumeration
|
# NOTE: Enumerations require special handling, because all enumeration
|
||||||
# constants are exposed as module attributes
|
# constants are exposed as module attributes
|
||||||
has_enums = _generate_section_stub(
|
has_enums = _generate_section_stub(
|
||||||
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0
|
StubSection("# Enumerations", ASTNodeType.Enumeration), root,
|
||||||
|
output_stream, 0
|
||||||
)
|
)
|
||||||
# Collect all enums from class level and export them to module level
|
# Collect all enums from class level and export them to module level
|
||||||
for class_node in root.classes.values():
|
for class_node in root.classes.values():
|
||||||
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
|
if _generate_enums_from_classes_tree(class_node, output_stream,
|
||||||
|
indent=0):
|
||||||
has_enums = True
|
has_enums = True
|
||||||
# 2 empty lines between enum and classes definitions
|
# 2 empty lines between enum and classes definitions
|
||||||
if has_enums:
|
if has_enums:
|
||||||
@ -134,14 +137,15 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
|
|||||||
|
|
||||||
class StubSection(NamedTuple):
|
class StubSection(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
node_type: Type[ASTNode]
|
node_type: ASTNodeType
|
||||||
|
|
||||||
|
|
||||||
STUB_SECTIONS = (
|
STUB_SECTIONS = (
|
||||||
StubSection("# Constants", ConstantNode),
|
StubSection("# Constants", ASTNodeType.Constant),
|
||||||
# StubSection("# Enumerations", EnumerationNode), # Skipped for now (special rules)
|
# Enumerations are skipped due to special handling rules
|
||||||
StubSection("# Classes", ClassNode),
|
# StubSection("# Enumerations", ASTNodeType.Enumeration),
|
||||||
StubSection("# Functions", FunctionNode)
|
StubSection("# Classes", ASTNodeType.Class),
|
||||||
|
StubSection("# Functions", ASTNodeType.Function)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -250,9 +254,9 @@ def _generate_class_stub(class_node: ClassNode, output_stream: StringIO,
|
|||||||
else:
|
else:
|
||||||
bases.append(base.export_name)
|
bases.append(base.export_name)
|
||||||
|
|
||||||
inheritance_str = "({})".format(
|
inheritance_str = f"({', '.join(bases)})"
|
||||||
', '.join(bases)
|
elif isinstance(class_node, ProtocolClassNode):
|
||||||
)
|
inheritance_str = "(Protocol)"
|
||||||
else:
|
else:
|
||||||
inheritance_str = ""
|
inheritance_str = ""
|
||||||
|
|
||||||
@ -547,7 +551,8 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
|
||||||
|
def _collect_required_imports(root: NamespaceNode) -> Collection[str]:
|
||||||
"""Collects all imports required for classes and functions typing stubs
|
"""Collects all imports required for classes and functions typing stubs
|
||||||
declarations.
|
declarations.
|
||||||
|
|
||||||
@ -555,8 +560,8 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
|||||||
root (NamespaceNode): Namespace node to collect imports for
|
root (NamespaceNode): Namespace node to collect imports for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Set[str]: Collection of unique `import smth` statements required for
|
Collection[str]: Collection of unique `import smth` statements required
|
||||||
classes and function declarations of `root` node.
|
for classes and function declarations of `root` node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
|
def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
|
||||||
@ -569,6 +574,7 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
|||||||
has_overload = check_overload_presence(root)
|
has_overload = check_overload_presence(root)
|
||||||
# if there is no module-level functions with overload, check its presence
|
# if there is no module-level functions with overload, check its presence
|
||||||
# during class traversing, including their inner-classes
|
# during class traversing, including their inner-classes
|
||||||
|
has_protocol = False
|
||||||
for cls in for_each_class(root):
|
for cls in for_each_class(root):
|
||||||
if not has_overload and check_overload_presence(cls):
|
if not has_overload and check_overload_presence(cls):
|
||||||
has_overload = True
|
has_overload = True
|
||||||
@ -583,6 +589,8 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
|||||||
required_imports.add(
|
required_imports.add(
|
||||||
"import " + base_namespace.full_export_name
|
"import " + base_namespace.full_export_name
|
||||||
)
|
)
|
||||||
|
if isinstance(cls, ProtocolClassNode):
|
||||||
|
has_protocol = True
|
||||||
|
|
||||||
if has_overload:
|
if has_overload:
|
||||||
required_imports.add("import typing")
|
required_imports.add("import typing")
|
||||||
@ -599,7 +607,20 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
|||||||
if root_import in required_imports:
|
if root_import in required_imports:
|
||||||
required_imports.remove(root_import)
|
required_imports.remove(root_import)
|
||||||
|
|
||||||
return required_imports
|
if has_protocol:
|
||||||
|
required_imports.add("import sys")
|
||||||
|
ordered_required_imports = sorted(required_imports)
|
||||||
|
|
||||||
|
# Protocol import always goes as last import statement
|
||||||
|
if has_protocol:
|
||||||
|
ordered_required_imports.append(
|
||||||
|
"""if sys.version_info >= (3, 8):
|
||||||
|
from typing import Protocol
|
||||||
|
else:
|
||||||
|
from typing_extensions import Protocol"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return ordered_required_imports
|
||||||
|
|
||||||
|
|
||||||
def _populate_reexported_symbols(root: NamespaceNode) -> None:
|
def _populate_reexported_symbols(root: NamespaceNode) -> None:
|
||||||
@ -666,7 +687,7 @@ def _write_required_imports(required_imports: Collection[str],
|
|||||||
output_stream (StringIO): Output stream for import statements.
|
output_stream (StringIO): Output stream for import statements.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for required_import in sorted(required_imports):
|
for required_import in required_imports:
|
||||||
output_stream.write(required_import)
|
output_stream.write(required_import)
|
||||||
output_stream.write("\n")
|
output_stream.write("\n")
|
||||||
if len(required_imports):
|
if len(required_imports):
|
||||||
@ -803,8 +824,8 @@ StubGenerator = Callable[[ASTNode, StringIO, int], None]
|
|||||||
|
|
||||||
|
|
||||||
NODE_TYPE_TO_STUB_GENERATOR = {
|
NODE_TYPE_TO_STUB_GENERATOR = {
|
||||||
ClassNode: _generate_class_stub,
|
ASTNodeType.Class: _generate_class_stub,
|
||||||
ConstantNode: _generate_constant_stub,
|
ASTNodeType.Constant: _generate_constant_stub,
|
||||||
EnumerationNode: _generate_enumeration_stub,
|
ASTNodeType.Enumeration: _generate_enumeration_stub,
|
||||||
FunctionNode: _generate_function_stub
|
ASTNodeType.Function: _generate_function_stub
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .node import ASTNode, ASTNodeType
|
from .node import ASTNode, ASTNodeType
|
||||||
from .namespace_node import NamespaceNode
|
from .namespace_node import NamespaceNode
|
||||||
from .class_node import ClassNode, ClassProperty
|
from .class_node import ClassNode, ClassProperty, ProtocolClassNode
|
||||||
from .function_node import FunctionNode
|
from .function_node import FunctionNode
|
||||||
from .enumeration_node import EnumerationNode
|
from .enumeration_node import EnumerationNode
|
||||||
from .constant_node import ConstantNode
|
from .constant_node import ConstantNode
|
||||||
@ -8,5 +8,5 @@ from .type_node import (
|
|||||||
TypeNode, OptionalTypeNode, UnionTypeNode, NoneTypeNode, TupleTypeNode,
|
TypeNode, OptionalTypeNode, UnionTypeNode, NoneTypeNode, TupleTypeNode,
|
||||||
ASTNodeTypeNode, AliasTypeNode, SequenceTypeNode, AnyTypeNode,
|
ASTNodeTypeNode, AliasTypeNode, SequenceTypeNode, AnyTypeNode,
|
||||||
AggregatedTypeNode, NDArrayTypeNode, AliasRefTypeNode, PrimitiveTypeNode,
|
AggregatedTypeNode, NDArrayTypeNode, AliasRefTypeNode, PrimitiveTypeNode,
|
||||||
CallableTypeNode,
|
CallableTypeNode, DictTypeNode, ClassTypeNode
|
||||||
)
|
)
|
||||||
|
@ -63,8 +63,9 @@ class ClassNode(ASTNode):
|
|||||||
return 1 + sum(base.weight for base in self.bases)
|
return 1 + sum(base.weight for base in self.bases)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||||
return (ClassNode, FunctionNode, EnumerationNode, ConstantNode)
|
return (ASTNodeType.Class, ASTNodeType.Function,
|
||||||
|
ASTNodeType.Enumeration, ASTNodeType.Constant)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node_type(self) -> ASTNodeType:
|
def node_type(self) -> ASTNodeType:
|
||||||
@ -72,19 +73,19 @@ class ClassNode(ASTNode):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def classes(self) -> Dict[str, "ClassNode"]:
|
def classes(self) -> Dict[str, "ClassNode"]:
|
||||||
return self._children[ClassNode]
|
return self._children[ASTNodeType.Class]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def functions(self) -> Dict[str, FunctionNode]:
|
def functions(self) -> Dict[str, FunctionNode]:
|
||||||
return self._children[FunctionNode]
|
return self._children[ASTNodeType.Function]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enumerations(self) -> Dict[str, EnumerationNode]:
|
def enumerations(self) -> Dict[str, EnumerationNode]:
|
||||||
return self._children[EnumerationNode]
|
return self._children[ASTNodeType.Enumeration]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def constants(self) -> Dict[str, ConstantNode]:
|
def constants(self) -> Dict[str, ConstantNode]:
|
||||||
return self._children[ConstantNode]
|
return self._children[ASTNodeType.Constant]
|
||||||
|
|
||||||
def add_class(self, name: str,
|
def add_class(self, name: str,
|
||||||
bases: Sequence["weakref.ProxyType[ClassNode]"] = (),
|
bases: Sequence["weakref.ProxyType[ClassNode]"] = (),
|
||||||
@ -179,3 +180,11 @@ class ClassNode(ASTNode):
|
|||||||
self.full_export_name, root.full_export_name, errors
|
self.full_export_name, root.full_export_name, errors
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolClassNode(ClassNode):
|
||||||
|
def __init__(self, name: str, parent: Optional[ASTNode] = None,
|
||||||
|
export_name: Optional[str] = None,
|
||||||
|
properties: Sequence[ClassProperty] = ()) -> None:
|
||||||
|
super().__init__(name, parent, export_name, bases=(),
|
||||||
|
properties=properties)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Type, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from .node import ASTNode, ASTNodeType
|
from .node import ASTNode, ASTNodeType
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ class ConstantNode(ASTNode):
|
|||||||
self._value_type = "int"
|
self._value_type = "int"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -18,8 +18,8 @@ class EnumerationNode(ASTNode):
|
|||||||
self.is_scoped = is_scoped
|
self.is_scoped = is_scoped
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||||
return (ConstantNode, )
|
return (ASTNodeType.Constant, )
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node_type(self) -> ASTNodeType:
|
def node_type(self) -> ASTNodeType:
|
||||||
@ -27,7 +27,7 @@ class EnumerationNode(ASTNode):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def constants(self) -> Dict[str, ConstantNode]:
|
def constants(self) -> Dict[str, ConstantNode]:
|
||||||
return self._children[ConstantNode]
|
return self._children[ASTNodeType.Constant]
|
||||||
|
|
||||||
def add_constant(self, name: str, value: str) -> ConstantNode:
|
def add_constant(self, name: str, value: str) -> ConstantNode:
|
||||||
return self._add_child(ConstantNode, name, value=value)
|
return self._add_child(ConstantNode, name, value=value)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import NamedTuple, Sequence, Type, Optional, Tuple, List
|
from typing import NamedTuple, Sequence, Optional, Tuple, List
|
||||||
|
|
||||||
from .node import ASTNode, ASTNodeType
|
from .node import ASTNode, ASTNodeType
|
||||||
from .type_node import TypeNode, NoneTypeNode, TypeResolutionError
|
from .type_node import TypeNode, NoneTypeNode, TypeResolutionError
|
||||||
@ -98,7 +98,7 @@ class FunctionNode(ASTNode):
|
|||||||
return ASTNodeType.Function
|
return ASTNodeType.Function
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
def add_overload(self, arguments: Sequence["FunctionNode.Arg"] = (),
|
def add_overload(self, arguments: Sequence["FunctionNode.Arg"] = (),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import weakref
|
import weakref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Type
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from .class_node import ClassNode, ClassProperty
|
from .class_node import ClassNode, ClassProperty
|
||||||
from .constant_node import ConstantNode
|
from .constant_node import ConstantNode
|
||||||
@ -33,29 +33,29 @@ class NamespaceNode(ASTNode):
|
|||||||
return ASTNodeType.Namespace
|
return ASTNodeType.Namespace
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def children_types(self) -> Tuple[Type[ASTNode], ...]:
|
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||||
return (NamespaceNode, ClassNode, FunctionNode,
|
return (ASTNodeType.Namespace, ASTNodeType.Class, ASTNodeType.Function,
|
||||||
EnumerationNode, ConstantNode)
|
ASTNodeType.Enumeration, ASTNodeType.Constant)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def namespaces(self) -> Dict[str, "NamespaceNode"]:
|
def namespaces(self) -> Dict[str, "NamespaceNode"]:
|
||||||
return self._children[NamespaceNode]
|
return self._children[ASTNodeType.Namespace]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def classes(self) -> Dict[str, ClassNode]:
|
def classes(self) -> Dict[str, ClassNode]:
|
||||||
return self._children[ClassNode]
|
return self._children[ASTNodeType.Class]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def functions(self) -> Dict[str, FunctionNode]:
|
def functions(self) -> Dict[str, FunctionNode]:
|
||||||
return self._children[FunctionNode]
|
return self._children[ASTNodeType.Function]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enumerations(self) -> Dict[str, EnumerationNode]:
|
def enumerations(self) -> Dict[str, EnumerationNode]:
|
||||||
return self._children[EnumerationNode]
|
return self._children[ASTNodeType.Enumeration]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def constants(self) -> Dict[str, ConstantNode]:
|
def constants(self) -> Dict[str, ConstantNode]:
|
||||||
return self._children[ConstantNode]
|
return self._children[ASTNodeType.Constant]
|
||||||
|
|
||||||
def add_namespace(self, name: str) -> "NamespaceNode":
|
def add_namespace(self, name: str) -> "NamespaceNode":
|
||||||
return self._add_child(NamespaceNode, name)
|
return self._add_child(NamespaceNode, name)
|
||||||
|
@ -70,22 +70,22 @@ class ASTNode:
|
|||||||
self._parent: Optional["ASTNode"] = None
|
self._parent: Optional["ASTNode"] = None
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.is_exported = True
|
self.is_exported = True
|
||||||
self._children: DefaultDict[NodeType, NameToNode] = defaultdict(dict)
|
self._children: DefaultDict[ASTNodeType, NameToNode] = defaultdict(dict)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "{}('{}' exported as '{}')".format(
|
return "{}('{}' exported as '{}')".format(
|
||||||
type(self).__name__.replace("Node", ""), self.name, self.export_name
|
self.node_type.name, self.name, self.export_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return str(self)
|
return str(self)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@abc.abstractproperty
|
||||||
def children_types(self) -> Tuple[Type["ASTNode"], ...]:
|
def children_types(self) -> Tuple[ASTNodeType, ...]:
|
||||||
"""Set of ASTNode types that are allowed to be children of this node
|
"""Set of ASTNode types that are allowed to be children of this node
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Type[ASTNode], ...]: Types of children nodes
|
Tuple[ASTNodeType, ...]: Types of children nodes
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -99,6 +99,9 @@ class ASTNode:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def node_type_name(self) -> str:
|
||||||
|
return f"{self.node_type.name}::{self.name}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self.__name
|
return self.__name
|
||||||
@ -126,11 +129,11 @@ class ASTNode:
|
|||||||
"but got: {}".format(type(value))
|
"but got: {}".format(type(value))
|
||||||
|
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value.__check_child_before_add(type(self), self.name)
|
value.__check_child_before_add(self, self.name)
|
||||||
|
|
||||||
# Detach from previous parent
|
# Detach from previous parent
|
||||||
if self._parent is not None:
|
if self._parent is not None:
|
||||||
self._parent._children[type(self)].pop(self.name)
|
self._parent._children[self.node_type].pop(self.name)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
self._parent = None
|
self._parent = None
|
||||||
@ -138,28 +141,26 @@ class ASTNode:
|
|||||||
|
|
||||||
# Set a weak reference to a new parent and add self to its children
|
# Set a weak reference to a new parent and add self to its children
|
||||||
self._parent = weakref.proxy(value)
|
self._parent = weakref.proxy(value)
|
||||||
value._children[type(self)][self.name] = self
|
value._children[self.node_type][self.name] = self
|
||||||
|
|
||||||
def __check_child_before_add(self, child_type: Type[ASTNodeSubtype],
|
def __check_child_before_add(self, child: ASTNodeSubtype,
|
||||||
name: str) -> None:
|
name: str) -> None:
|
||||||
assert len(self.children_types) > 0, \
|
assert len(self.children_types) > 0, (
|
||||||
"Trying to add child node '{}::{}' to node '{}::{}' " \
|
f"Trying to add child node '{child.node_type_name}' to node "
|
||||||
"that can't have children nodes".format(child_type.__name__, name,
|
f"'{self.node_type_name}' that can't have children nodes"
|
||||||
type(self).__name__,
|
)
|
||||||
self.name)
|
|
||||||
|
|
||||||
assert child_type in self.children_types, \
|
assert child.node_type in self.children_types, \
|
||||||
"Trying to add child node '{}::{}' to node '{}::{}' " \
|
"Trying to add child node '{}' to node '{}' " \
|
||||||
"that supports only ({}) as its children types".format(
|
"that supports only ({}) as its children types".format(
|
||||||
child_type.__name__, name, type(self).__name__, self.name,
|
child.node_type_name, self.node_type_name,
|
||||||
",".join(t.__name__ for t in self.children_types)
|
",".join(t.name for t in self.children_types)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._find_child(child_type, name) is not None:
|
if self._find_child(child.node_type, name) is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Node '{}::{}' already has a child '{}::{}'".format(
|
f"Node '{self.node_type_name}' already has a "
|
||||||
type(self).__name__, self.name, child_type.__name__, name
|
f"child '{child.node_type_name}'"
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_child(self, child_type: Type[ASTNodeSubtype], name: str,
|
def _add_child(self, child_type: Type[ASTNodeSubtype], name: str,
|
||||||
@ -180,15 +181,14 @@ class ASTNode:
|
|||||||
Returns:
|
Returns:
|
||||||
ASTNodeSubtype: Created ASTNode
|
ASTNodeSubtype: Created ASTNode
|
||||||
"""
|
"""
|
||||||
self.__check_child_before_add(child_type, name)
|
|
||||||
return child_type(name, parent=self, **kwargs)
|
return child_type(name, parent=self, **kwargs)
|
||||||
|
|
||||||
def _find_child(self, child_type: Type[ASTNodeSubtype],
|
def _find_child(self, child_type: ASTNodeType,
|
||||||
name: str) -> Optional[ASTNodeSubtype]:
|
name: str) -> Optional[ASTNodeSubtype]:
|
||||||
"""Looks for child node with the given type and name.
|
"""Looks for child node with the given type and name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
child_type (Type[ASTNodeSubtype]): Type of the child node.
|
child_type (ASTNodeType): Type of the child node.
|
||||||
name (str): Name of the child node.
|
name (str): Name of the child node.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -839,6 +839,21 @@ class CallableTypeNode(AggregatedTypeNode):
|
|||||||
yield from super().required_usage_imports
|
yield from super().required_usage_imports
|
||||||
|
|
||||||
|
|
||||||
|
class ClassTypeNode(ContainerTypeNode):
|
||||||
|
"""Type node representing types themselves (refer to typing.Type)
|
||||||
|
"""
|
||||||
|
def __init__(self, value: TypeNode) -> None:
|
||||||
|
super().__init__(value.ctype_name, (value,))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type_format(self) -> str:
|
||||||
|
return "typing.Type[{}]"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def types_separator(self) -> str:
|
||||||
|
return ", "
|
||||||
|
|
||||||
|
|
||||||
def _resolve_symbol(root: Optional[ASTNode], full_symbol_name: str) -> Optional[ASTNode]:
|
def _resolve_symbol(root: Optional[ASTNode], full_symbol_name: str) -> Optional[ASTNode]:
|
||||||
"""Searches for a symbol with the given full export name in the AST
|
"""Searches for a symbol with the given full export name in the AST
|
||||||
starting from the `root`.
|
starting from the `root`.
|
||||||
|
Loading…
Reference in New Issue
Block a user