mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #24022 from VadimLevin:dev/vlevin/python-typing-cuda
Fix python typing stubs generation for CUDA modules #24022 resolves #23946 resolves #23945 resolves opencv/opencv-python#871 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
b0cbd6ef77
commit
9519e67ad2
@ -2,14 +2,17 @@ __all__ = [
|
||||
"apply_manual_api_refinement"
|
||||
]
|
||||
|
||||
from typing import Sequence, Callable
|
||||
from typing import cast, Sequence, Callable, Iterable
|
||||
|
||||
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode,
|
||||
ClassProperty, PrimitiveTypeNode)
|
||||
from .ast_utils import find_function_node, SymbolName
|
||||
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
|
||||
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
|
||||
AggregatedTypeNode)
|
||||
from .ast_utils import (find_function_node, SymbolName,
|
||||
for_each_function_overload)
|
||||
|
||||
|
||||
def apply_manual_api_refinement(root: NamespaceNode) -> None:
|
||||
refine_cuda_module(root)
|
||||
export_matrix_type_constants(root)
|
||||
# Export OpenCV exception class
|
||||
builtin_exception = root.add_class("Exception")
|
||||
@ -57,13 +60,65 @@ def make_optional_arg(arg_name: str) -> Callable[[NamespaceNode, SymbolName], No
|
||||
continue
|
||||
|
||||
overload.arguments[arg_idx].type_node = OptionalTypeNode(
|
||||
overload.arguments[arg_idx].type_node
|
||||
cast(TypeNode, overload.arguments[arg_idx].type_node)
|
||||
)
|
||||
|
||||
return _make_optional_arg
|
||||
|
||||
|
||||
def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> int:
|
||||
def refine_cuda_module(root: NamespaceNode) -> None:
|
||||
def fix_cudaoptflow_enums_names() -> None:
|
||||
for class_name in ("NvidiaOpticalFlow_1_0", "NvidiaOpticalFlow_2_0"):
|
||||
if class_name not in cuda_root.classes:
|
||||
continue
|
||||
opt_flow_class = cuda_root.classes[class_name]
|
||||
_trim_class_name_from_argument_types(
|
||||
for_each_function_overload(opt_flow_class), class_name
|
||||
)
|
||||
|
||||
def fix_namespace_usage_scope(cuda_ns: NamespaceNode) -> None:
|
||||
USED_TYPES = ("GpuMat", "Stream")
|
||||
|
||||
def fix_type_usage(type_node: TypeNode) -> None:
|
||||
if isinstance(type_node, AggregatedTypeNode):
|
||||
for item in type_node.items:
|
||||
fix_type_usage(item)
|
||||
if isinstance(type_node, ASTNodeTypeNode):
|
||||
if type_node._typename in USED_TYPES:
|
||||
type_node._typename = f"cuda_{type_node._typename}"
|
||||
|
||||
for overload in for_each_function_overload(cuda_ns):
|
||||
if overload.return_type is not None:
|
||||
fix_type_usage(overload.return_type.type_node)
|
||||
for type_node in [arg.type_node for arg in overload.arguments
|
||||
if arg.type_node is not None]:
|
||||
fix_type_usage(type_node)
|
||||
|
||||
if "cuda" not in root.namespaces:
|
||||
return
|
||||
cuda_root = root.namespaces["cuda"]
|
||||
fix_cudaoptflow_enums_names()
|
||||
for ns in [ns for ns_name, ns in root.namespaces.items()
|
||||
if ns_name.startswith("cuda")]:
|
||||
fix_namespace_usage_scope(ns)
|
||||
|
||||
|
||||
def _trim_class_name_from_argument_types(
|
||||
overloads: Iterable[FunctionNode.Overload],
|
||||
class_name: str
|
||||
) -> None:
|
||||
separator = f"{class_name}_"
|
||||
for overload in overloads:
|
||||
for arg in [arg for arg in overload.arguments
|
||||
if arg.type_node is not None]:
|
||||
ast_node = cast(ASTNodeTypeNode, arg.type_node)
|
||||
if class_name in ast_node.ctype_name:
|
||||
fixed_name = ast_node._typename.split(separator)[-1]
|
||||
ast_node._typename = fixed_name
|
||||
|
||||
|
||||
def _find_argument_index(arguments: Sequence[FunctionNode.Arg],
|
||||
name: str) -> int:
|
||||
for i, arg in enumerate(arguments):
|
||||
if arg.name == name:
|
||||
return i
|
||||
@ -76,6 +131,7 @@ NODES_TO_REFINE = {
|
||||
SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"),
|
||||
SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"),
|
||||
}
|
||||
|
||||
ERROR_CLASS_PROPERTIES = (
|
||||
ClassProperty("code", PrimitiveTypeNode.int_(), False),
|
||||
ClassProperty("err", PrimitiveTypeNode.str_(), False),
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import (NamedTuple, Sequence, Tuple, Union, List,
|
||||
Dict, Callable, Optional)
|
||||
Dict, Callable, Optional, Generator)
|
||||
import keyword
|
||||
|
||||
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
|
||||
@ -404,6 +404,33 @@ def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, st
|
||||
return enum_export_name, namespace_node.full_export_name
|
||||
|
||||
|
||||
def for_each_class(
|
||||
node: Union[NamespaceNode, ClassNode]
|
||||
) -> Generator[ClassNode, None, None]:
|
||||
for cls in node.classes.values():
|
||||
yield cls
|
||||
if len(cls.classes):
|
||||
yield from for_each_class(cls)
|
||||
|
||||
|
||||
def for_each_function(
|
||||
node: Union[NamespaceNode, ClassNode],
|
||||
traverse_class_nodes: bool = True
|
||||
) -> Generator[FunctionNode, None, None]:
|
||||
yield from node.functions.values()
|
||||
if traverse_class_nodes:
|
||||
for cls in for_each_class(node):
|
||||
yield from for_each_function(cls)
|
||||
|
||||
|
||||
def for_each_function_overload(
|
||||
node: Union[NamespaceNode, ClassNode],
|
||||
traverse_class_nodes: bool = True
|
||||
) -> Generator[FunctionNode.Overload, None, None]:
|
||||
for func in for_each_function(node, traverse_class_nodes):
|
||||
yield from func.overloads
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
@ -3,17 +3,20 @@ __all__ = ("generate_typing_stubs", )
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import (Generator, Type, Callable, NamedTuple, Union, Set, Dict,
|
||||
from typing import (Type, Callable, NamedTuple, Union, Set, Dict,
|
||||
Collection, Tuple, List)
|
||||
import warnings
|
||||
|
||||
from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name
|
||||
from .ast_utils import (get_enclosing_namespace,
|
||||
get_enum_module_and_export_name,
|
||||
for_each_function_overload,
|
||||
for_each_class)
|
||||
|
||||
from .predefined_types import PREDEFINED_TYPES
|
||||
from .api_refinement import apply_manual_api_refinement
|
||||
|
||||
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode,
|
||||
EnumerationNode, ConstantNode)
|
||||
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
|
||||
FunctionNode, EnumerationNode, ConstantNode)
|
||||
|
||||
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
|
||||
AggregatedTypeNode, ASTNodeTypeNode,
|
||||
@ -105,8 +108,9 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
|
||||
|
||||
# NOTE: Enumerations require special handling, because all enumeration
|
||||
# constants are exposed as module attributes
|
||||
has_enums = _generate_section_stub(StubSection("# Enumerations", EnumerationNode),
|
||||
root, output_stream, 0)
|
||||
has_enums = _generate_section_stub(
|
||||
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0
|
||||
)
|
||||
# Collect all enums from class level and export them to module level
|
||||
for class_node in root.classes.values():
|
||||
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
|
||||
@ -536,30 +540,6 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _for_each_class(node: Union[NamespaceNode, ClassNode]) \
|
||||
-> Generator[ClassNode, None, None]:
|
||||
for cls in node.classes.values():
|
||||
yield cls
|
||||
if len(cls.classes):
|
||||
yield from _for_each_class(cls)
|
||||
|
||||
|
||||
def _for_each_function(node: Union[NamespaceNode, ClassNode]) \
|
||||
-> Generator[FunctionNode, None, None]:
|
||||
for func in node.functions.values():
|
||||
yield func
|
||||
for cls in node.classes.values():
|
||||
yield from _for_each_function(cls)
|
||||
|
||||
|
||||
def _for_each_function_overload(node: Union[NamespaceNode, ClassNode]) \
|
||||
-> Generator[FunctionNode.Overload, None, None]:
|
||||
for func in _for_each_function(node):
|
||||
for overload in func.overloads:
|
||||
yield overload
|
||||
|
||||
|
||||
def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||
"""Collects all imports required for classes and functions typing stubs
|
||||
declarations.
|
||||
@ -582,7 +562,7 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||
has_overload = check_overload_presence(root)
|
||||
# if there is no module-level functions with overload, check its presence
|
||||
# during class traversing, including their inner-classes
|
||||
for cls in _for_each_class(root):
|
||||
for cls in for_each_class(root):
|
||||
if not has_overload and check_overload_presence(cls):
|
||||
has_overload = True
|
||||
required_imports.add("import typing")
|
||||
@ -600,8 +580,9 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||
if has_overload:
|
||||
required_imports.add("import typing")
|
||||
# Importing modules required to resolve functions arguments
|
||||
for overload in _for_each_function_overload(root):
|
||||
for arg in filter(lambda a: a.type_node is not None, overload.arguments):
|
||||
for overload in for_each_function_overload(root):
|
||||
for arg in filter(lambda a: a.type_node is not None,
|
||||
overload.arguments):
|
||||
_add_required_usage_imports(arg.type_node, required_imports) # type: ignore
|
||||
if overload.return_type is not None:
|
||||
_add_required_usage_imports(overload.return_type.type_node,
|
||||
@ -625,11 +606,13 @@ def _populate_reexported_symbols(root: NamespaceNode) -> None:
|
||||
|
||||
_reexport_submodule(root)
|
||||
|
||||
# Special cases, symbols defined in possible pure Python submodules should be
|
||||
# Special cases, symbols defined in possible pure Python submodules
|
||||
# should be
|
||||
root.reexported_submodules_symbols["mat_wrapper"].append("Mat")
|
||||
|
||||
|
||||
def _write_reexported_symbols_section(module: NamespaceNode, output_stream: StringIO) -> None:
|
||||
def _write_reexported_symbols_section(module: NamespaceNode,
|
||||
output_stream: StringIO) -> None:
|
||||
"""Write re-export section for the given module.
|
||||
|
||||
Re-export statements have from `from module_name import smth as smth`.
|
||||
|
@ -22,6 +22,10 @@ _PREDEFINED_TYPES = (
|
||||
PrimitiveTypeNode.int_("uchar"),
|
||||
PrimitiveTypeNode.int_("unsigned"),
|
||||
PrimitiveTypeNode.int_("int64"),
|
||||
PrimitiveTypeNode.int_("uint8_t"),
|
||||
PrimitiveTypeNode.int_("int8_t"),
|
||||
PrimitiveTypeNode.int_("int32_t"),
|
||||
PrimitiveTypeNode.int_("uint32_t"),
|
||||
PrimitiveTypeNode.int_("size_t"),
|
||||
PrimitiveTypeNode.float_("float"),
|
||||
PrimitiveTypeNode.float_("double"),
|
||||
|
Loading…
Reference in New Issue
Block a user