From 5bd38df3c28321a7b65f174da77c0f0e7cd2c3b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20Bo=C4=8Dek?= Date: Mon, 15 May 2023 18:15:38 +0200 Subject: [PATCH] [aio types] Fix some grpc.aio python types (#32475) With these, it is actually possible to have typed client stubs where the return type is correctly inferred. It's only for the non-streaming calls, because there is `RequestIterableType` for the streaming ones (but it's just Any with extra steps and would require much more work). --------- Co-authored-by: Xuan Wang --- src/python/grpcio/grpc/aio/_base_call.py | 6 +++--- src/python/grpcio/grpc/aio/_base_channel.py | 16 +++++++++------- src/python/grpcio/grpc/aio/_call.py | 7 ++++--- src/python/grpcio/grpc/aio/_channel.py | 10 ++++++---- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/python/grpcio/grpc/aio/_base_call.py b/src/python/grpcio/grpc/aio/_base_call.py index ff4071758b5..a1226158bac 100644 --- a/src/python/grpcio/grpc/aio/_base_call.py +++ b/src/python/grpcio/grpc/aio/_base_call.py @@ -20,7 +20,7 @@ RPC, e.g. cancellation. from abc import ABCMeta from abc import abstractmethod -from typing import AsyncIterator, Awaitable, Generic, Optional, Union +from typing import Any, AsyncIterator, Generator, Generic, Optional, Union import grpc @@ -141,7 +141,7 @@ class UnaryUnaryCall(Generic[RequestType, ResponseType], """The abstract base class of an unary-unary RPC on the client-side.""" @abstractmethod - def __await__(self) -> Awaitable[ResponseType]: + def __await__(self) -> Generator[Any, None, ResponseType]: """Await the response message to be ready. Returns: @@ -197,7 +197,7 @@ class StreamUnaryCall(Generic[RequestType, ResponseType], """ @abstractmethod - def __await__(self) -> Awaitable[ResponseType]: + def __await__(self) -> Generator[Any, None, ResponseType]: """Await the response message to be ready. Returns: diff --git a/src/python/grpcio/grpc/aio/_base_channel.py b/src/python/grpcio/grpc/aio/_base_channel.py index 4135e4796c7..04b92a42403 100644 --- a/src/python/grpcio/grpc/aio/_base_channel.py +++ b/src/python/grpcio/grpc/aio/_base_channel.py @@ -14,7 +14,7 @@ """Abstract base classes for Channel objects and Multicallable objects.""" import abc -from typing import Any, Optional +from typing import Generic, Optional import grpc @@ -22,23 +22,25 @@ from . import _base_call from ._typing import DeserializingFunction from ._typing import MetadataType from ._typing import RequestIterableType +from ._typing import RequestType +from ._typing import ResponseType from ._typing import SerializingFunction -class UnaryUnaryMultiCallable(abc.ABC): +class UnaryUnaryMultiCallable(Generic[RequestType, ResponseType], abc.ABC): """Enables asynchronous invocation of a unary-call RPC.""" @abc.abstractmethod def __call__( self, - request: Any, + request: RequestType, *, timeout: Optional[float] = None, metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None - ) -> _base_call.UnaryUnaryCall: + ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: """Asynchronously invokes the underlying RPC. Args: @@ -63,20 +65,20 @@ class UnaryUnaryMultiCallable(abc.ABC): """ -class UnaryStreamMultiCallable(abc.ABC): +class UnaryStreamMultiCallable(Generic[RequestType, ResponseType], abc.ABC): """Enables asynchronous invocation of a server-streaming RPC.""" @abc.abstractmethod def __call__( self, - request: Any, + request: RequestType, *, timeout: Optional[float] = None, metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None - ) -> _base_call.UnaryStreamCall: + ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: """Asynchronously invokes the underlying RPC. Args: diff --git a/src/python/grpcio/grpc/aio/_call.py b/src/python/grpcio/grpc/aio/_call.py index 37ba945da73..fcc90066c00 100644 --- a/src/python/grpcio/grpc/aio/_call.py +++ b/src/python/grpcio/grpc/aio/_call.py @@ -19,7 +19,7 @@ from functools import partial import inspect import logging import traceback -from typing import AsyncIterator, Optional, Tuple +from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple import grpc from grpc import _common @@ -252,7 +252,7 @@ class _APIStyle(enum.IntEnum): READER_WRITER = 2 -class _UnaryResponseMixin(Call): +class _UnaryResponseMixin(Call, Generic[ResponseType]): _call_response: asyncio.Task def _init_unary_response_mixin(self, response_task: asyncio.Task): @@ -265,7 +265,7 @@ class _UnaryResponseMixin(Call): else: return False - def __await__(self) -> ResponseType: + def __await__(self) -> Generator[Any, None, ResponseType]: """Wait till the ongoing RPC request finishes.""" try: response = yield from self._call_response @@ -573,6 +573,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): await self._raise_for_status() +# pylint: disable=too-many-ancestors class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall): """Object for managing stream-unary RPC calls. diff --git a/src/python/grpcio/grpc/aio/_channel.py b/src/python/grpcio/grpc/aio/_channel.py index a6fb2221250..f40e413a487 100644 --- a/src/python/grpcio/grpc/aio/_channel.py +++ b/src/python/grpcio/grpc/aio/_channel.py @@ -42,6 +42,8 @@ from ._metadata import Metadata from ._typing import ChannelArgumentType from ._typing import DeserializingFunction from ._typing import RequestIterableType +from ._typing import RequestType +from ._typing import ResponseType from ._typing import SerializingFunction from ._utils import _timeout_to_deadline @@ -121,14 +123,14 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable, def __call__( self, - request: Any, + request: RequestType, *, timeout: Optional[float] = None, metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None - ) -> _base_call.UnaryUnaryCall: + ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: metadata = self._init_metadata(metadata, compression) if not self._interceptors: @@ -152,14 +154,14 @@ class UnaryStreamMultiCallable(_BaseMultiCallable, def __call__( self, - request: Any, + request: RequestType, *, timeout: Optional[float] = None, metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None - ) -> _base_call.UnaryStreamCall: + ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout)