mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 03:33:28 +08:00
fix: conditionally define generic NumPy NDArray alias
This commit is contained in:
parent
fe4f5b539e
commit
f20edba925
@ -15,7 +15,8 @@ from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode
|
||||
EnumerationNode, ConstantNode)
|
||||
|
||||
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
|
||||
AggregatedTypeNode, ASTNodeTypeNode)
|
||||
AggregatedTypeNode, ASTNodeTypeNode,
|
||||
ConditionalAliasTypeNode, PrimitiveTypeNode)
|
||||
|
||||
|
||||
def generate_typing_stubs(root: NamespaceNode, output_path: Path):
|
||||
@ -682,28 +683,37 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
|
||||
f"Provided type node '{type_node.ctype_name}' is not an aggregated type"
|
||||
|
||||
for item in filter(lambda i: isinstance(i, AliasRefTypeNode), type_node):
|
||||
register_alias(PREDEFINED_TYPES[item.ctype_name]) # type: ignore
|
||||
type_node = PREDEFINED_TYPES[item.ctype_name]
|
||||
if isinstance(type_node, AliasTypeNode):
|
||||
register_alias(type_node)
|
||||
elif isinstance(type_node, ConditionalAliasTypeNode):
|
||||
conditional_type_nodes[type_node.ctype_name] = type_node
|
||||
|
||||
def create_alias_for_enum_node(enum_node: ASTNode) -> AliasTypeNode:
|
||||
"""Create int alias corresponding to the given enum node.
|
||||
def create_alias_for_enum_node(enum_node_alias: AliasTypeNode) -> ConditionalAliasTypeNode:
|
||||
"""Create conditional int alias corresponding to the given enum node.
|
||||
|
||||
Args:
|
||||
enum_node (ASTNodeTypeNode): Enumeration node to create int alias for.
|
||||
enum_node (AliasTypeNode): Enumeration node to create conditional
|
||||
int alias for.
|
||||
|
||||
Returns:
|
||||
AliasTypeNode: int alias node with same export name as enum.
|
||||
ConditionalAliasTypeNode: conditional int alias node with same
|
||||
export name as enum.
|
||||
"""
|
||||
enum_node = enum_node_alias.ast_node
|
||||
assert enum_node.node_type == ASTNodeType.Enumeration, \
|
||||
f"{enum_node} has wrong node type. Expected type: Enumeration."
|
||||
|
||||
enum_export_name, enum_module_name = get_enum_module_and_export_name(
|
||||
enum_node
|
||||
)
|
||||
enum_full_export_name = f"{enum_module_name}.{enum_export_name}"
|
||||
alias_node = AliasTypeNode.int_(enum_full_export_name,
|
||||
enum_export_name)
|
||||
type_checking_time_definitions.add(alias_node)
|
||||
return alias_node
|
||||
return ConditionalAliasTypeNode(
|
||||
enum_export_name,
|
||||
"typing.TYPE_CHECKING",
|
||||
positive_branch_type=enum_node_alias,
|
||||
negative_branch_type=PrimitiveTypeNode.int_(enum_export_name),
|
||||
condition_required_imports=("import typing", )
|
||||
)
|
||||
|
||||
def register_alias(alias_node: AliasTypeNode) -> None:
|
||||
typename = alias_node.typename
|
||||
@ -726,11 +736,15 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
|
||||
continue
|
||||
if item.ast_node.node_type != ASTNodeType.Enumeration:
|
||||
continue
|
||||
alias_node.value.items[i] = create_alias_for_enum_node(item.ast_node)
|
||||
enum_node = create_alias_for_enum_node(item)
|
||||
alias_node.value.items[i] = enum_node
|
||||
conditional_type_nodes[enum_node.ctype_name] = enum_node
|
||||
|
||||
if isinstance(alias_node.value, ASTNodeTypeNode) \
|
||||
and alias_node.value.ast_node == ASTNodeType.Enumeration:
|
||||
alias_node.value = create_alias_for_enum_node(alias_node.ast_node)
|
||||
enum_node = create_alias_for_enum_node(alias_node.ast_node)
|
||||
conditional_type_nodes[enum_node.ctype_name] = enum_node
|
||||
return
|
||||
|
||||
# Strip module prefix from aliased types
|
||||
aliases[typename] = alias_node.value.full_typename.replace(
|
||||
@ -744,7 +758,7 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
|
||||
|
||||
required_imports: Set[str] = set()
|
||||
aliases: Dict[str, str] = {}
|
||||
type_checking_time_definitions: Set[AliasTypeNode] = set()
|
||||
conditional_type_nodes: Dict[str, ConditionalAliasTypeNode] = {}
|
||||
|
||||
# Resolve each node and register aliases
|
||||
TypeNode.compatible_to_runtime_usage = True
|
||||
@ -752,6 +766,12 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
|
||||
node.resolve(root)
|
||||
if isinstance(node, AliasTypeNode):
|
||||
register_alias(node)
|
||||
elif isinstance(node, ConditionalAliasTypeNode):
|
||||
conditional_type_nodes[node.ctype_name] = node
|
||||
|
||||
for node in conditional_type_nodes.values():
|
||||
for required_import in node.required_definition_imports:
|
||||
required_imports.add(required_import)
|
||||
|
||||
output_stream = StringIO()
|
||||
output_stream.write("__all__ = [\n")
|
||||
@ -762,12 +782,10 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
|
||||
_write_required_imports(required_imports, output_stream)
|
||||
|
||||
# Add type checking time definitions as generated __init__.py content
|
||||
for alias in type_checking_time_definitions:
|
||||
output_stream.write("if typing.TYPE_CHECKING:\n ")
|
||||
output_stream.write(f"{alias.typename} = {alias.ctype_name}\nelse:\n")
|
||||
output_stream.write(f" {alias.typename} = {alias.value.ctype_name}\n")
|
||||
if type_checking_time_definitions:
|
||||
output_stream.write("\n\n")
|
||||
for _, type_node in conditional_type_nodes.items():
|
||||
output_stream.write(f"if {type_node.condition}:\n ")
|
||||
output_stream.write(f"{type_node.typename} = {type_node.positive_branch_type.full_typename}\nelse:\n")
|
||||
output_stream.write(f" {type_node.typename} = {type_node.negative_branch_type.full_typename}\n\n\n")
|
||||
|
||||
for alias_name, alias_type in aliases.items():
|
||||
output_stream.write(f"{alias_name} = {alias_type}\n")
|
||||
|
@ -307,14 +307,31 @@ class AliasTypeNode(TypeNode):
|
||||
return cls(ctype_name, PrimitiveTypeNode.float_(), export_name, doc)
|
||||
|
||||
@classmethod
|
||||
def array_(cls, ctype_name: str, shape: Optional[Tuple[int, ...]],
|
||||
dtype: Optional[str] = None, export_name: Optional[str] = None,
|
||||
doc: Optional[str] = None):
|
||||
def array_ref_(cls, ctype_name: str, array_ref_name: str,
|
||||
shape: Optional[Tuple[int, ...]],
|
||||
dtype: Optional[str] = None,
|
||||
export_name: Optional[str] = None,
|
||||
doc: Optional[str] = None):
|
||||
"""Create alias to array reference alias `array_ref_name`.
|
||||
|
||||
This is required to preserve backward compatibility with Python < 3.9
|
||||
and NumPy 1.20, when NumPy module introduces generics support.
|
||||
|
||||
Args:
|
||||
ctype_name (str): Name of the alias.
|
||||
array_ref_name (str): Name of the conditional array alias.
|
||||
shape (Optional[Tuple[int, ...]]): Array shape.
|
||||
dtype (Optional[str], optional): Array type. Defaults to None.
|
||||
export_name (Optional[str], optional): Alias export name.
|
||||
Defaults to None.
|
||||
doc (Optional[str], optional): Documentation string for alias.
|
||||
Defaults to None.
|
||||
"""
|
||||
if doc is None:
|
||||
doc = "Shape: " + str(shape)
|
||||
doc = f"NDArray(shape={shape}, dtype={dtype})"
|
||||
else:
|
||||
doc += ". Shape: " + str(shape)
|
||||
return cls(ctype_name, NDArrayTypeNode(ctype_name, shape, dtype),
|
||||
doc += f". NDArray(shape={shape}, dtype={dtype})"
|
||||
return cls(ctype_name, AliasRefTypeNode(array_ref_name),
|
||||
export_name, doc)
|
||||
|
||||
@classmethod
|
||||
@ -376,24 +393,112 @@ class AliasTypeNode(TypeNode):
|
||||
export_name, doc)
|
||||
|
||||
|
||||
class NDArrayTypeNode(TypeNode):
|
||||
"""Type node representing NumPy ndarray.
|
||||
class ConditionalAliasTypeNode(TypeNode):
|
||||
"""Type node representing an alias protected by condition checked in runtime.
|
||||
Example:
|
||||
```python
|
||||
if numpy.lib.NumpyVersion(numpy.__version__) > "1.20.0" and sys.version_info >= (3, 9)
|
||||
NumPyArray = numpy.ndarray[typing.Any, numpy.dtype[numpy.generic]]
|
||||
else:
|
||||
NumPyArray = numpy.ndarray
|
||||
```
|
||||
is defined as follows:
|
||||
```python
|
||||
|
||||
ConditionalAliasTypeNode(
|
||||
"NumPyArray",
|
||||
'numpy.lib.NumpyVersion(numpy.__version__) > "1.20.0" and sys.version_info >= (3, 9)',
|
||||
NDArrayTypeNode("NumPyArray"),
|
||||
NDArrayTypeNode("NumPyArray", use_numpy_generics=False),
|
||||
condition_required_imports=("import numpy", "import sys")
|
||||
)
|
||||
```
|
||||
"""
|
||||
def __init__(self, ctype_name: str, shape: Optional[Tuple[int, ...]] = None,
|
||||
dtype: Optional[str] = None) -> None:
|
||||
def __init__(self, ctype_name: str, condition: str,
|
||||
positive_branch_type: TypeNode,
|
||||
negative_branch_type: TypeNode,
|
||||
export_name: Optional[str] = None,
|
||||
condition_required_imports: Sequence[str] = ()) -> None:
|
||||
super().__init__(ctype_name)
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.condition = condition
|
||||
self.positive_branch_type = positive_branch_type
|
||||
self.positive_branch_type.ctype_name = self.ctype_name
|
||||
self.negative_branch_type = negative_branch_type
|
||||
self.negative_branch_type.ctype_name = self.ctype_name
|
||||
self._export_name = export_name
|
||||
self._condition_required_imports = condition_required_imports
|
||||
|
||||
@property
|
||||
def typename(self) -> str:
|
||||
return "numpy.ndarray[{shape}, numpy.dtype[{dtype}]]".format(
|
||||
# NOTE: Shape is not fully supported yet
|
||||
# shape=self.shape if self.shape is not None else "typing.Any",
|
||||
shape="typing.Any",
|
||||
dtype=self.dtype if self.dtype is not None else "numpy.generic"
|
||||
if self._export_name is not None:
|
||||
return self._export_name
|
||||
return self.ctype_name
|
||||
|
||||
@property
|
||||
def full_typename(self) -> str:
|
||||
return "cv2.typing." + self.typename
|
||||
|
||||
@property
|
||||
def required_definition_imports(self) -> Generator[str, None, None]:
|
||||
yield from self.positive_branch_type.required_usage_imports
|
||||
yield from self.negative_branch_type.required_usage_imports
|
||||
yield from self._condition_required_imports
|
||||
|
||||
@property
|
||||
def required_usage_imports(self) -> Generator[str, None, None]:
|
||||
yield "import cv2.typing"
|
||||
|
||||
@property
|
||||
def is_resolved(self) -> bool:
|
||||
return self.positive_branch_type.is_resolved \
|
||||
and self.negative_branch_type.is_resolved
|
||||
|
||||
def resolve(self, root: ASTNode):
|
||||
try:
|
||||
self.positive_branch_type.resolve(root)
|
||||
self.negative_branch_type.resolve(root)
|
||||
except TypeResolutionError as e:
|
||||
raise TypeResolutionError(
|
||||
'Failed to resolve alias "{}" exposed as "{}"'.format(
|
||||
self.ctype_name, self.typename
|
||||
)
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def numpy_array_(cls, ctype_name: str, export_name: Optional[str] = None,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
dtype: Optional[str] = None):
|
||||
return cls(
|
||||
ctype_name,
|
||||
('numpy.lib.NumpyVersion(numpy.__version__) > "1.20.0" '
|
||||
'and sys.version_info >= (3, 9)'),
|
||||
NDArrayTypeNode(ctype_name, shape, dtype),
|
||||
NDArrayTypeNode(ctype_name, shape, dtype,
|
||||
use_numpy_generics=False),
|
||||
condition_required_imports=("import numpy", "import sys")
|
||||
)
|
||||
|
||||
|
||||
class NDArrayTypeNode(TypeNode):
|
||||
"""Type node representing NumPy ndarray.
|
||||
"""
|
||||
def __init__(self, ctype_name: str,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
dtype: Optional[str] = None,
|
||||
use_numpy_generics: bool = True) -> None:
|
||||
super().__init__(ctype_name)
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self._use_numpy_generics = use_numpy_generics
|
||||
|
||||
@property
|
||||
def typename(self) -> str:
|
||||
if self._use_numpy_generics:
|
||||
# NOTE: Shape is not fully supported yet
|
||||
dtype = self.dtype if self.dtype is not None else "numpy.generic"
|
||||
return f"numpy.ndarray[typing.Any, numpy.dtype[{dtype}]]"
|
||||
return "numpy.ndarray"
|
||||
|
||||
@property
|
||||
def required_usage_imports(self) -> Generator[str, None, None]:
|
||||
yield "import numpy"
|
||||
|
@ -1,7 +1,7 @@
|
||||
from .nodes.type_node import (
|
||||
AliasTypeNode, AliasRefTypeNode, PrimitiveTypeNode,
|
||||
ASTNodeTypeNode, NDArrayTypeNode, NoneTypeNode, SequenceTypeNode,
|
||||
TupleTypeNode, UnionTypeNode, AnyTypeNode
|
||||
TupleTypeNode, UnionTypeNode, AnyTypeNode, ConditionalAliasTypeNode
|
||||
)
|
||||
|
||||
# Set of predefined types used to cover cases when library doesn't
|
||||
@ -30,12 +30,15 @@ _PREDEFINED_TYPES = (
|
||||
PrimitiveTypeNode.str_("char"),
|
||||
PrimitiveTypeNode.str_("String"),
|
||||
PrimitiveTypeNode.str_("c_string"),
|
||||
ConditionalAliasTypeNode.numpy_array_("NumPyArrayGeneric"),
|
||||
ConditionalAliasTypeNode.numpy_array_("NumPyArrayFloat32", dtype="numpy.float32"),
|
||||
ConditionalAliasTypeNode.numpy_array_("NumPyArrayFloat64", dtype="numpy.float64"),
|
||||
NoneTypeNode("void"),
|
||||
AliasTypeNode.int_("void*", "IntPointer", "Represents an arbitrary pointer"),
|
||||
AliasTypeNode.union_(
|
||||
"Mat",
|
||||
items=(ASTNodeTypeNode("Mat", module_name="cv2.mat_wrapper"),
|
||||
NDArrayTypeNode("Mat")),
|
||||
AliasRefTypeNode("NumPyArrayGeneric")),
|
||||
export_name="MatLike"
|
||||
),
|
||||
AliasTypeNode.sequence_("MatShape", PrimitiveTypeNode.int_()),
|
||||
@ -137,10 +140,22 @@ _PREDEFINED_TYPES = (
|
||||
ASTNodeTypeNode("gapi.wip.draw.Mosaic"),
|
||||
ASTNodeTypeNode("gapi.wip.draw.Poly"))),
|
||||
SequenceTypeNode("Prims", AliasRefTypeNode("Prim")),
|
||||
AliasTypeNode.array_("Matx33f", (3, 3), "numpy.float32"),
|
||||
AliasTypeNode.array_("Matx33d", (3, 3), "numpy.float64"),
|
||||
AliasTypeNode.array_("Matx44f", (4, 4), "numpy.float32"),
|
||||
AliasTypeNode.array_("Matx44d", (4, 4), "numpy.float64"),
|
||||
AliasTypeNode.array_ref_("Matx33f",
|
||||
array_ref_name="NumPyArrayFloat32",
|
||||
shape=(3, 3),
|
||||
dtype="numpy.float32"),
|
||||
AliasTypeNode.array_ref_("Matx33d",
|
||||
array_ref_name="NumPyArrayFloat64",
|
||||
shape=(3, 3),
|
||||
dtype="numpy.float64"),
|
||||
AliasTypeNode.array_ref_("Matx44f",
|
||||
array_ref_name="NumPyArrayFloat32",
|
||||
shape=(4, 4),
|
||||
dtype="numpy.float32"),
|
||||
AliasTypeNode.array_ref_("Matx44d",
|
||||
array_ref_name="NumPyArrayFloat64",
|
||||
shape=(4, 4),
|
||||
dtype="numpy.float64"),
|
||||
NDArrayTypeNode("vector<uchar>", dtype="numpy.uint8"),
|
||||
NDArrayTypeNode("vector_uchar", dtype="numpy.uint8"),
|
||||
TupleTypeNode("GMat2", items=(ASTNodeTypeNode("GMat"),
|
||||
|
Loading…
Reference in New Issue
Block a user