diff --git a/src/python/grpcio/grpc/aio/_call.py b/src/python/grpcio/grpc/aio/_call.py index 24f2090651a..be8855d059b 100644 --- a/src/python/grpcio/grpc/aio/_call.py +++ b/src/python/grpcio/grpc/aio/_call.py @@ -420,6 +420,14 @@ class _StreamRequestMixin(Call): self._async_request_poller = self._loop.create_task( self._consume_request_iterator(request_iterator) ) + # Cancel the Task when the RPC is done. + # If the RPC fails immediately but this Task is still pending or running, + # these errors will occur: + # - Task was destroyed but it is pending! + # - aclose(): asynchronous generator is already running + self.add_done_callback( + lambda call: call._async_request_poller.cancel() + ) self._request_style = _APIStyle.ASYNC_GENERATOR else: self._async_request_poller = None diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 5d25272f300..849a14cca0a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -120,6 +120,14 @@ class TestServiceServicer(test_pb2_grpc.TestServiceServicer): aggregated_payload_size=aggregate_size ) + async def StreamingInputCallWithUnavailable( + self, unused_request_async_iterator, context + ): + self._append_to_log() + await context.abort( + grpc.StatusCode.UNAVAILABLE, "Service is unavailable" + ) + async def FullDuplexCall(self, request_async_iterator, context): self._append_to_log() await _maybe_echo_metadata(context) @@ -151,7 +159,12 @@ def _create_extra_generic_handler(servicer: TestServiceServicer): servicer.UnaryCallWithSleep, request_deserializer=messages_pb2.SimpleRequest.FromString, response_serializer=messages_pb2.SimpleResponse.SerializeToString, - ) + ), + "StreamingInputCallWithUnavailable": grpc.stream_unary_rpc_method_handler( + servicer.StreamingInputCallWithUnavailable, + request_deserializer=messages_pb2.StreamingInputCallRequest.FromString, + response_serializer=messages_pb2.StreamingInputCallResponse.SerializeToString, + ), } return grpc.method_handlers_generic_handler( "grpc.testing.TestService", rpc_method_handlers diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 26ef006fa15..5dac1fbd587 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -30,6 +30,9 @@ from tests_aio.unit._test_server import start_test_server _UNARY_CALL_METHOD = "/grpc.testing.TestService/UnaryCall" _UNARY_CALL_METHOD_WITH_SLEEP = "/grpc.testing.TestService/UnaryCallWithSleep" +_STREAMING_INPUT_CALL_METHOD_WITH_UNAVAILABLE = ( + "/grpc.testing.TestService/StreamingInputCallWithUnavailable" +) _STREAMING_OUTPUT_CALL_METHOD = "/grpc.testing.TestService/StreamingOutputCall" _INVOCATION_METADATA = ( @@ -198,6 +201,37 @@ class TestChannel(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close() + async def test_stream_unary_unavailable_using_async_gen(self): + channel = aio.insecure_channel(self._server_target) + hi = channel.stream_unary( + _STREAMING_INPUT_CALL_METHOD_WITH_UNAVAILABLE, + request_serializer=messages_pb2.StreamingInputCallRequest.SerializeToString, + response_deserializer=messages_pb2.StreamingInputCallResponse.FromString, + ) + # Prepares the request + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + # Invokes the actual RPC + call: aio.StreamUnaryCall = hi(gen()) + + # Validates the responses + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code(), + ) + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + + # Verify the async request poller task is done + self.assertTrue(call._async_request_poller.done()) + async def test_stream_stream_using_read_write(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel)