[aio types] Fix some grpc.aio python types (#32475)

With these, it is actually possible to have typed client stubs where the
return type is correctly inferred.

It's only for the non-streaming calls, because there is
`RequestIterableType` for the streaming ones (but it's just Any with
extra steps and would require much more work).

---------

Co-authored-by: Xuan Wang <xuanwn@google.com>
pull/33123/head
Vojtěch Boček 2 years ago committed by GitHub
parent b8a6b4267d
commit 5bd38df3c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      src/python/grpcio/grpc/aio/_base_call.py
  2. 16
      src/python/grpcio/grpc/aio/_base_channel.py
  3. 7
      src/python/grpcio/grpc/aio/_call.py
  4. 10
      src/python/grpcio/grpc/aio/_channel.py

@ -20,7 +20,7 @@ RPC, e.g. cancellation.
from abc import ABCMeta from abc import ABCMeta
from abc import abstractmethod from abc import abstractmethod
from typing import AsyncIterator, Awaitable, Generic, Optional, Union from typing import Any, AsyncIterator, Generator, Generic, Optional, Union
import grpc import grpc
@ -141,7 +141,7 @@ class UnaryUnaryCall(Generic[RequestType, ResponseType],
"""The abstract base class of an unary-unary RPC on the client-side.""" """The abstract base class of an unary-unary RPC on the client-side."""
@abstractmethod @abstractmethod
def __await__(self) -> Awaitable[ResponseType]: def __await__(self) -> Generator[Any, None, ResponseType]:
"""Await the response message to be ready. """Await the response message to be ready.
Returns: Returns:
@ -197,7 +197,7 @@ class StreamUnaryCall(Generic[RequestType, ResponseType],
""" """
@abstractmethod @abstractmethod
def __await__(self) -> Awaitable[ResponseType]: def __await__(self) -> Generator[Any, None, ResponseType]:
"""Await the response message to be ready. """Await the response message to be ready.
Returns: Returns:

@ -14,7 +14,7 @@
"""Abstract base classes for Channel objects and Multicallable objects.""" """Abstract base classes for Channel objects and Multicallable objects."""
import abc import abc
from typing import Any, Optional from typing import Generic, Optional
import grpc import grpc
@ -22,23 +22,25 @@ from . import _base_call
from ._typing import DeserializingFunction from ._typing import DeserializingFunction
from ._typing import MetadataType from ._typing import MetadataType
from ._typing import RequestIterableType from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction from ._typing import SerializingFunction
class UnaryUnaryMultiCallable(abc.ABC): class UnaryUnaryMultiCallable(Generic[RequestType, ResponseType], abc.ABC):
"""Enables asynchronous invocation of a unary-call RPC.""" """Enables asynchronous invocation of a unary-call RPC."""
@abc.abstractmethod @abc.abstractmethod
def __call__( def __call__(
self, self,
request: Any, request: RequestType,
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None, metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall: ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
"""Asynchronously invokes the underlying RPC. """Asynchronously invokes the underlying RPC.
Args: Args:
@ -63,20 +65,20 @@ class UnaryUnaryMultiCallable(abc.ABC):
""" """
class UnaryStreamMultiCallable(abc.ABC): class UnaryStreamMultiCallable(Generic[RequestType, ResponseType], abc.ABC):
"""Enables asynchronous invocation of a server-streaming RPC.""" """Enables asynchronous invocation of a server-streaming RPC."""
@abc.abstractmethod @abc.abstractmethod
def __call__( def __call__(
self, self,
request: Any, request: RequestType,
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None, metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall: ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
"""Asynchronously invokes the underlying RPC. """Asynchronously invokes the underlying RPC.
Args: Args:

@ -19,7 +19,7 @@ from functools import partial
import inspect import inspect
import logging import logging
import traceback import traceback
from typing import AsyncIterator, Optional, Tuple from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple
import grpc import grpc
from grpc import _common from grpc import _common
@ -252,7 +252,7 @@ class _APIStyle(enum.IntEnum):
READER_WRITER = 2 READER_WRITER = 2
class _UnaryResponseMixin(Call): class _UnaryResponseMixin(Call, Generic[ResponseType]):
_call_response: asyncio.Task _call_response: asyncio.Task
def _init_unary_response_mixin(self, response_task: asyncio.Task): def _init_unary_response_mixin(self, response_task: asyncio.Task):
@ -265,7 +265,7 @@ class _UnaryResponseMixin(Call):
else: else:
return False return False
def __await__(self) -> ResponseType: def __await__(self) -> Generator[Any, None, ResponseType]:
"""Wait till the ongoing RPC request finishes.""" """Wait till the ongoing RPC request finishes."""
try: try:
response = yield from self._call_response response = yield from self._call_response
@ -573,6 +573,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
await self._raise_for_status() await self._raise_for_status()
# pylint: disable=too-many-ancestors
class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
_base_call.StreamUnaryCall): _base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls. """Object for managing stream-unary RPC calls.

@ -42,6 +42,8 @@ from ._metadata import Metadata
from ._typing import ChannelArgumentType from ._typing import ChannelArgumentType
from ._typing import DeserializingFunction from ._typing import DeserializingFunction
from ._typing import RequestIterableType from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction from ._typing import SerializingFunction
from ._utils import _timeout_to_deadline from ._utils import _timeout_to_deadline
@ -121,14 +123,14 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable,
def __call__( def __call__(
self, self,
request: Any, request: RequestType,
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[Metadata] = None, metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall: ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
metadata = self._init_metadata(metadata, compression) metadata = self._init_metadata(metadata, compression)
if not self._interceptors: if not self._interceptors:
@ -152,14 +154,14 @@ class UnaryStreamMultiCallable(_BaseMultiCallable,
def __call__( def __call__(
self, self,
request: Any, request: RequestType,
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[Metadata] = None, metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall: ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
metadata = self._init_metadata(metadata, compression) metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)

Loading…
Cancel
Save