From 46e963f8bceb795a90890e86c4903516dab4a7a0 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 5 Dec 2019 16:51:05 -0800 Subject: [PATCH] Let streaming RPC start immediately --- .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 91 +++++++++++-------- .../grpcio/grpc/experimental/aio/_call.py | 11 ++- .../tests_aio/unit/server_test.py | 12 +++ 3 files changed, 71 insertions(+), 43 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 43557e68733..14a7e6df197 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -32,6 +32,9 @@ cdef class _AioCall: self._status_received = asyncio.Event(loop=self._loop) + def __dealloc__(self): + self._destroy_grpc_call() + def __repr__(self): class_name = self.__class__.__name__ id_ = id(self) @@ -68,9 +71,13 @@ cdef class _AioCall: grpc_slice_unref(method_slice) cdef void _destroy_grpc_call(self): - """Destroys the corresponding Core object for this RPC.""" + """Destroys the corresponding Core object for this RPC. + + This method is idempotent. Multiple calls should not result in crashes. + """ if self._grpc_call_wrapper.call != NULL: grpc_call_unref(self._grpc_call_wrapper.call) + self._grpc_call_wrapper.call = NULL cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future): """Cancels the RPC in C-Core, and return the final RPC status.""" @@ -183,6 +190,7 @@ cdef class _AioCall: ) status_observer(status) self._status_received.set() + self._destroy_grpc_call() def _handle_cancellation_from_application(self, object cancellation_future, @@ -190,9 +198,33 @@ cdef class _AioCall: def _cancellation_action(finished_future): status = self._cancel_and_create_status(finished_future) status_observer(status) + self._status_received.set() + self._destroy_grpc_call() cancellation_future.add_done_callback(_cancellation_action) + async def _message_async_generator(self): + cdef bytes received_message + + # Infinitely receiving messages, until: + # * EOF, no more messages to read; + # * The client application cancells; + # * The server sends final status. + while True: + if self._status_received.is_set(): + return + + received_message = await _receive_message( + self._grpc_call_wrapper, + self._loop + ) + if received_message is None: + # The read operation failed, C-Core should explain why it fails + await self._status_received.wait() + return + else: + yield received_message + async def unary_stream(self, bytes method, bytes request, @@ -206,7 +238,6 @@ cdef class _AioCall: propagate the final status exception, then we have to raise it. Othersize, it would end normally and raise `StopAsyncIteration()`. """ - cdef bytes received_message cdef tuple outbound_ops cdef Operation initial_metadata_op = SendInitialMetadataOperation( _EMPTY_METADATA, @@ -223,45 +254,25 @@ cdef class _AioCall: send_close_op, ) - # NOTE(lidiz) Not catching CancelledError here, because async - # generators do not have "cancel" method. - try: - self._create_grpc_call(deadline, method) + # Creates the grpc_call C-Core object, it needs to be deleted explicitly + # through _destroy_grpc_call call in other methods. + self._create_grpc_call(deadline, method) - await callback_start_batch( - self._grpc_call_wrapper, - outbound_ops, - self._loop) + # Actually sends out the request message. + await callback_start_batch(self._grpc_call_wrapper, + outbound_ops, + self._loop) - # Peer may prematurely end this RPC at any point. We need a mechanism - # that handles both the normal case and the error case. - self._loop.create_task(self._handle_status_once_received(status_observer)) - self._handle_cancellation_from_application(cancellation_future, - status_observer) + # Peer may prematurely end this RPC at any point. We need a mechanism + # that handles both the normal case and the error case. + self._loop.create_task(self._handle_status_once_received(status_observer)) + self._handle_cancellation_from_application(cancellation_future, + status_observer) - # Receives initial metadata. - initial_metadata_observer( - await _receive_initial_metadata(self._grpc_call_wrapper, - self._loop), - ) + # Receives initial metadata. + initial_metadata_observer( + await _receive_initial_metadata(self._grpc_call_wrapper, + self._loop), + ) - # Infinitely receiving messages, until: - # * EOF, no more messages to read; - # * The client application cancells; - # * The server sends final status. - while True: - if self._status_received.is_set(): - return - - received_message = await _receive_message( - self._grpc_call_wrapper, - self._loop - ) - if received_message is None: - # The read operation failed, wait for status from C-Core. - await self._status_received.wait() - return - else: - yield received_message - finally: - self._destroy_grpc_call() + return self._message_async_generator() diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index d3489506a82..395e57756a6 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -334,6 +334,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction + _call: asyncio.Task _aiter: AsyncIterable[ResponseType] def __init__(self, request: RequestType, deadline: Optional[float], @@ -347,7 +348,8 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer - self._aiter = self._invoke() + self._call = self._loop.create_task(self._invoke()) + self._aiter = self._process() def __del__(self) -> None: if not self._status.done(): @@ -359,7 +361,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): serialized_request = _common.serialize(self._request, self._request_serializer) - async_gen = self._channel.unary_stream( + self._aiter = await self._channel.unary_stream( self._method, serialized_request, self._deadline, @@ -367,7 +369,10 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): self._set_initial_metadata, self._set_status, ) - async for serialized_response in async_gen: + + async def _process(self) -> ResponseType: + await self._call + async for serialized_response in self._aiter: if self._cancellation.done(): await self._status if self._status.done(): diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 962ab520ca9..1ba00c0312a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -105,6 +105,12 @@ class TestServer(AioTestBase): async with aio.insecure_channel(self._server_target) as channel: unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) call = unary_stream_call(_REQUEST) + await self._generic_handler.wait_for_call() + + # Expecting the request message to reach server before retriving + # any responses. + await asyncio.wait_for(self._generic_handler.wait_for_call(), + test_constants.SHORT_TIMEOUT) response_cnt = 0 async for response in call: @@ -118,6 +124,12 @@ class TestServer(AioTestBase): async with aio.insecure_channel(self._server_target) as channel: unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) call = unary_stream_call(_REQUEST) + await self._generic_handler.wait_for_call() + + # Expecting the request message to reach server before retriving + # any responses. + await asyncio.wait_for(self._generic_handler.wait_for_call(), + test_constants.SHORT_TIMEOUT) for _ in range(_NUM_STREAM_RESPONSES): response = await call.read()