mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 03:33:28 +08:00
Merge pull request #24022 from VadimLevin:dev/vlevin/python-typing-cuda
Fix python typing stubs generation for CUDA modules #24022 resolves #23946 resolves #23945 resolves opencv/opencv-python#871 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
b0cbd6ef77
commit
9519e67ad2
@ -2,14 +2,17 @@ __all__ = [
|
|||||||
"apply_manual_api_refinement"
|
"apply_manual_api_refinement"
|
||||||
]
|
]
|
||||||
|
|
||||||
from typing import Sequence, Callable
|
from typing import cast, Sequence, Callable, Iterable
|
||||||
|
|
||||||
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode,
|
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
|
||||||
ClassProperty, PrimitiveTypeNode)
|
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
|
||||||
from .ast_utils import find_function_node, SymbolName
|
AggregatedTypeNode)
|
||||||
|
from .ast_utils import (find_function_node, SymbolName,
|
||||||
|
for_each_function_overload)
|
||||||
|
|
||||||
|
|
||||||
def apply_manual_api_refinement(root: NamespaceNode) -> None:
|
def apply_manual_api_refinement(root: NamespaceNode) -> None:
|
||||||
|
refine_cuda_module(root)
|
||||||
export_matrix_type_constants(root)
|
export_matrix_type_constants(root)
|
||||||
# Export OpenCV exception class
|
# Export OpenCV exception class
|
||||||
builtin_exception = root.add_class("Exception")
|
builtin_exception = root.add_class("Exception")
|
||||||
@ -57,13 +60,65 @@ def make_optional_arg(arg_name: str) -> Callable[[NamespaceNode, SymbolName], No
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
overload.arguments[arg_idx].type_node = OptionalTypeNode(
|
overload.arguments[arg_idx].type_node = OptionalTypeNode(
|
||||||
overload.arguments[arg_idx].type_node
|
cast(TypeNode, overload.arguments[arg_idx].type_node)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _make_optional_arg
|
return _make_optional_arg
|
||||||
|
|
||||||
|
|
||||||
def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> int:
|
def refine_cuda_module(root: NamespaceNode) -> None:
|
||||||
|
def fix_cudaoptflow_enums_names() -> None:
|
||||||
|
for class_name in ("NvidiaOpticalFlow_1_0", "NvidiaOpticalFlow_2_0"):
|
||||||
|
if class_name not in cuda_root.classes:
|
||||||
|
continue
|
||||||
|
opt_flow_class = cuda_root.classes[class_name]
|
||||||
|
_trim_class_name_from_argument_types(
|
||||||
|
for_each_function_overload(opt_flow_class), class_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def fix_namespace_usage_scope(cuda_ns: NamespaceNode) -> None:
|
||||||
|
USED_TYPES = ("GpuMat", "Stream")
|
||||||
|
|
||||||
|
def fix_type_usage(type_node: TypeNode) -> None:
|
||||||
|
if isinstance(type_node, AggregatedTypeNode):
|
||||||
|
for item in type_node.items:
|
||||||
|
fix_type_usage(item)
|
||||||
|
if isinstance(type_node, ASTNodeTypeNode):
|
||||||
|
if type_node._typename in USED_TYPES:
|
||||||
|
type_node._typename = f"cuda_{type_node._typename}"
|
||||||
|
|
||||||
|
for overload in for_each_function_overload(cuda_ns):
|
||||||
|
if overload.return_type is not None:
|
||||||
|
fix_type_usage(overload.return_type.type_node)
|
||||||
|
for type_node in [arg.type_node for arg in overload.arguments
|
||||||
|
if arg.type_node is not None]:
|
||||||
|
fix_type_usage(type_node)
|
||||||
|
|
||||||
|
if "cuda" not in root.namespaces:
|
||||||
|
return
|
||||||
|
cuda_root = root.namespaces["cuda"]
|
||||||
|
fix_cudaoptflow_enums_names()
|
||||||
|
for ns in [ns for ns_name, ns in root.namespaces.items()
|
||||||
|
if ns_name.startswith("cuda")]:
|
||||||
|
fix_namespace_usage_scope(ns)
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_class_name_from_argument_types(
|
||||||
|
overloads: Iterable[FunctionNode.Overload],
|
||||||
|
class_name: str
|
||||||
|
) -> None:
|
||||||
|
separator = f"{class_name}_"
|
||||||
|
for overload in overloads:
|
||||||
|
for arg in [arg for arg in overload.arguments
|
||||||
|
if arg.type_node is not None]:
|
||||||
|
ast_node = cast(ASTNodeTypeNode, arg.type_node)
|
||||||
|
if class_name in ast_node.ctype_name:
|
||||||
|
fixed_name = ast_node._typename.split(separator)[-1]
|
||||||
|
ast_node._typename = fixed_name
|
||||||
|
|
||||||
|
|
||||||
|
def _find_argument_index(arguments: Sequence[FunctionNode.Arg],
|
||||||
|
name: str) -> int:
|
||||||
for i, arg in enumerate(arguments):
|
for i, arg in enumerate(arguments):
|
||||||
if arg.name == name:
|
if arg.name == name:
|
||||||
return i
|
return i
|
||||||
@ -76,6 +131,7 @@ NODES_TO_REFINE = {
|
|||||||
SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"),
|
SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"),
|
||||||
SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"),
|
SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"),
|
||||||
}
|
}
|
||||||
|
|
||||||
ERROR_CLASS_PROPERTIES = (
|
ERROR_CLASS_PROPERTIES = (
|
||||||
ClassProperty("code", PrimitiveTypeNode.int_(), False),
|
ClassProperty("code", PrimitiveTypeNode.int_(), False),
|
||||||
ClassProperty("err", PrimitiveTypeNode.str_(), False),
|
ClassProperty("err", PrimitiveTypeNode.str_(), False),
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from typing import (NamedTuple, Sequence, Tuple, Union, List,
|
from typing import (NamedTuple, Sequence, Tuple, Union, List,
|
||||||
Dict, Callable, Optional)
|
Dict, Callable, Optional, Generator)
|
||||||
import keyword
|
import keyword
|
||||||
|
|
||||||
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
|
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
|
||||||
@ -404,6 +404,33 @@ def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, st
|
|||||||
return enum_export_name, namespace_node.full_export_name
|
return enum_export_name, namespace_node.full_export_name
|
||||||
|
|
||||||
|
|
||||||
|
def for_each_class(
|
||||||
|
node: Union[NamespaceNode, ClassNode]
|
||||||
|
) -> Generator[ClassNode, None, None]:
|
||||||
|
for cls in node.classes.values():
|
||||||
|
yield cls
|
||||||
|
if len(cls.classes):
|
||||||
|
yield from for_each_class(cls)
|
||||||
|
|
||||||
|
|
||||||
|
def for_each_function(
|
||||||
|
node: Union[NamespaceNode, ClassNode],
|
||||||
|
traverse_class_nodes: bool = True
|
||||||
|
) -> Generator[FunctionNode, None, None]:
|
||||||
|
yield from node.functions.values()
|
||||||
|
if traverse_class_nodes:
|
||||||
|
for cls in for_each_class(node):
|
||||||
|
yield from for_each_function(cls)
|
||||||
|
|
||||||
|
|
||||||
|
def for_each_function_overload(
|
||||||
|
node: Union[NamespaceNode, ClassNode],
|
||||||
|
traverse_class_nodes: bool = True
|
||||||
|
) -> Generator[FunctionNode.Overload, None, None]:
|
||||||
|
for func in for_each_function(node, traverse_class_nodes):
|
||||||
|
yield from func.overloads
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import doctest
|
import doctest
|
||||||
doctest.testmod()
|
doctest.testmod()
|
||||||
|
@ -3,17 +3,20 @@ __all__ = ("generate_typing_stubs", )
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
from typing import (Generator, Type, Callable, NamedTuple, Union, Set, Dict,
|
from typing import (Type, Callable, NamedTuple, Union, Set, Dict,
|
||||||
Collection, Tuple, List)
|
Collection, Tuple, List)
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name
|
from .ast_utils import (get_enclosing_namespace,
|
||||||
|
get_enum_module_and_export_name,
|
||||||
|
for_each_function_overload,
|
||||||
|
for_each_class)
|
||||||
|
|
||||||
from .predefined_types import PREDEFINED_TYPES
|
from .predefined_types import PREDEFINED_TYPES
|
||||||
from .api_refinement import apply_manual_api_refinement
|
from .api_refinement import apply_manual_api_refinement
|
||||||
|
|
||||||
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode,
|
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
|
||||||
EnumerationNode, ConstantNode)
|
FunctionNode, EnumerationNode, ConstantNode)
|
||||||
|
|
||||||
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
|
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
|
||||||
AggregatedTypeNode, ASTNodeTypeNode,
|
AggregatedTypeNode, ASTNodeTypeNode,
|
||||||
@ -105,8 +108,9 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
|
|||||||
|
|
||||||
# NOTE: Enumerations require special handling, because all enumeration
|
# NOTE: Enumerations require special handling, because all enumeration
|
||||||
# constants are exposed as module attributes
|
# constants are exposed as module attributes
|
||||||
has_enums = _generate_section_stub(StubSection("# Enumerations", EnumerationNode),
|
has_enums = _generate_section_stub(
|
||||||
root, output_stream, 0)
|
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0
|
||||||
|
)
|
||||||
# Collect all enums from class level and export them to module level
|
# Collect all enums from class level and export them to module level
|
||||||
for class_node in root.classes.values():
|
for class_node in root.classes.values():
|
||||||
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
|
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
|
||||||
@ -536,30 +540,6 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _for_each_class(node: Union[NamespaceNode, ClassNode]) \
|
|
||||||
-> Generator[ClassNode, None, None]:
|
|
||||||
for cls in node.classes.values():
|
|
||||||
yield cls
|
|
||||||
if len(cls.classes):
|
|
||||||
yield from _for_each_class(cls)
|
|
||||||
|
|
||||||
|
|
||||||
def _for_each_function(node: Union[NamespaceNode, ClassNode]) \
|
|
||||||
-> Generator[FunctionNode, None, None]:
|
|
||||||
for func in node.functions.values():
|
|
||||||
yield func
|
|
||||||
for cls in node.classes.values():
|
|
||||||
yield from _for_each_function(cls)
|
|
||||||
|
|
||||||
|
|
||||||
def _for_each_function_overload(node: Union[NamespaceNode, ClassNode]) \
|
|
||||||
-> Generator[FunctionNode.Overload, None, None]:
|
|
||||||
for func in _for_each_function(node):
|
|
||||||
for overload in func.overloads:
|
|
||||||
yield overload
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
||||||
"""Collects all imports required for classes and functions typing stubs
|
"""Collects all imports required for classes and functions typing stubs
|
||||||
declarations.
|
declarations.
|
||||||
@ -582,7 +562,7 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
|||||||
has_overload = check_overload_presence(root)
|
has_overload = check_overload_presence(root)
|
||||||
# if there is no module-level functions with overload, check its presence
|
# if there is no module-level functions with overload, check its presence
|
||||||
# during class traversing, including their inner-classes
|
# during class traversing, including their inner-classes
|
||||||
for cls in _for_each_class(root):
|
for cls in for_each_class(root):
|
||||||
if not has_overload and check_overload_presence(cls):
|
if not has_overload and check_overload_presence(cls):
|
||||||
has_overload = True
|
has_overload = True
|
||||||
required_imports.add("import typing")
|
required_imports.add("import typing")
|
||||||
@ -600,8 +580,9 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
|
|||||||
if has_overload:
|
if has_overload:
|
||||||
required_imports.add("import typing")
|
required_imports.add("import typing")
|
||||||
# Importing modules required to resolve functions arguments
|
# Importing modules required to resolve functions arguments
|
||||||
for overload in _for_each_function_overload(root):
|
for overload in for_each_function_overload(root):
|
||||||
for arg in filter(lambda a: a.type_node is not None, overload.arguments):
|
for arg in filter(lambda a: a.type_node is not None,
|
||||||
|
overload.arguments):
|
||||||
_add_required_usage_imports(arg.type_node, required_imports) # type: ignore
|
_add_required_usage_imports(arg.type_node, required_imports) # type: ignore
|
||||||
if overload.return_type is not None:
|
if overload.return_type is not None:
|
||||||
_add_required_usage_imports(overload.return_type.type_node,
|
_add_required_usage_imports(overload.return_type.type_node,
|
||||||
@ -625,11 +606,13 @@ def _populate_reexported_symbols(root: NamespaceNode) -> None:
|
|||||||
|
|
||||||
_reexport_submodule(root)
|
_reexport_submodule(root)
|
||||||
|
|
||||||
# Special cases, symbols defined in possible pure Python submodules should be
|
# Special cases, symbols defined in possible pure Python submodules
|
||||||
|
# should be
|
||||||
root.reexported_submodules_symbols["mat_wrapper"].append("Mat")
|
root.reexported_submodules_symbols["mat_wrapper"].append("Mat")
|
||||||
|
|
||||||
|
|
||||||
def _write_reexported_symbols_section(module: NamespaceNode, output_stream: StringIO) -> None:
|
def _write_reexported_symbols_section(module: NamespaceNode,
|
||||||
|
output_stream: StringIO) -> None:
|
||||||
"""Write re-export section for the given module.
|
"""Write re-export section for the given module.
|
||||||
|
|
||||||
Re-export statements have from `from module_name import smth as smth`.
|
Re-export statements have from `from module_name import smth as smth`.
|
||||||
|
@ -22,6 +22,10 @@ _PREDEFINED_TYPES = (
|
|||||||
PrimitiveTypeNode.int_("uchar"),
|
PrimitiveTypeNode.int_("uchar"),
|
||||||
PrimitiveTypeNode.int_("unsigned"),
|
PrimitiveTypeNode.int_("unsigned"),
|
||||||
PrimitiveTypeNode.int_("int64"),
|
PrimitiveTypeNode.int_("int64"),
|
||||||
|
PrimitiveTypeNode.int_("uint8_t"),
|
||||||
|
PrimitiveTypeNode.int_("int8_t"),
|
||||||
|
PrimitiveTypeNode.int_("int32_t"),
|
||||||
|
PrimitiveTypeNode.int_("uint32_t"),
|
||||||
PrimitiveTypeNode.int_("size_t"),
|
PrimitiveTypeNode.int_("size_t"),
|
||||||
PrimitiveTypeNode.float_("float"),
|
PrimitiveTypeNode.float_("float"),
|
||||||
PrimitiveTypeNode.float_("double"),
|
PrimitiveTypeNode.float_("double"),
|
||||||
|
Loading…
Reference in New Issue
Block a user