Apply review feedback

pull/23092/head
Pau Freixes 5 years ago
parent b3425f6dbf
commit f9d9793c96
  1. 49
      src/python/grpcio/grpc/experimental/aio/_channel.py
  2. 41
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  3. 3
      src/python/grpcio/grpc/experimental/aio/_typing.py
  4. 17
      src/python/grpcio_tests/tests_aio/unit/_common.py

@ -242,38 +242,23 @@ class Channel(_base_channel.Channel):
self._stream_unary_interceptors = []
self._stream_stream_interceptors = []
if interceptors:
attrs_and_interceptor_classes = ((self._unary_unary_interceptors,
UnaryUnaryClientInterceptor),
(self._unary_stream_interceptors,
UnaryStreamClientInterceptor),
(self._stream_unary_interceptors,
StreamUnaryClientInterceptor),
(self._stream_stream_interceptors,
StreamStreamClientInterceptor))
# pylint: disable=cell-var-from-loop
for attr, interceptor_class in attrs_and_interceptor_classes:
attr.extend([
interceptor for interceptor in interceptors
if isinstance(interceptor, interceptor_class)
])
invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors) - set(
self._unary_stream_interceptors) - set(
self._stream_unary_interceptors) - set(
self._stream_stream_interceptors)
if invalid_interceptors:
raise ValueError(
"Interceptor must be " +
"{} or ".format(UnaryUnaryClientInterceptor.__name__) +
"{} or ".format(UnaryStreamClientInterceptor.__name__) +
"{} or ".format(StreamUnaryClientInterceptor.__name__) +
"{}. ".format(StreamStreamClientInterceptor.__name__) +
"The following are invalid: {}".format(invalid_interceptors)
)
if interceptors is not None:
for interceptor in interceptors:
if isinstance(interceptor, UnaryUnaryClientInterceptor):
self._unary_unary_interceptors.append(interceptor)
elif isinstance(interceptor, UnaryStreamClientInterceptor):
self._unary_stream_interceptors.append(interceptor)
elif isinstance(interceptor, StreamUnaryClientInterceptor):
self._stream_unary_interceptors.append(interceptor)
elif isinstance(interceptor, StreamStreamClientInterceptor):
self._stream_stream_interceptors.append(interceptor)
else:
raise ValueError(
"Interceptor {} must be ".format(interceptor) +
"{} or ".format(UnaryUnaryClientInterceptor.__name__) +
"{} or ".format(UnaryStreamClientInterceptor.__name__) +
"{} or ".format(StreamUnaryClientInterceptor.__name__) +
"{}. ".format(StreamStreamClientInterceptor.__name__))
self._loop = asyncio.get_event_loop()
self._channel = cygrpc.AioChannel(

@ -28,7 +28,7 @@ from ._call import _API_STYLE_ERROR
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
MetadataType, ResponseType, DoneCallbackType,
RequestIterableType)
RequestIterableType, ResponseIterableType)
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
@ -132,7 +132,7 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
self, continuation: Callable[[ClientCallDetails, RequestType],
UnaryStreamCall],
client_call_details: ClientCallDetails, request: RequestType
) -> Union[AsyncIterable[ResponseType], UnaryStreamCall]:
) -> Union[ResponseIterableType, UnaryStreamCall]:
"""Intercepts a unary-stream invocation asynchronously.
The function could return the call object or an asynchronous
@ -212,7 +212,7 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
UnaryStreamCall],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType,
) -> Union[AsyncIterable[ResponseType], StreamStreamCall]:
) -> Union[ResponseIterableType, StreamStreamCall]:
"""Intercepts a stream-stream invocation asynchronously.
Within the interceptor the usage of the call methods like `write` or
@ -434,11 +434,12 @@ class _InterceptedUnaryResponseMixin:
class _InterceptedStreamResponseMixin:
_response_aiter: AsyncIterable[ResponseType]
_response_aiter: Optional[AsyncIterable[ResponseType]]
def _init_stream_response_mixin(self) -> None:
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
# Is initalized later, otherwise if the iterator is not finnally
# consumed a logging warning is emmited by Asyncio.
self._response_aiter = None
async def _wait_for_interceptor_task_response_iterator(self
) -> ResponseType:
@ -447,14 +448,17 @@ class _InterceptedStreamResponseMixin:
yield response
def __aiter__(self) -> AsyncIterable[ResponseType]:
if self._response_aiter is None:
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
return self._response_aiter
async def read(self) -> ResponseType:
if self._response_aiter is None:
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
return await self._response_aiter.asend(None)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class _InterceptedStreamRequestMixin:
@ -945,32 +949,37 @@ class _StreamCallResponseIterator:
async def wait_for_connection(self) -> None:
return await self._call.wait_for_connection()
async def read(self) -> ResponseType:
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise Exception()
class UnaryStreamCallResponseIterator(_StreamCallResponseIterator,
_base_call.UnaryStreamCall):
"""UnaryStreamCall class wich uses an alternative response iterator."""
async def read(self) -> ResponseType:
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise NotImplementedError()
class StreamStreamCallResponseIterator(_StreamCallResponseIterator,
_base_call.StreamStreamCall):
"""UnaryStreamCall class wich uses an alternative response iterator."""
async def read(self) -> ResponseType:
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise NotImplementedError()
async def write(self, request: RequestType) -> None:
# Behind the scenes everyting goes through the
# async iterator provided by the InterceptedStreamStreamCall.
# So this path should not be reached.
raise Exception()
raise NotImplementedError()
async def done_writing(self) -> None:
# Behind the scenes everyting goes through the
# async iterator provided by the InterceptedStreamStreamCall.
# So this path should not be reached.
raise Exception()
raise NotImplementedError()
@property
def _done_writing_flag(self) -> bool:

@ -27,4 +27,5 @@ MetadataType = Sequence[MetadatumType]
ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]
RequestIterableType = Union[Iterable[Any], AsyncIterable[Any]]
RequestIterableType = Union[Iterable[RequestType], AsyncIterable[RequestType]]
ResponseIterableType = AsyncIterable[ResponseType]

@ -15,7 +15,8 @@
import asyncio
import grpc
from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType
from grpc.experimental.aio._typing import MetadataType, MetadatumType, RequestIterableType
from grpc.experimental.aio._typing import ResponseIterableType, RequestType, ResponseType
from tests.unit.framework.common import test_constants
@ -37,7 +38,7 @@ async def block_until_certain_state(channel: aio.Channel,
state = channel.get_state()
def inject_callbacks(call):
def inject_callbacks(call: aio.Call):
first_callback_ran = asyncio.Event()
def first_callback(call):
@ -68,29 +69,29 @@ def inject_callbacks(call):
class CountingRequestIterator:
def __init__(self, request_iterator):
def __init__(self, request_iterator: RequestIterableType) -> None:
self.request_cnt = 0
self._request_iterator = request_iterator
async def _forward_requests(self):
async def _forward_requests(self) -> RequestType:
async for request in self._request_iterator:
self.request_cnt += 1
yield request
def __aiter__(self):
def __aiter__(self) -> RequestIterableType:
return self._forward_requests()
class CountingResponseIterator:
def __init__(self, response_iterator):
def __init__(self, response_iterator: ResponseIterableType) -> None:
self.response_cnt = 0
self._response_iterator = response_iterator
async def _forward_responses(self):
async def _forward_responses(self) -> ResponseType:
async for response in self._response_iterator:
self.response_cnt += 1
yield response
def __aiter__(self):
def __aiter__(self) -> ResponseIterableType:
return self._forward_responses()

Loading…
Cancel
Save