Merge pull request #23813 from VadimLevin:dev/vlevin/runtime-typing-module

fix: typing module enums references
pull/23821/head
Alexander Smorkalov 2 years ago committed by GitHub
commit 003d048b0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 33
      modules/python/src2/typing_stubs_generation/ast_utils.py
  2. 61
      modules/python/src2/typing_stubs_generation/generation.py
  3. 2
      modules/python/src2/typing_stubs_generation/nodes/__init__.py
  4. 2
      modules/python/src2/typing_stubs_generation/nodes/node.py
  5. 4
      modules/python/src2/typing_stubs_generation/nodes/type_node.py
  6. 7
      modules/python/src2/typing_stubs_generator.py

@ -1,4 +1,5 @@
from typing import NamedTuple, Sequence, Tuple, Union, List, Dict from typing import (NamedTuple, Sequence, Tuple, Union, List,
Dict, Callable, Optional)
import keyword import keyword
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode, from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
@ -323,12 +324,18 @@ def resolve_enum_scopes(root: NamespaceNode,
enum_node.parent = scope enum_node.parent = scope
def get_enclosing_namespace(node: ASTNode) -> NamespaceNode: def get_enclosing_namespace(
node: ASTNode,
class_node_callback: Optional[Callable[[ClassNode], None]] = None
) -> NamespaceNode:
"""Traverses up nodes hierarchy to find closest enclosing namespace of the """Traverses up nodes hierarchy to find closest enclosing namespace of the
passed node passed node
Args: Args:
node (ASTNode): Node to find a namespace for. node (ASTNode): Node to find a namespace for.
class_node_callback (Optional[Callable[[ClassNode], None]]): Optional
callable object invoked for each traversed class node in bottom-up
order. Defaults: None.
Returns: Returns:
NamespaceNode: Closest enclosing namespace of the provided node. NamespaceNode: Closest enclosing namespace of the provided node.
@ -360,10 +367,32 @@ def get_enclosing_namespace(node: ASTNode) -> NamespaceNode:
"Can't find enclosing namespace for '{}' known as: '{}'".format( "Can't find enclosing namespace for '{}' known as: '{}'".format(
node.full_export_name, node.native_name node.full_export_name, node.native_name
) )
if class_node_callback:
class_node_callback(parent_node)
parent_node = parent_node.parent parent_node = parent_node.parent
return parent_node return parent_node
def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, str]:
"""Get export name of the enum node with its module name.
Note: Enumeration export names are prefixed with enclosing class names.
Args:
enum_node (EnumerationNode): Enumeration node to construct name for.
Returns:
Tuple[str, str]: a pair of enum export name and its full module name.
"""
def update_full_export_name(class_node: ClassNode) -> None:
nonlocal enum_export_name
enum_export_name = class_node.export_name + "_" + enum_export_name
enum_export_name = enum_node.export_name
namespace_node = get_enclosing_namespace(enum_node, update_full_export_name)
return enum_export_name, namespace_node.full_export_name
if __name__ == '__main__': if __name__ == '__main__':
import doctest import doctest
doctest.testmod() doctest.testmod()

@ -6,15 +6,15 @@ from typing import (Generator, Type, Callable, NamedTuple, Union, Set, Dict,
Collection) Collection)
import warnings import warnings
from .ast_utils import get_enclosing_namespace from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name
from .predefined_types import PREDEFINED_TYPES from .predefined_types import PREDEFINED_TYPES
from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode, from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode,
EnumerationNode, ConstantNode) EnumerationNode, ConstantNode)
from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode, from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
AggregatedTypeNode) AggregatedTypeNode, ASTNodeTypeNode)
def generate_typing_stubs(root: NamespaceNode, output_path: Path): def generate_typing_stubs(root: NamespaceNode, output_path: Path):
@ -616,29 +616,67 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
for item in filter(lambda i: isinstance(i, AliasRefTypeNode), type_node): for item in filter(lambda i: isinstance(i, AliasRefTypeNode), type_node):
register_alias(PREDEFINED_TYPES[item.ctype_name]) # type: ignore register_alias(PREDEFINED_TYPES[item.ctype_name]) # type: ignore
def create_alias_for_enum_node(enum_node: ASTNode) -> AliasTypeNode:
"""Create int alias corresponding to the given enum node.
Args:
enum_node (ASTNodeTypeNode): Enumeration node to create int alias for.
Returns:
AliasTypeNode: int alias node with same export name as enum.
"""
assert enum_node.node_type == ASTNodeType.Enumeration, \
f"{enum_node} has wrong node type. Expected type: Enumeration."
enum_export_name, enum_module_name = get_enum_module_and_export_name(
enum_node
)
enum_full_export_name = f"{enum_module_name}.{enum_export_name}"
alias_node = AliasTypeNode.int_(enum_full_export_name,
enum_export_name)
type_checking_time_definitions.add(alias_node)
return alias_node
def register_alias(alias_node: AliasTypeNode) -> None: def register_alias(alias_node: AliasTypeNode) -> None:
typename = alias_node.typename typename = alias_node.typename
# Check if alias is already registered # Check if alias is already registered
if typename in aliases: if typename in aliases:
return return
# Collect required imports for alias definition
for required_import in alias_node.required_definition_imports:
required_imports.add(required_import)
if isinstance(alias_node.value, AggregatedTypeNode): if isinstance(alias_node.value, AggregatedTypeNode):
# Check if collection contains a link to another alias # Check if collection contains a link to another alias
register_alias_links_from_aggregated_type(alias_node.value) register_alias_links_from_aggregated_type(alias_node.value)
# Remove references to alias nodes
for i, item in enumerate(alias_node.value.items):
# Process enumerations only
if not isinstance(item, ASTNodeTypeNode) or item.ast_node is None:
continue
if item.ast_node.node_type != ASTNodeType.Enumeration:
continue
alias_node.value.items[i] = create_alias_for_enum_node(item.ast_node)
if isinstance(alias_node.value, ASTNodeTypeNode) \
and alias_node.value.ast_node == ASTNodeType.Enumeration:
alias_node.value = create_alias_for_enum_node(alias_node.ast_node)
# Strip module prefix from aliased types # Strip module prefix from aliased types
aliases[typename] = alias_node.value.full_typename.replace( aliases[typename] = alias_node.value.full_typename.replace(
root.export_name + ".typing.", "" root.export_name + ".typing.", ""
) )
if alias_node.doc is not None: if alias_node.doc is not None:
aliases[typename] += f'\n"""{alias_node.doc}"""' aliases[typename] += f'\n"""{alias_node.doc}"""'
for required_import in alias_node.required_definition_imports:
required_imports.add(required_import)
output_path = Path(output_path) / root.export_name / "typing" output_path = Path(output_path) / root.export_name / "typing"
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
required_imports: Set[str] = set() required_imports: Set[str] = set()
aliases: Dict[str, str] = {} aliases: Dict[str, str] = {}
type_checking_time_definitions: Set[AliasTypeNode] = set()
# Resolve each node and register aliases # Resolve each node and register aliases
TypeNode.compatible_to_runtime_usage = True TypeNode.compatible_to_runtime_usage = True
@ -655,11 +693,16 @@ def _generate_typing_module(root: NamespaceNode, output_path: Path) -> None:
_write_required_imports(required_imports, output_stream) _write_required_imports(required_imports, output_stream)
# Add type checking time definitions as generated __init__.py content
for alias in type_checking_time_definitions:
output_stream.write("if typing.TYPE_CHECKING:\n ")
output_stream.write(f"{alias.typename} = {alias.ctype_name}\nelse:\n")
output_stream.write(f" {alias.typename} = {alias.value.ctype_name}\n")
if type_checking_time_definitions:
output_stream.write("\n\n")
for alias_name, alias_type in aliases.items(): for alias_name, alias_type in aliases.items():
output_stream.write(alias_name) output_stream.write(f"{alias_name} = {alias_type}\n")
output_stream.write(" = ")
output_stream.write(alias_type)
output_stream.write("\n")
TypeNode.compatible_to_runtime_usage = False TypeNode.compatible_to_runtime_usage = False
(output_path / "__init__.py").write_text(output_stream.getvalue()) (output_path / "__init__.py").write_text(output_stream.getvalue())

@ -1,4 +1,4 @@
from .node import ASTNode from .node import ASTNode, ASTNodeType
from .namespace_node import NamespaceNode from .namespace_node import NamespaceNode
from .class_node import ClassNode, ClassProperty from .class_node import ClassNode, ClassProperty
from .function_node import FunctionNode from .function_node import FunctionNode

@ -1,7 +1,7 @@
import abc import abc
import enum import enum
import itertools import itertools
from typing import (Iterator, Type, TypeVar, Iterable, Dict, from typing import (Iterator, Type, TypeVar, Dict,
Optional, Tuple, DefaultDict) Optional, Tuple, DefaultDict)
from collections import defaultdict from collections import defaultdict

@ -417,6 +417,10 @@ class ASTNodeTypeNode(TypeNode):
self._module_name = module_name self._module_name = module_name
self._ast_node: Optional[weakref.ProxyType[ASTNode]] = None self._ast_node: Optional[weakref.ProxyType[ASTNode]] = None
@property
def ast_node(self):
return self._ast_node
@property @property
def typename(self) -> str: def typename(self) -> str:
if self._ast_node is None: if self._ast_node is None:

@ -12,6 +12,7 @@ if sys.version_info >= (3, 6):
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Set, Any, Sequence, Generator, Union from typing import Dict, Set, Any, Sequence, Generator, Union
import traceback
from pathlib import Path from pathlib import Path
@ -46,10 +47,12 @@ if sys.version_info >= (3, 6):
try: try:
ret_type = func(*args, **kwargs) ret_type = func(*args, **kwargs)
except Exception as e: except Exception:
self.has_failure = True self.has_failure = True
warnings.warn( warnings.warn(
'Typing stubs generation has failed. Reason: {}'.format(e) "Typing stubs generation has failed.\n{}".format(
traceback.format_exc()
)
) )
if ret_type_on_failure is None: if ret_type_on_failure is None:
return None return None

Loading…
Cancel
Save