mirror of
https://github.com/opencv/opencv.git
synced 2025-01-19 06:53:50 +08:00
Merge pull request #23801 from VadimLevin:dev/vlevin/python-stubs-api-refinement
feat: manual refinement for Python API definition
This commit is contained in:
commit
fc810434de
@ -0,0 +1,48 @@
|
||||
__all__ = [
|
||||
"apply_manual_api_refinement"
|
||||
]
|
||||
|
||||
from typing import Sequence, Callable
|
||||
from .nodes import NamespaceNode, FunctionNode, OptionalTypeNode
|
||||
from .ast_utils import find_function_node, SymbolName
|
||||
|
||||
|
||||
def apply_manual_api_refinement(root: NamespaceNode) -> None:
|
||||
# Export OpenCV exception class
|
||||
builtin_exception = root.add_class("Exception")
|
||||
builtin_exception.is_exported = False
|
||||
root.add_class("error", (builtin_exception, ))
|
||||
for symbol_name, refine_symbol in NODES_TO_REFINE.items():
|
||||
refine_symbol(root, symbol_name)
|
||||
|
||||
|
||||
def make_optional_arg(arg_name: str) -> Callable[[NamespaceNode, SymbolName], None]:
|
||||
def _make_optional_arg(root_node: NamespaceNode,
|
||||
function_symbol_name: SymbolName) -> None:
|
||||
function = find_function_node(root_node, function_symbol_name)
|
||||
for overload in function.overloads:
|
||||
arg_idx = _find_argument_index(overload.arguments, arg_name)
|
||||
# Avoid multiplying optional qualification
|
||||
if isinstance(overload.arguments[arg_idx].type_node, OptionalTypeNode):
|
||||
continue
|
||||
|
||||
overload.arguments[arg_idx].type_node = OptionalTypeNode(
|
||||
overload.arguments[arg_idx].type_node
|
||||
)
|
||||
|
||||
return _make_optional_arg
|
||||
|
||||
|
||||
def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> int:
|
||||
for i, arg in enumerate(arguments):
|
||||
if arg.name == name:
|
||||
return i
|
||||
raise RuntimeError(
|
||||
f"Failed to find argument with name: '{name}' in {arguments}"
|
||||
)
|
||||
|
||||
|
||||
NODES_TO_REFINE = {
|
||||
SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"),
|
||||
SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"),
|
||||
}
|
@ -143,13 +143,24 @@ because 'GOpaque' class is not registered yet
|
||||
return scope
|
||||
|
||||
|
||||
def find_class_node(root: NamespaceNode, full_class_name: str,
|
||||
namespaces: Sequence[str]) -> ClassNode:
|
||||
symbol_name = SymbolName.parse(full_class_name, namespaces)
|
||||
scope = find_scope(root, symbol_name)
|
||||
if symbol_name.name not in scope.classes:
|
||||
raise SymbolNotFoundError("Can't find {} in its scope".format(symbol_name))
|
||||
return scope.classes[symbol_name.name]
|
||||
def find_class_node(root: NamespaceNode, class_symbol: SymbolName,
|
||||
create_missing_namespaces: bool = False) -> ClassNode:
|
||||
scope = find_scope(root, class_symbol, create_missing_namespaces)
|
||||
if class_symbol.name not in scope.classes:
|
||||
raise SymbolNotFoundError(
|
||||
"Can't find {} in its scope".format(class_symbol)
|
||||
)
|
||||
return scope.classes[class_symbol.name]
|
||||
|
||||
|
||||
def find_function_node(root: NamespaceNode, function_symbol: SymbolName,
|
||||
create_missing_namespaces: bool = False) -> FunctionNode:
|
||||
scope = find_scope(root, function_symbol, create_missing_namespaces)
|
||||
if function_symbol.name not in scope.functions:
|
||||
raise SymbolNotFoundError(
|
||||
"Can't find {} in its scope".format(function_symbol)
|
||||
)
|
||||
return scope.functions[function_symbol.name]
|
||||
|
||||
|
||||
def create_function_node_in_scope(scope: Union[NamespaceNode, ClassNode],
|
||||
|
@ -10,6 +10,7 @@ import warnings
|
||||
from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name
|
||||
|
||||
from .predefined_types import PREDEFINED_TYPES
|
||||
from .api_refinement import apply_manual_api_refinement
|
||||
|
||||
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode,
|
||||
EnumerationNode, ConstantNode)
|
||||
@ -47,6 +48,18 @@ def generate_typing_stubs(root: NamespaceNode, output_path: Path):
|
||||
root (NamespaceNode): Root namespace node of the library AST.
|
||||
output_path (Path): Path to output directory.
|
||||
"""
|
||||
# Perform special handling for function arguments that has some conventions
|
||||
# not expressed in their API e.g. optionality of mutually exclusive arguments
|
||||
# without default values:
|
||||
# ```cxx
|
||||
# cv::resize(cv::InputArray src, cv::OutputArray dst, cv::Size dsize,
|
||||
# double fx = 0.0, double fy = 0.0, int interpolation);
|
||||
# ```
|
||||
# should accept `None` as `dsize`:
|
||||
# ```python
|
||||
# cv2.resize(image, dsize=None, fx=0.5, fy=0.5)
|
||||
# ```
|
||||
apply_manual_api_refinement(root)
|
||||
# Most of the time type nodes miss their full name (especially function
|
||||
# arguments and return types), so resolution should start from the narrowest
|
||||
# scope and gradually expanded.
|
||||
|
@ -10,10 +10,12 @@ class FunctionNode(ASTNode):
|
||||
This class defines an overload set rather then function itself, because
|
||||
function without overloads is represented as FunctionNode with 1 overload.
|
||||
"""
|
||||
class Arg(NamedTuple):
|
||||
name: str
|
||||
type_node: Optional[TypeNode] = None
|
||||
default_value: Optional[str] = None
|
||||
class Arg:
|
||||
def __init__(self, name: str, type_node: Optional[TypeNode] = None,
|
||||
default_value: Optional[str] = None) -> None:
|
||||
self.name = name
|
||||
self.type_node = type_node
|
||||
self.default_value = default_value
|
||||
|
||||
@property
|
||||
def typename(self) -> Optional[str]:
|
||||
@ -24,8 +26,18 @@ class FunctionNode(ASTNode):
|
||||
return self.type_node.relative_typename(root)
|
||||
return None
|
||||
|
||||
class RetType(NamedTuple):
|
||||
type_node: TypeNode = NoneTypeNode("void")
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Arg(name={self.name}, type_node={self.type_node},"
|
||||
f" default_value={self.default_value})"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
class RetType:
|
||||
def __init__(self, type_node: TypeNode = NoneTypeNode("void")) -> None:
|
||||
self.type_node = type_node
|
||||
|
||||
@property
|
||||
def typename(self) -> str:
|
||||
@ -34,6 +46,12 @@ class FunctionNode(ASTNode):
|
||||
def relative_typename(self, root: str) -> Optional[str]:
|
||||
return self.type_node.relative_typename(root)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"RetType(type_node={self.type_node})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
class Overload(NamedTuple):
|
||||
arguments: Sequence["FunctionNode.Arg"] = ()
|
||||
return_type: Optional["FunctionNode.RetType"] = None
|
||||
|
@ -123,7 +123,11 @@ if sys.version_info >= (3, 6):
|
||||
@failures_wrapper.wrap_exceptions_as_warnings(ret_type_on_failure=ClassNodeStub)
|
||||
def find_class_node(self, class_info, namespaces):
|
||||
# type: (Any, Sequence[str]) -> ClassNode
|
||||
return find_class_node(self.cv_root, class_info.full_original_name, namespaces)
|
||||
return find_class_node(
|
||||
self.cv_root,
|
||||
SymbolName.parse(class_info.full_original_name, namespaces),
|
||||
create_missing_namespaces=True
|
||||
)
|
||||
|
||||
@failures_wrapper.wrap_exceptions_as_warnings(ret_type_on_failure=ClassNodeStub)
|
||||
def create_class_node(self, class_info, namespaces):
|
||||
|
Loading…
Reference in New Issue
Block a user