Merge pull request #23801 from VadimLevin:dev/vlevin/python-stubs-api-refinement

feat: manual refinement for Python API definition
This commit is contained in:
Alexander Smorkalov 2023-06-21 10:44:36 +03:00 committed by GitHub
commit fc810434de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 108 additions and 14 deletions

View File

@ -0,0 +1,48 @@
__all__ = [
"apply_manual_api_refinement"
]
from typing import Sequence, Callable
from .nodes import NamespaceNode, FunctionNode, OptionalTypeNode
from .ast_utils import find_function_node, SymbolName
def apply_manual_api_refinement(root: NamespaceNode) -> None:
# Export OpenCV exception class
builtin_exception = root.add_class("Exception")
builtin_exception.is_exported = False
root.add_class("error", (builtin_exception, ))
for symbol_name, refine_symbol in NODES_TO_REFINE.items():
refine_symbol(root, symbol_name)
def make_optional_arg(arg_name: str) -> Callable[[NamespaceNode, SymbolName], None]:
def _make_optional_arg(root_node: NamespaceNode,
function_symbol_name: SymbolName) -> None:
function = find_function_node(root_node, function_symbol_name)
for overload in function.overloads:
arg_idx = _find_argument_index(overload.arguments, arg_name)
# Avoid multiplying optional qualification
if isinstance(overload.arguments[arg_idx].type_node, OptionalTypeNode):
continue
overload.arguments[arg_idx].type_node = OptionalTypeNode(
overload.arguments[arg_idx].type_node
)
return _make_optional_arg
def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> int:
for i, arg in enumerate(arguments):
if arg.name == name:
return i
raise RuntimeError(
f"Failed to find argument with name: '{name}' in {arguments}"
)
NODES_TO_REFINE = {
SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"),
SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"),
}

View File

@ -143,13 +143,24 @@ because 'GOpaque' class is not registered yet
return scope
def find_class_node(root: NamespaceNode, full_class_name: str,
namespaces: Sequence[str]) -> ClassNode:
symbol_name = SymbolName.parse(full_class_name, namespaces)
scope = find_scope(root, symbol_name)
if symbol_name.name not in scope.classes:
raise SymbolNotFoundError("Can't find {} in its scope".format(symbol_name))
return scope.classes[symbol_name.name]
def find_class_node(root: NamespaceNode, class_symbol: SymbolName,
create_missing_namespaces: bool = False) -> ClassNode:
scope = find_scope(root, class_symbol, create_missing_namespaces)
if class_symbol.name not in scope.classes:
raise SymbolNotFoundError(
"Can't find {} in its scope".format(class_symbol)
)
return scope.classes[class_symbol.name]
def find_function_node(root: NamespaceNode, function_symbol: SymbolName,
create_missing_namespaces: bool = False) -> FunctionNode:
scope = find_scope(root, function_symbol, create_missing_namespaces)
if function_symbol.name not in scope.functions:
raise SymbolNotFoundError(
"Can't find {} in its scope".format(function_symbol)
)
return scope.functions[function_symbol.name]
def create_function_node_in_scope(scope: Union[NamespaceNode, ClassNode],

View File

@ -10,6 +10,7 @@ import warnings
from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name
from .predefined_types import PREDEFINED_TYPES
from .api_refinement import apply_manual_api_refinement
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode,
EnumerationNode, ConstantNode)
@ -47,6 +48,18 @@ def generate_typing_stubs(root: NamespaceNode, output_path: Path):
root (NamespaceNode): Root namespace node of the library AST.
output_path (Path): Path to output directory.
"""
# Perform special handling for function arguments that has some conventions
# not expressed in their API e.g. optionality of mutually exclusive arguments
# without default values:
# ```cxx
# cv::resize(cv::InputArray src, cv::OutputArray dst, cv::Size dsize,
# double fx = 0.0, double fy = 0.0, int interpolation);
# ```
# should accept `None` as `dsize`:
# ```python
# cv2.resize(image, dsize=None, fx=0.5, fy=0.5)
# ```
apply_manual_api_refinement(root)
# Most of the time type nodes miss their full name (especially function
# arguments and return types), so resolution should start from the narrowest
# scope and gradually expanded.

View File

@ -10,10 +10,12 @@ class FunctionNode(ASTNode):
This class defines an overload set rather then function itself, because
function without overloads is represented as FunctionNode with 1 overload.
"""
class Arg(NamedTuple):
name: str
type_node: Optional[TypeNode] = None
default_value: Optional[str] = None
class Arg:
def __init__(self, name: str, type_node: Optional[TypeNode] = None,
default_value: Optional[str] = None) -> None:
self.name = name
self.type_node = type_node
self.default_value = default_value
@property
def typename(self) -> Optional[str]:
@ -24,8 +26,18 @@ class FunctionNode(ASTNode):
return self.type_node.relative_typename(root)
return None
class RetType(NamedTuple):
type_node: TypeNode = NoneTypeNode("void")
def __str__(self) -> str:
return (
f"Arg(name={self.name}, type_node={self.type_node},"
f" default_value={self.default_value})"
)
def __repr__(self) -> str:
return str(self)
class RetType:
def __init__(self, type_node: TypeNode = NoneTypeNode("void")) -> None:
self.type_node = type_node
@property
def typename(self) -> str:
@ -34,6 +46,12 @@ class FunctionNode(ASTNode):
def relative_typename(self, root: str) -> Optional[str]:
return self.type_node.relative_typename(root)
def __str__(self) -> str:
return f"RetType(type_node={self.type_node})"
def __repr__(self) -> str:
return str(self)
class Overload(NamedTuple):
arguments: Sequence["FunctionNode.Arg"] = ()
return_type: Optional["FunctionNode.RetType"] = None

View File

@ -123,7 +123,11 @@ if sys.version_info >= (3, 6):
@failures_wrapper.wrap_exceptions_as_warnings(ret_type_on_failure=ClassNodeStub)
def find_class_node(self, class_info, namespaces):
# type: (Any, Sequence[str]) -> ClassNode
return find_class_node(self.cv_root, class_info.full_original_name, namespaces)
return find_class_node(
self.cv_root,
SymbolName.parse(class_info.full_original_name, namespaces),
create_missing_namespaces=True
)
@failures_wrapper.wrap_exceptions_as_warnings(ret_type_on_failure=ClassNodeStub)
def create_class_node(self, class_info, namespaces):