[aio types] Fix some grpc.aio python types (#32475)

With these, it is actually possible to have typed client stubs where the
return type is correctly inferred.

It's only for the non-streaming calls, because there is
`RequestIterableType` for the streaming ones (but it's just Any with
extra steps and would require much more work).

---------

Co-authored-by: Xuan Wang <xuanwn@google.com>
pull/33123/head
Vojtěch Boček 2 years ago committed by GitHub
parent b8a6b4267d
commit 5bd38df3c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      src/python/grpcio/grpc/aio/_base_call.py
  2. 16
      src/python/grpcio/grpc/aio/_base_channel.py
  3. 7
      src/python/grpcio/grpc/aio/_call.py
  4. 10
      src/python/grpcio/grpc/aio/_channel.py

@ -20,7 +20,7 @@ RPC, e.g. cancellation.
from abc import ABCMeta
from abc import abstractmethod
from typing import AsyncIterator, Awaitable, Generic, Optional, Union
from typing import Any, AsyncIterator, Generator, Generic, Optional, Union
import grpc
@ -141,7 +141,7 @@ class UnaryUnaryCall(Generic[RequestType, ResponseType],
"""The abstract base class of an unary-unary RPC on the client-side."""
@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
def __await__(self) -> Generator[Any, None, ResponseType]:
"""Await the response message to be ready.
Returns:
@ -197,7 +197,7 @@ class StreamUnaryCall(Generic[RequestType, ResponseType],
"""
@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
def __await__(self) -> Generator[Any, None, ResponseType]:
"""Await the response message to be ready.
Returns:

@ -14,7 +14,7 @@
"""Abstract base classes for Channel objects and Multicallable objects."""
import abc
from typing import Any, Optional
from typing import Generic, Optional
import grpc
@ -22,23 +22,25 @@ from . import _base_call
from ._typing import DeserializingFunction
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction
class UnaryUnaryMultiCallable(abc.ABC):
class UnaryUnaryMultiCallable(Generic[RequestType, ResponseType], abc.ABC):
"""Enables asynchronous invocation of a unary-call RPC."""
@abc.abstractmethod
def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
"""Asynchronously invokes the underlying RPC.
Args:
@ -63,20 +65,20 @@ class UnaryUnaryMultiCallable(abc.ABC):
"""
class UnaryStreamMultiCallable(abc.ABC):
class UnaryStreamMultiCallable(Generic[RequestType, ResponseType], abc.ABC):
"""Enables asynchronous invocation of a server-streaming RPC."""
@abc.abstractmethod
def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
"""Asynchronously invokes the underlying RPC.
Args:

@ -19,7 +19,7 @@ from functools import partial
import inspect
import logging
import traceback
from typing import AsyncIterator, Optional, Tuple
from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple
import grpc
from grpc import _common
@ -252,7 +252,7 @@ class _APIStyle(enum.IntEnum):
READER_WRITER = 2
class _UnaryResponseMixin(Call):
class _UnaryResponseMixin(Call, Generic[ResponseType]):
_call_response: asyncio.Task
def _init_unary_response_mixin(self, response_task: asyncio.Task):
@ -265,7 +265,7 @@ class _UnaryResponseMixin(Call):
else:
return False
def __await__(self) -> ResponseType:
def __await__(self) -> Generator[Any, None, ResponseType]:
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call_response
@ -573,6 +573,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
await self._raise_for_status()
# pylint: disable=too-many-ancestors
class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
_base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.

@ -42,6 +42,8 @@ from ._metadata import Metadata
from ._typing import ChannelArgumentType
from ._typing import DeserializingFunction
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction
from ._utils import _timeout_to_deadline
@ -121,14 +123,14 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable,
def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
@ -152,14 +154,14 @@ class UnaryStreamMultiCallable(_BaseMultiCallable,
def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)

Loading…
Cancel
Save