Cancel the async response poller task when the RPC finishes to avoid races.

pull/37355/head
Kyle Brooks 4 months ago
parent 36123761a1
commit a3ab7c73aa
  1. 8
      src/python/grpcio/grpc/aio/_call.py
  2. 15
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  3. 34
      src/python/grpcio_tests/tests_aio/unit/channel_test.py

@ -420,6 +420,14 @@ class _StreamRequestMixin(Call):
self._async_request_poller = self._loop.create_task( self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_iterator) self._consume_request_iterator(request_iterator)
) )
# Cancel the Task when the RPC is done.
# If the RPC fails immediately but this Task is still pending or running,
# these errors will occur:
# - Task was destroyed but it is pending!
# - aclose(): asynchronous generator is already running
self.add_done_callback(
lambda call: call._async_request_poller.cancel()
)
self._request_style = _APIStyle.ASYNC_GENERATOR self._request_style = _APIStyle.ASYNC_GENERATOR
else: else:
self._async_request_poller = None self._async_request_poller = None

@ -120,6 +120,14 @@ class TestServiceServicer(test_pb2_grpc.TestServiceServicer):
aggregated_payload_size=aggregate_size aggregated_payload_size=aggregate_size
) )
async def StreamingInputCallWithUnavailable(
self, unused_request_async_iterator, context
):
self._append_to_log()
await context.abort(
grpc.StatusCode.UNAVAILABLE, "Service is unavailable"
)
async def FullDuplexCall(self, request_async_iterator, context): async def FullDuplexCall(self, request_async_iterator, context):
self._append_to_log() self._append_to_log()
await _maybe_echo_metadata(context) await _maybe_echo_metadata(context)
@ -151,7 +159,12 @@ def _create_extra_generic_handler(servicer: TestServiceServicer):
servicer.UnaryCallWithSleep, servicer.UnaryCallWithSleep,
request_deserializer=messages_pb2.SimpleRequest.FromString, request_deserializer=messages_pb2.SimpleRequest.FromString,
response_serializer=messages_pb2.SimpleResponse.SerializeToString, response_serializer=messages_pb2.SimpleResponse.SerializeToString,
) ),
"StreamingInputCallWithUnavailable": grpc.stream_unary_rpc_method_handler(
servicer.StreamingInputCallWithUnavailable,
request_deserializer=messages_pb2.StreamingInputCallRequest.FromString,
response_serializer=messages_pb2.StreamingInputCallResponse.SerializeToString,
),
} }
return grpc.method_handlers_generic_handler( return grpc.method_handlers_generic_handler(
"grpc.testing.TestService", rpc_method_handlers "grpc.testing.TestService", rpc_method_handlers

@ -30,6 +30,9 @@ from tests_aio.unit._test_server import start_test_server
_UNARY_CALL_METHOD = "/grpc.testing.TestService/UnaryCall" _UNARY_CALL_METHOD = "/grpc.testing.TestService/UnaryCall"
_UNARY_CALL_METHOD_WITH_SLEEP = "/grpc.testing.TestService/UnaryCallWithSleep" _UNARY_CALL_METHOD_WITH_SLEEP = "/grpc.testing.TestService/UnaryCallWithSleep"
_STREAMING_INPUT_CALL_METHOD_WITH_UNAVAILABLE = (
"/grpc.testing.TestService/StreamingInputCallWithUnavailable"
)
_STREAMING_OUTPUT_CALL_METHOD = "/grpc.testing.TestService/StreamingOutputCall" _STREAMING_OUTPUT_CALL_METHOD = "/grpc.testing.TestService/StreamingOutputCall"
_INVOCATION_METADATA = ( _INVOCATION_METADATA = (
@ -198,6 +201,37 @@ class TestChannel(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close() await channel.close()
async def test_stream_unary_unavailable_using_async_gen(self):
channel = aio.insecure_channel(self._server_target)
hi = channel.stream_unary(
_STREAMING_INPUT_CALL_METHOD_WITH_UNAVAILABLE,
request_serializer=messages_pb2.StreamingInputCallRequest.SerializeToString,
response_deserializer=messages_pb2.StreamingInputCallResponse.FromString,
)
# Prepares the request
payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
# Invokes the actual RPC
call: aio.StreamUnaryCall = hi(gen())
# Validates the responses
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(
grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code(),
)
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
# Verify the async request poller task is done
self.assertTrue(call._async_request_poller.done())
async def test_stream_stream_using_read_write(self): async def test_stream_stream_using_read_write(self):
channel = aio.insecure_channel(self._server_target) channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)

Loading…
Cancel
Save