Merge pull request #24801 from lidizheng/aio-stream-empty-ping-pong

[Aio] Fix the emtpy response handling in streaming RPC
pull/24732/head
Lidi Zheng 4 years ago committed by GitHub
commit 3b87bf09af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  3. 22
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  4. 31
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -360,7 +360,7 @@ cdef class _AioCall(GrpcCallWrapper):
self,
self._loop
)
if received_message:
if received_message is not None:
return received_message
else:
return EOF

@ -130,6 +130,8 @@ async def _receive_message(GrpcCallWrapper grpc_call_wrapper,
#
# Since they all indicates finish, they are better be merged.
_LOGGER.debug('Failed to receive any message from Core')
# NOTE(lidiz) The returned message might be an empty bytes (aka. b'').
# Please explicitly check if it is None or falsey string object!
return receive_op.message()

@ -67,10 +67,13 @@ class TestServiceServicer(test_pb2_grpc.TestServiceServicer):
await asyncio.sleep(
datetime.timedelta(microseconds=response_parameters.
interval_us).total_seconds())
yield messages_pb2.StreamingOutputCallResponse(
payload=messages_pb2.Payload(type=request.response_type,
body=b'\x00' *
response_parameters.size))
if response_parameters.size != 0:
yield messages_pb2.StreamingOutputCallResponse(
payload=messages_pb2.Payload(type=request.response_type,
body=b'\x00' *
response_parameters.size))
else:
yield messages_pb2.StreamingOutputCallResponse()
# Next methods are extra ones that are registred programatically
# when the sever is instantiated. They are not being provided by
@ -96,10 +99,13 @@ class TestServiceServicer(test_pb2_grpc.TestServiceServicer):
await asyncio.sleep(
datetime.timedelta(microseconds=response_parameters.
interval_us).total_seconds())
yield messages_pb2.StreamingOutputCallResponse(
payload=messages_pb2.Payload(type=request.payload.type,
body=b'\x00' *
response_parameters.size))
if response_parameters.size != 0:
yield messages_pb2.StreamingOutputCallResponse(
payload=messages_pb2.Payload(type=request.payload.type,
body=b'\x00' *
response_parameters.size))
else:
yield messages_pb2.StreamingOutputCallResponse()
def _create_extra_generic_handler(servicer: TestServiceServicer):

@ -472,6 +472,24 @@ class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_empty_responses(self):
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters())
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(b'', response.SerializeToString())
self.assertEqual(grpc.StatusCode.OK, await call.code())
class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
@ -624,6 +642,10 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = messages_pb2.StreamingOutputCallRequest(
)
_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append(
messages_pb2.ResponseParameters())
class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
@ -808,6 +830,15 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_empty_ping_pong(self):
call = self._stub.FullDuplexCall()
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE)
response = await call.read()
self.assertEqual(b'', response.SerializeToString())
await call.done_writing()
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save