Merge pull request #24022 from VadimLevin:dev/vlevin/python-typing-cuda

Fix python typing stubs generation for CUDA modules #24022

resolves #23946
resolves #23945
resolves opencv/opencv-python#871

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
pull/24025/head
Vadim Levin 2 years ago committed by GitHub
parent b0cbd6ef77
commit 9519e67ad2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 68
      modules/python/src2/typing_stubs_generation/api_refinement.py
  2. 29
      modules/python/src2/typing_stubs_generation/ast_utils.py
  3. 53
      modules/python/src2/typing_stubs_generation/generation.py
  4. 4
      modules/python/src2/typing_stubs_generation/predefined_types.py

@ -2,14 +2,17 @@ __all__ = [
"apply_manual_api_refinement"
]
from typing import Sequence, Callable
from typing import cast, Sequence, Callable, Iterable
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode,
ClassProperty, PrimitiveTypeNode)
from .ast_utils import find_function_node, SymbolName
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
AggregatedTypeNode)
from .ast_utils import (find_function_node, SymbolName,
for_each_function_overload)
def apply_manual_api_refinement(root: NamespaceNode) -> None:
refine_cuda_module(root)
export_matrix_type_constants(root)
# Export OpenCV exception class
builtin_exception = root.add_class("Exception")
@ -57,13 +60,65 @@ def make_optional_arg(arg_name: str) -> Callable[[NamespaceNode, SymbolName], No
continue
overload.arguments[arg_idx].type_node = OptionalTypeNode(
overload.arguments[arg_idx].type_node
cast(TypeNode, overload.arguments[arg_idx].type_node)
)
return _make_optional_arg
def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> int:
def refine_cuda_module(root: NamespaceNode) -> None:
def fix_cudaoptflow_enums_names() -> None:
for class_name in ("NvidiaOpticalFlow_1_0", "NvidiaOpticalFlow_2_0"):
if class_name not in cuda_root.classes:
continue
opt_flow_class = cuda_root.classes[class_name]
_trim_class_name_from_argument_types(
for_each_function_overload(opt_flow_class), class_name
)
def fix_namespace_usage_scope(cuda_ns: NamespaceNode) -> None:
USED_TYPES = ("GpuMat", "Stream")
def fix_type_usage(type_node: TypeNode) -> None:
if isinstance(type_node, AggregatedTypeNode):
for item in type_node.items:
fix_type_usage(item)
if isinstance(type_node, ASTNodeTypeNode):
if type_node._typename in USED_TYPES:
type_node._typename = f"cuda_{type_node._typename}"
for overload in for_each_function_overload(cuda_ns):
if overload.return_type is not None:
fix_type_usage(overload.return_type.type_node)
for type_node in [arg.type_node for arg in overload.arguments
if arg.type_node is not None]:
fix_type_usage(type_node)
if "cuda" not in root.namespaces:
return
cuda_root = root.namespaces["cuda"]
fix_cudaoptflow_enums_names()
for ns in [ns for ns_name, ns in root.namespaces.items()
if ns_name.startswith("cuda")]:
fix_namespace_usage_scope(ns)
def _trim_class_name_from_argument_types(
overloads: Iterable[FunctionNode.Overload],
class_name: str
) -> None:
separator = f"{class_name}_"
for overload in overloads:
for arg in [arg for arg in overload.arguments
if arg.type_node is not None]:
ast_node = cast(ASTNodeTypeNode, arg.type_node)
if class_name in ast_node.ctype_name:
fixed_name = ast_node._typename.split(separator)[-1]
ast_node._typename = fixed_name
def _find_argument_index(arguments: Sequence[FunctionNode.Arg],
name: str) -> int:
for i, arg in enumerate(arguments):
if arg.name == name:
return i
@ -76,6 +131,7 @@ NODES_TO_REFINE = {
SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"),
SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"),
}
ERROR_CLASS_PROPERTIES = (
ClassProperty("code", PrimitiveTypeNode.int_(), False),
ClassProperty("err", PrimitiveTypeNode.str_(), False),

@ -1,5 +1,5 @@
from typing import (NamedTuple, Sequence, Tuple, Union, List,
Dict, Callable, Optional)
Dict, Callable, Optional, Generator)
import keyword
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
@ -404,6 +404,33 @@ def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, st
return enum_export_name, namespace_node.full_export_name
def for_each_class(
node: Union[NamespaceNode, ClassNode]
) -> Generator[ClassNode, None, None]:
for cls in node.classes.values():
yield cls
if len(cls.classes):
yield from for_each_class(cls)
def for_each_function(
node: Union[NamespaceNode, ClassNode],
traverse_class_nodes: bool = True
) -> Generator[FunctionNode, None, None]:
yield from node.functions.values()
if traverse_class_nodes:
for cls in for_each_class(node):
yield from for_each_function(cls)
def for_each_function_overload(
node: Union[NamespaceNode, ClassNode],
traverse_class_nodes: bool = True
) -> Generator[FunctionNode.Overload, None, None]:
for func in for_each_function(node, traverse_class_nodes):
yield from func.overloads
if __name__ == '__main__':
import doctest
doctest.testmod()

@ -3,17 +3,20 @@ __all__ = ("generate_typing_stubs", )
from io import StringIO
from pathlib import Path
import re
from typing import (Generator, Type, Callable, NamedTuple, Union, Set, Dict,
from typing import (Type, Callable, NamedTuple, Union, Set, Dict,
Collection, Tuple, List)
import warnings
from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name
from .ast_utils import (get_enclosing_namespace,
get_enum_module_and_export_name,
for_each_function_overload,
for_each_class)
from .predefined_types import PREDEFINED_TYPES
from .api_refinement import apply_manual_api_refinement
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode,
EnumerationNode, ConstantNode)
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
FunctionNode, EnumerationNode, ConstantNode)
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
AggregatedTypeNode, ASTNodeTypeNode,
@ -105,8 +108,9 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:
# NOTE: Enumerations require special handling, because all enumeration
# constants are exposed as module attributes
has_enums = _generate_section_stub(StubSection("# Enumerations", EnumerationNode),
root, output_stream, 0)
has_enums = _generate_section_stub(
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0
)
# Collect all enums from class level and export them to module level
for class_node in root.classes.values():
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
@ -536,30 +540,6 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
return True
return False
def _for_each_class(node: Union[NamespaceNode, ClassNode]) \
-> Generator[ClassNode, None, None]:
for cls in node.classes.values():
yield cls
if len(cls.classes):
yield from _for_each_class(cls)
def _for_each_function(node: Union[NamespaceNode, ClassNode]) \
-> Generator[FunctionNode, None, None]:
for func in node.functions.values():
yield func
for cls in node.classes.values():
yield from _for_each_function(cls)
def _for_each_function_overload(node: Union[NamespaceNode, ClassNode]) \
-> Generator[FunctionNode.Overload, None, None]:
for func in _for_each_function(node):
for overload in func.overloads:
yield overload
def _collect_required_imports(root: NamespaceNode) -> Set[str]:
"""Collects all imports required for classes and functions typing stubs
declarations.
@ -582,7 +562,7 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
has_overload = check_overload_presence(root)
# if there is no module-level functions with overload, check its presence
# during class traversing, including their inner-classes
for cls in _for_each_class(root):
for cls in for_each_class(root):
if not has_overload and check_overload_presence(cls):
has_overload = True
required_imports.add("import typing")
@ -600,8 +580,9 @@ def _collect_required_imports(root: NamespaceNode) -> Set[str]:
if has_overload:
required_imports.add("import typing")
# Importing modules required to resolve functions arguments
for overload in _for_each_function_overload(root):
for arg in filter(lambda a: a.type_node is not None, overload.arguments):
for overload in for_each_function_overload(root):
for arg in filter(lambda a: a.type_node is not None,
overload.arguments):
_add_required_usage_imports(arg.type_node, required_imports) # type: ignore
if overload.return_type is not None:
_add_required_usage_imports(overload.return_type.type_node,
@ -625,11 +606,13 @@ def _populate_reexported_symbols(root: NamespaceNode) -> None:
_reexport_submodule(root)
# Special cases, symbols defined in possible pure Python submodules should be
# Special cases, symbols defined in possible pure Python submodules
# should be
root.reexported_submodules_symbols["mat_wrapper"].append("Mat")
def _write_reexported_symbols_section(module: NamespaceNode, output_stream: StringIO) -> None:
def _write_reexported_symbols_section(module: NamespaceNode,
output_stream: StringIO) -> None:
"""Write re-export section for the given module.
Re-export statements have from `from module_name import smth as smth`.

@ -22,6 +22,10 @@ _PREDEFINED_TYPES = (
PrimitiveTypeNode.int_("uchar"),
PrimitiveTypeNode.int_("unsigned"),
PrimitiveTypeNode.int_("int64"),
PrimitiveTypeNode.int_("uint8_t"),
PrimitiveTypeNode.int_("int8_t"),
PrimitiveTypeNode.int_("int32_t"),
PrimitiveTypeNode.int_("uint32_t"),
PrimitiveTypeNode.int_("size_t"),
PrimitiveTypeNode.float_("float"),
PrimitiveTypeNode.float_("double"),

Loading…
Cancel
Save