[Python AIO] Return EOF from UnaryStreamCall.read() as documented (#36660)

Fix: https://github.com/grpc/grpc/issues/36581

Based on our documentation, we should return `grpc.aio.EOF` to indicate the end of the stream: fb6a57b153/src/python/grpcio/grpc/aio/_base_call.py (L166-L178)

But if the call was intercepted, we're raising `StopAsyncIteration`, This Pr changes the return to match the documentation (Which is also the behavior for non-intercepted call).

Closes #36660

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/36660 from XuanWang-Amos:fix_aio_read_return 4b679ba429
PiperOrigin-RevId: 638681673
pull/36758/head^2
Xuan Wang 6 months ago committed by Copybara-Service
parent 39afbf49f2
commit 1a96ce7620
  1. 13
      src/python/grpcio/grpc/aio/_call.py
  2. 12
      src/python/grpcio/grpc/aio/_interceptor.py
  3. 48
      src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py

@ -19,7 +19,15 @@ from functools import partial
import inspect
import logging
import traceback
from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple
from typing import (
Any,
AsyncIterator,
Generator,
Generic,
Optional,
Tuple,
Union,
)
import grpc
from grpc import _common
@ -29,6 +37,7 @@ from . import _base_call
from ._metadata import Metadata
from ._typing import DeserializingFunction
from ._typing import DoneCallbackType
from ._typing import EOFType
from ._typing import MetadatumType
from ._typing import RequestIterableType
from ._typing import RequestType
@ -380,7 +389,7 @@ class _StreamResponseMixin(Call):
raw_response, self._response_deserializer
)
async def read(self) -> ResponseType:
async def read(self) -> Union[EOFType, ResponseType]:
if self.done():
await self._raise_for_status()
return cygrpc.EOF

@ -43,6 +43,7 @@ from ._call import _RPC_HALF_CLOSED_DETAILS
from ._metadata import Metadata
from ._typing import DeserializingFunction
from ._typing import DoneCallbackType
from ._typing import EOFType
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseIterableType
@ -494,12 +495,15 @@ class _InterceptedStreamResponseMixin:
)
return self._response_aiter
async def read(self) -> ResponseType:
async def read(self) -> Union[EOFType, ResponseType]:
if self._response_aiter is None:
self._response_aiter = (
self._wait_for_interceptor_task_response_iterator()
)
return await self._response_aiter.asend(None)
try:
return await self._response_aiter.asend(None)
except StopAsyncIteration:
return cygrpc.EOF
class _InterceptedStreamRequestMixin:
@ -1141,7 +1145,7 @@ class UnaryStreamCallResponseIterator(
):
"""UnaryStreamCall class wich uses an alternative response iterator."""
async def read(self) -> ResponseType:
async def read(self) -> Union[EOFType, ResponseType]:
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise NotImplementedError()
@ -1152,7 +1156,7 @@ class StreamStreamCallResponseIterator(
):
"""StreamStreamCall class wich uses an alternative response iterator."""
async def read(self) -> ResponseType:
async def read(self) -> Union[EOFType, ResponseType]:
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise NotImplementedError()

@ -223,6 +223,54 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
await channel.close()
async def test_too_many_reads(self):
for interceptor_class in (
[_UnaryStreamInterceptorEmpty],
[_UnaryStreamInterceptorWithResponseIterator],
[],
):
with self.subTest(name=interceptor_class):
if interceptor_class:
interceptor = interceptor_class[0]()
channel = aio.insecure_channel(
self._server_target, interceptors=[interceptor]
)
else:
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend(
[
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE
)
]
* _NUM_STREAM_RESPONSES
)
call = stub.StreamingOutputCall(request)
response_cnt = 0
for response in range(_NUM_STREAM_RESPONSES):
response = await call.read()
response_cnt += 1
self.assertIs(
type(response), messages_pb2.StreamingOutputCallResponse
)
self.assertEqual(
_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)
)
# Additional read() should return EOF
self.assertIs(await call.read(), aio.EOF)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# After the RPC finished, the read should also produce EOF
self.assertIs(await call.read(), aio.EOF)
await channel.close()
async def test_multiple_interceptors_response_iterator(self):
for interceptor_class in (
_UnaryStreamInterceptorEmpty,

Loading…
Cancel
Save