From 9519e67ad21e33fb7f7ee4b5f2e04fabb89def82 Mon Sep 17 00:00:00 2001 From: Vadim Levin Date: Wed, 19 Jul 2023 16:51:41 +0300 Subject: [PATCH] Merge pull request #24022 from VadimLevin:dev/vlevin/python-typing-cuda Fix python typing stubs generation for CUDA modules #24022 resolves #23946 resolves #23945 resolves opencv/opencv-python#871 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake --- .../typing_stubs_generation/api_refinement.py | 68 +++++++++++++++++-- .../src2/typing_stubs_generation/ast_utils.py | 29 +++++++- .../typing_stubs_generation/generation.py | 53 +++++---------- .../predefined_types.py | 4 ++ 4 files changed, 112 insertions(+), 42 deletions(-) diff --git a/modules/python/src2/typing_stubs_generation/api_refinement.py b/modules/python/src2/typing_stubs_generation/api_refinement.py index f7d3a9dd08..a879d0ac4b 100644 --- a/modules/python/src2/typing_stubs_generation/api_refinement.py +++ b/modules/python/src2/typing_stubs_generation/api_refinement.py @@ -2,14 +2,17 @@ __all__ = [ "apply_manual_api_refinement" ] -from typing import Sequence, Callable +from typing import cast, Sequence, Callable, Iterable -from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, - ClassProperty, PrimitiveTypeNode) -from .ast_utils import find_function_node, SymbolName +from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode, + ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode, + AggregatedTypeNode) +from .ast_utils import (find_function_node, SymbolName, + for_each_function_overload) def apply_manual_api_refinement(root: NamespaceNode) -> None: + refine_cuda_module(root) export_matrix_type_constants(root) # Export OpenCV exception class builtin_exception = root.add_class("Exception") @@ -57,13 +60,65 @@ def make_optional_arg(arg_name: str) -> Callable[[NamespaceNode, SymbolName], No continue overload.arguments[arg_idx].type_node = OptionalTypeNode( - overload.arguments[arg_idx].type_node + cast(TypeNode, overload.arguments[arg_idx].type_node) ) return _make_optional_arg -def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> int: +def refine_cuda_module(root: NamespaceNode) -> None: + def fix_cudaoptflow_enums_names() -> None: + for class_name in ("NvidiaOpticalFlow_1_0", "NvidiaOpticalFlow_2_0"): + if class_name not in cuda_root.classes: + continue + opt_flow_class = cuda_root.classes[class_name] + _trim_class_name_from_argument_types( + for_each_function_overload(opt_flow_class), class_name + ) + + def fix_namespace_usage_scope(cuda_ns: NamespaceNode) -> None: + USED_TYPES = ("GpuMat", "Stream") + + def fix_type_usage(type_node: TypeNode) -> None: + if isinstance(type_node, AggregatedTypeNode): + for item in type_node.items: + fix_type_usage(item) + if isinstance(type_node, ASTNodeTypeNode): + if type_node._typename in USED_TYPES: + type_node._typename = f"cuda_{type_node._typename}" + + for overload in for_each_function_overload(cuda_ns): + if overload.return_type is not None: + fix_type_usage(overload.return_type.type_node) + for type_node in [arg.type_node for arg in overload.arguments + if arg.type_node is not None]: + fix_type_usage(type_node) + + if "cuda" not in root.namespaces: + return + cuda_root = root.namespaces["cuda"] + fix_cudaoptflow_enums_names() + for ns in [ns for ns_name, ns in root.namespaces.items() + if ns_name.startswith("cuda")]: + fix_namespace_usage_scope(ns) + + +def _trim_class_name_from_argument_types( + overloads: Iterable[FunctionNode.Overload], + class_name: str +) -> None: + separator = f"{class_name}_" + for overload in overloads: + for arg in [arg for arg in overload.arguments + if arg.type_node is not None]: + ast_node = cast(ASTNodeTypeNode, arg.type_node) + if class_name in ast_node.ctype_name: + fixed_name = ast_node._typename.split(separator)[-1] + ast_node._typename = fixed_name + + +def _find_argument_index(arguments: Sequence[FunctionNode.Arg], + name: str) -> int: for i, arg in enumerate(arguments): if arg.name == name: return i @@ -76,6 +131,7 @@ NODES_TO_REFINE = { SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"), SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"), } + ERROR_CLASS_PROPERTIES = ( ClassProperty("code", PrimitiveTypeNode.int_(), False), ClassProperty("err", PrimitiveTypeNode.str_(), False), diff --git a/modules/python/src2/typing_stubs_generation/ast_utils.py b/modules/python/src2/typing_stubs_generation/ast_utils.py index 4cdf807260..47c06571a5 100644 --- a/modules/python/src2/typing_stubs_generation/ast_utils.py +++ b/modules/python/src2/typing_stubs_generation/ast_utils.py @@ -1,5 +1,5 @@ from typing import (NamedTuple, Sequence, Tuple, Union, List, - Dict, Callable, Optional) + Dict, Callable, Optional, Generator) import keyword from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode, @@ -404,6 +404,33 @@ def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, st return enum_export_name, namespace_node.full_export_name +def for_each_class( + node: Union[NamespaceNode, ClassNode] +) -> Generator[ClassNode, None, None]: + for cls in node.classes.values(): + yield cls + if len(cls.classes): + yield from for_each_class(cls) + + +def for_each_function( + node: Union[NamespaceNode, ClassNode], + traverse_class_nodes: bool = True +) -> Generator[FunctionNode, None, None]: + yield from node.functions.values() + if traverse_class_nodes: + for cls in for_each_class(node): + yield from for_each_function(cls) + + +def for_each_function_overload( + node: Union[NamespaceNode, ClassNode], + traverse_class_nodes: bool = True +) -> Generator[FunctionNode.Overload, None, None]: + for func in for_each_function(node, traverse_class_nodes): + yield from func.overloads + + if __name__ == '__main__': import doctest doctest.testmod() diff --git a/modules/python/src2/typing_stubs_generation/generation.py b/modules/python/src2/typing_stubs_generation/generation.py index 2d6d4d338e..b4e8cb22c6 100644 --- a/modules/python/src2/typing_stubs_generation/generation.py +++ b/modules/python/src2/typing_stubs_generation/generation.py @@ -3,17 +3,20 @@ __all__ = ("generate_typing_stubs", ) from io import StringIO from pathlib import Path import re -from typing import (Generator, Type, Callable, NamedTuple, Union, Set, Dict, +from typing import (Type, Callable, NamedTuple, Union, Set, Dict, Collection, Tuple, List) import warnings -from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name +from .ast_utils import (get_enclosing_namespace, + get_enum_module_and_export_name, + for_each_function_overload, + for_each_class) from .predefined_types import PREDEFINED_TYPES from .api_refinement import apply_manual_api_refinement -from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode, - EnumerationNode, ConstantNode) +from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, + FunctionNode, EnumerationNode, ConstantNode) from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode, AggregatedTypeNode, ASTNodeTypeNode, @@ -105,8 +108,9 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None: # NOTE: Enumerations require special handling, because all enumeration # constants are exposed as module attributes - has_enums = _generate_section_stub(StubSection("# Enumerations", EnumerationNode), - root, output_stream, 0) + has_enums = _generate_section_stub( + StubSection("# Enumerations", EnumerationNode), root, output_stream, 0 + ) # Collect all enums from class level and export them to module level for class_node in root.classes.values(): if _generate_enums_from_classes_tree(class_node, output_stream, indent=0): @@ -536,30 +540,6 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool: return True return False - -def _for_each_class(node: Union[NamespaceNode, ClassNode]) \ - -> Generator[ClassNode, None, None]: - for cls in node.classes.values(): - yield cls - if len(cls.classes): - yield from _for_each_class(cls) - - -def _for_each_function(node: Union[NamespaceNode, ClassNode]) \ - -> Generator[FunctionNode, None, None]: - for func in node.functions.values(): - yield func - for cls in node.classes.values(): - yield from _for_each_function(cls) - - -def _for_each_function_overload(node: Union[NamespaceNode, ClassNode]) \ - -> Generator[FunctionNode.Overload, None, None]: - for func in _for_each_function(node): - for overload in func.overloads: - yield overload - - def _collect_required_imports(root: NamespaceNode) -> Set[str]: """Collects all imports required for classes and functions typing stubs declarations. @@ -582,7 +562,7 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]: has_overload = check_overload_presence(root) # if there is no module-level functions with overload, check its presence # during class traversing, including their inner-classes - for cls in _for_each_class(root): + for cls in for_each_class(root): if not has_overload and check_overload_presence(cls): has_overload = True required_imports.add("import typing") @@ -600,8 +580,9 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]: if has_overload: required_imports.add("import typing") # Importing modules required to resolve functions arguments - for overload in _for_each_function_overload(root): - for arg in filter(lambda a: a.type_node is not None, overload.arguments): + for overload in for_each_function_overload(root): + for arg in filter(lambda a: a.type_node is not None, + overload.arguments): _add_required_usage_imports(arg.type_node, required_imports) # type: ignore if overload.return_type is not None: _add_required_usage_imports(overload.return_type.type_node, @@ -625,11 +606,13 @@ def _populate_reexported_symbols(root: NamespaceNode) -> None: _reexport_submodule(root) - # Special cases, symbols defined in possible pure Python submodules should be + # Special cases, symbols defined in possible pure Python submodules + # should be root.reexported_submodules_symbols["mat_wrapper"].append("Mat") -def _write_reexported_symbols_section(module: NamespaceNode, output_stream: StringIO) -> None: +def _write_reexported_symbols_section(module: NamespaceNode, + output_stream: StringIO) -> None: """Write re-export section for the given module. Re-export statements have from `from module_name import smth as smth`. diff --git a/modules/python/src2/typing_stubs_generation/predefined_types.py b/modules/python/src2/typing_stubs_generation/predefined_types.py index fe9a37a45e..ce4f901e79 100644 --- a/modules/python/src2/typing_stubs_generation/predefined_types.py +++ b/modules/python/src2/typing_stubs_generation/predefined_types.py @@ -22,6 +22,10 @@ _PREDEFINED_TYPES = ( PrimitiveTypeNode.int_("uchar"), PrimitiveTypeNode.int_("unsigned"), PrimitiveTypeNode.int_("int64"), + PrimitiveTypeNode.int_("uint8_t"), + PrimitiveTypeNode.int_("int8_t"), + PrimitiveTypeNode.int_("int32_t"), + PrimitiveTypeNode.int_("uint32_t"), PrimitiveTypeNode.int_("size_t"), PrimitiveTypeNode.float_("float"), PrimitiveTypeNode.float_("double"),