diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi index c1987b55ff7..aae9daefd99 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi @@ -40,6 +40,8 @@ cdef class _AioCall(GrpcCallWrapper): list _waiters_status list _waiters_initial_metadata + int _send_initial_metadata_flags + cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except * cdef void _set_status(self, AioRpcStatus status) except * cdef void _set_initial_metadata(self, tuple initial_metadata) except * diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 6de1fa0b834..e74b4e6ba05 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -30,10 +30,22 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' '>') +cdef int _get_send_initial_metadata_flags(object wait_for_ready) except *: + cdef int flags = 0 + # Wait-for-ready can be None, which means using default value in Core. + if wait_for_ready is not None: + flags |= InitialMetadataFlags.wait_for_ready_explicitly_set + if wait_for_ready: + flags |= InitialMetadataFlags.wait_for_ready + + flags &= InitialMetadataFlags.used_mask + return flags + + cdef class _AioCall(GrpcCallWrapper): def __cinit__(self, AioChannel channel, object deadline, - bytes method, CallCredentials call_credentials): + bytes method, CallCredentials call_credentials, object wait_for_ready): self.call = NULL self._channel = channel self._loop = channel.loop @@ -45,6 +57,7 @@ cdef class _AioCall(GrpcCallWrapper): self._done_callbacks = [] self._is_locally_cancelled = False self._deadline = deadline + self._send_initial_metadata_flags = _get_send_initial_metadata_flags(wait_for_ready) self._create_grpc_call(deadline, method, call_credentials) def __dealloc__(self): @@ -279,7 +292,7 @@ cdef class _AioCall(GrpcCallWrapper): cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation( outbound_initial_metadata, - GRPC_INITIAL_METADATA_USED_MASK) + self._send_initial_metadata_flags) cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS) cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS) cdef ReceiveInitialMetadataOperation receive_initial_metadata_op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS) @@ -366,12 +379,12 @@ cdef class _AioCall(GrpcCallWrapper): """Implementation of the start of a unary-stream call.""" # Peer may prematurely end this RPC at any point. We need a corutine # that watches if the server sends the final status. - self._loop.create_task(self._handle_status_once_received()) + status_task = self._loop.create_task(self._handle_status_once_received()) cdef tuple outbound_ops cdef Operation initial_metadata_op = SendInitialMetadataOperation( outbound_initial_metadata, - GRPC_INITIAL_METADATA_USED_MASK) + self._send_initial_metadata_flags) cdef Operation send_message_op = SendMessageOperation( request, _EMPTY_FLAGS) @@ -384,16 +397,21 @@ cdef class _AioCall(GrpcCallWrapper): send_close_op, ) - # Sends out the request message. - await execute_batch(self, - outbound_ops, - self._loop) - - # Receives initial metadata. - self._set_initial_metadata( - await _receive_initial_metadata(self, - self._loop), - ) + try: + # Sends out the request message. + await execute_batch(self, + outbound_ops, + self._loop) + + # Receives initial metadata. + self._set_initial_metadata( + await _receive_initial_metadata(self, + self._loop), + ) + except ExecuteBatchError as batch_error: + # Core should explain why this batch failed + await status_task + assert self._status.code() != StatusCode.ok async def stream_unary(self, tuple outbound_initial_metadata, @@ -404,17 +422,27 @@ cdef class _AioCall(GrpcCallWrapper): propagate the final status exception, then we have to raise it. Othersize, it would end normally and raise `StopAsyncIteration()`. """ - # Sends out initial_metadata ASAP. - await _send_initial_metadata(self, - outbound_initial_metadata, - self._loop) - # Notify upper level that sending messages are allowed now. - metadata_sent_observer() - - # Receives initial metadata. - self._set_initial_metadata( - await _receive_initial_metadata(self, self._loop) - ) + try: + # Sends out initial_metadata ASAP. + await _send_initial_metadata(self, + outbound_initial_metadata, + self._send_initial_metadata_flags, + self._loop) + # Notify upper level that sending messages are allowed now. + metadata_sent_observer() + + # Receives initial metadata. + self._set_initial_metadata( + await _receive_initial_metadata(self, self._loop) + ) + except ExecuteBatchError: + # Core should explain why this batch failed + await self._handle_status_once_received() + assert self._status.code() != StatusCode.ok + + # Allow upper layer to proceed only if the status is set + metadata_sent_observer() + return None cdef tuple inbound_ops cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS) @@ -452,16 +480,25 @@ cdef class _AioCall(GrpcCallWrapper): """ # Peer may prematurely end this RPC at any point. We need a corutine # that watches if the server sends the final status. - self._loop.create_task(self._handle_status_once_received()) - - # Sends out initial_metadata ASAP. - await _send_initial_metadata(self, - outbound_initial_metadata, - self._loop) - # Notify upper level that sending messages are allowed now. - metadata_sent_observer() - - # Receives initial metadata. - self._set_initial_metadata( - await _receive_initial_metadata(self, self._loop) - ) + status_task = self._loop.create_task(self._handle_status_once_received()) + + try: + # Sends out initial_metadata ASAP. + await _send_initial_metadata(self, + outbound_initial_metadata, + self._send_initial_metadata_flags, + self._loop) + # Notify upper level that sending messages are allowed now. + metadata_sent_observer() + + # Receives initial metadata. + self._set_initial_metadata( + await _receive_initial_metadata(self, self._loop) + ) + except ExecuteBatchError as batch_error: + # Core should explain why this batch failed + await status_task + assert self._status.code() != StatusCode.ok + + # Allow upper layer to proceed only if the status is set + metadata_sent_observer() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi index f40951060f0..69f3fcffbbf 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi @@ -164,10 +164,11 @@ async def _send_message(GrpcCallWrapper grpc_call_wrapper, async def _send_initial_metadata(GrpcCallWrapper grpc_call_wrapper, tuple metadata, + int flags, object loop): cdef SendInitialMetadataOperation op = SendInitialMetadataOperation( metadata, - _EMPTY_FLAG) + flags) cdef tuple ops = (op,) await execute_batch(grpc_call_wrapper, ops, loop) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index 6c4b8422cdd..fe4a4e14954 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -99,7 +99,8 @@ cdef class AioChannel: def call(self, bytes method, object deadline, - object python_call_credentials): + object python_call_credentials, + object wait_for_ready): """Assembles a Cython Call object. Returns: @@ -115,4 +116,4 @@ cdef class AioChannel: else: cython_call_credentials = None - return _AioCall(self, deadline, method, cython_call_credentials) + return _AioCall(self, deadline, method, cython_call_credentials, wait_for_ready) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi index e27613c01e7..af609603c60 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi @@ -87,7 +87,7 @@ cdef class _AsyncioSocket: except Exception as e: error = True error_msg = "%s: %s" % (type(e), str(e)) - _LOGGER.exception(e) + _LOGGER.debug(e) finally: self._task_read = None @@ -167,6 +167,11 @@ cdef class _AsyncioSocket: self._py_socket.close() def _new_connection_callback(self, object reader, object writer): + # Close the connection if server is not started yet. + if self._grpc_accept_cb == NULL: + writer.close() + return + client_socket = _AsyncioSocket.create( self._grpc_client_socket, reader, diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 8186fa95697..3add587be6c 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -15,6 +15,7 @@ import asyncio from functools import partial +import logging from typing import AsyncIterable, Awaitable, Dict, Optional import grpc @@ -43,6 +44,8 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' '\tdebug_error_string = "{}"\n' '>') +_LOGGER = logging.getLogger(__name__) + class AioRpcError(grpc.RpcError): """An implementation of RpcError to be used by the asynchronous API. @@ -168,8 +171,9 @@ class Call: self._response_deserializer = response_deserializer def __del__(self) -> None: - if not self._cython_call.done(): - self._cancel(_GC_CANCELLATION_DETAILS) + if hasattr(self, '_cython_call'): + if not self._cython_call.done(): + self._cancel(_GC_CANCELLATION_DETAILS) def cancelled(self) -> bool: return self._cython_call.cancelled() @@ -345,9 +349,15 @@ class _StreamRequestMixin(Call): async def _consume_request_iterator( self, request_async_iterator: AsyncIterable[RequestType]) -> None: - async for request in request_async_iterator: - await self.write(request) - await self.done_writing() + try: + async for request in request_async_iterator: + await self.write(request) + await self.done_writing() + except AioRpcError as rpc_error: + # Rpc status should be exposed through other API. Exceptions raised + # within this Task won't be retrieved by another coroutine. It's + # better to suppress the error than spamming users' screen. + _LOGGER.debug('Exception while consuming of the request_iterator: %s', rpc_error) async def write(self, request: RequestType) -> None: if self.done(): @@ -356,6 +366,8 @@ class _StreamRequestMixin(Call): raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) if not self._metadata_sent.is_set(): await self._metadata_sent.wait() + if self.done(): + await self._raise_for_status() serialized_request = _common.serialize(request, self._request_serializer) @@ -394,11 +406,12 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): def __init__(self, request: RequestType, deadline: Optional[float], metadata: MetadataType, credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: - super().__init__(channel.call(method, deadline, credentials), metadata, + super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) self._request = request self._init_unary_response_mixin(self._invoke()) @@ -436,11 +449,12 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): def __init__(self, request: RequestType, deadline: Optional[float], metadata: MetadataType, credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: - super().__init__(channel.call(method, deadline, credentials), metadata, + super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) self._request = request self._send_unary_request_task = loop.create_task( @@ -471,11 +485,12 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, request_async_iterator: Optional[AsyncIterable[RequestType]], deadline: Optional[float], metadata: MetadataType, credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: - super().__init__(channel.call(method, deadline, credentials), metadata, + super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) self._init_stream_request_mixin(request_async_iterator) @@ -509,11 +524,12 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, request_async_iterator: Optional[AsyncIterable[RequestType]], deadline: Optional[float], metadata: MetadataType, credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: - super().__init__(channel.call(method, deadline, credentials), metadata, + super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) self._initializer = self._loop.create_task(self._prepare_rpc()) self._init_stream_request_mixin(request_async_iterator) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 2788f4416e0..547a5d83b29 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -101,9 +101,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): raised RpcError will also be a Call for the RPC affording the RPC's metadata, status code, and details. """ - if wait_for_ready: - raise NotImplementedError( - "TODO: wait_for_ready not implemented yet") if compression: raise NotImplementedError("TODO: compression not implemented yet") @@ -112,12 +109,13 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): if not self._interceptors: return UnaryUnaryCall(request, _timeout_to_deadline(timeout), - metadata, credentials, self._channel, + metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) else: return InterceptedUnaryUnaryCall(self._interceptors, request, timeout, metadata, credentials, + wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, @@ -154,10 +152,6 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): Returns: A Call object instance which is an awaitable object. """ - if wait_for_ready: - raise NotImplementedError( - "TODO: wait_for_ready not implemented yet") - if compression: raise NotImplementedError("TODO: compression not implemented yet") @@ -165,7 +159,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): if metadata is None: metadata = _IMMUTABLE_EMPTY_TUPLE - return UnaryStreamCall(request, deadline, metadata, credentials, + return UnaryStreamCall(request, deadline, metadata, credentials,wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) @@ -205,10 +199,6 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): raised RpcError will also be a Call for the RPC affording the RPC's metadata, status code, and details. """ - if wait_for_ready: - raise NotImplementedError( - "TODO: wait_for_ready not implemented yet") - if compression: raise NotImplementedError("TODO: compression not implemented yet") @@ -217,7 +207,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): metadata = _IMMUTABLE_EMPTY_TUPLE return StreamUnaryCall(request_async_iterator, deadline, metadata, - credentials, self._channel, self._method, + credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) @@ -256,10 +246,6 @@ class StreamStreamMultiCallable(_BaseMultiCallable): raised RpcError will also be a Call for the RPC affording the RPC's metadata, status code, and details. """ - if wait_for_ready: - raise NotImplementedError( - "TODO: wait_for_ready not implemented yet") - if compression: raise NotImplementedError("TODO: compression not implemented yet") @@ -268,7 +254,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable): metadata = _IMMUTABLE_EMPTY_TUPLE return StreamStreamCall(request_async_iterator, deadline, metadata, - credentials, self._channel, self._method, + credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 884d7c98f1c..4bf281b2565 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -15,5 +15,6 @@ "unit.interceptor_test.TestInterceptedUnaryUnaryCall", "unit.interceptor_test.TestUnaryUnaryClientInterceptor", "unit.metadata_test.TestMetadata", - "unit.server_test.TestServer" + "unit.server_test.TestServer", + "unit.wait_for_ready.TestWaitForReady" ] diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index 68b75332daa..535f3d1b157 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import grpc +from grpc.experimental import aio from grpc.experimental.aio._typing import MetadataType, MetadatumType @@ -22,3 +24,11 @@ def seen_metadata(expected: MetadataType, actual: MetadataType): def seen_metadatum(expected: MetadatumType, actual: MetadataType): metadata_dict = dict(actual) return metadata_dict.get(expected[0]) == expected[1] + + +async def block_until_certain_state(channel: aio.Channel, expected_state: grpc.ChannelConnectivity): + state = channel.get_state() + while state != expected_state: + import logging;logging.debug('Get %s want %s', state, expected_state) + await channel.wait_for_state_change(state) + state = channel.get_state() diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 03aea81ec91..b289b67dbb0 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -87,7 +87,7 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): response_parameters.size)) -async def start_test_server(secure=False): +async def start_test_server(port=0, secure=False): server = aio.server(options=(('grpc.so_reuseport', 0),)) servicer = _TestServiceServicer() test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) @@ -109,10 +109,11 @@ async def start_test_server(secure=False): if secure: server_credentials = grpc.local_server_credentials( grpc.LocalConnectionType.LOCAL_TCP) - port = server.add_secure_port('[::]:0', server_credentials) + port = server.add_secure_port(f'[::]:{port}', server_credentials) else: - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port(f'[::]:{port}') await server.start() + # NOTE(lidizheng) returning the server to prevent it from deallocation return 'localhost:%d' % port, server diff --git a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py index 95a819b2b5f..388be97b2b5 100644 --- a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py @@ -28,13 +28,6 @@ from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server -async def _block_until_certain_state(channel, expected_state): - state = channel.get_state() - while state != expected_state: - await channel.wait_for_state_change(state) - state = channel.get_state() - - class TestConnectivityState(AioTestBase): async def setUp(self): @@ -52,7 +45,7 @@ class TestConnectivityState(AioTestBase): # Should not time out await asyncio.wait_for( - _block_until_certain_state( + _common.block_until_certain_state( channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE), test_constants.SHORT_TIMEOUT) @@ -63,7 +56,7 @@ class TestConnectivityState(AioTestBase): # Should not time out await asyncio.wait_for( - _block_until_certain_state(channel, + _common.block_until_certain_state(channel, grpc.ChannelConnectivity.READY), test_constants.SHORT_TIMEOUT) @@ -75,7 +68,7 @@ class TestConnectivityState(AioTestBase): # If timed out, the function should return None. with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for( - _block_until_certain_state(channel, + _common.block_until_certain_state(channel, grpc.ChannelConnectivity.READY), test_constants.SHORT_TIMEOUT) diff --git a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py index 93eddcf0917..a312e45711f 100644 --- a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py @@ -13,20 +13,6 @@ # limitations under the License. """Testing the done callbacks mechanism.""" -# Copyright 2019 The gRPC Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import asyncio import logging import unittest diff --git a/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py b/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py new file mode 100644 index 00000000000..b5decf7dd01 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py @@ -0,0 +1,136 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the done callbacks mechanism.""" + +import asyncio +import logging +import unittest +import time +import gc + +import grpc +from grpc.experimental import aio +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import get_socket +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit import _common + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + +async def _perform_unary_unary(stub, wait_for_ready): + await stub.UnaryCall(messages_pb2.SimpleRequest(), timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready) + + +async def _perform_unary_stream(stub, wait_for_ready): + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + call = stub.StreamingOutputCall(request, timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.read() + assert await call.code() == grpc.StatusCode.OK + + +async def _perform_stream_unary(stub, wait_for_ready): + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + await stub.StreamingInputCall(gen(), timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready) + + +async def _perform_stream_stream(stub, wait_for_ready): + call = stub.FullDuplexCall(timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready) + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) + + await call.done_writing() + assert await call.code() == grpc.StatusCode.OK + + +_RPC_ACTIONS = ( + _perform_unary_unary, + _perform_unary_stream, + _perform_stream_unary, + _perform_stream_stream, +) + + +class TestWaitForReady(AioTestBase): + + async def setUp(self): + address, self._port, self._socket = get_socket(listen=False) + self._channel = aio.insecure_channel(f"{address}:{self._port}") + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + self._socket.close() + + async def tearDown(self): + await self._channel.close() + + async def _connection_fails_fast(self, wait_for_ready): + for action in _RPC_ACTIONS: + with self.subTest(name=action): + with self.assertRaises(aio.AioRpcError) as exception_context: + await action(self._stub, wait_for_ready) + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + async def test_call_wait_for_ready_default(self): + await self._connection_fails_fast(None) + + async def test_call_wait_for_ready_disabled(self): + await self._connection_fails_fast(False) + + async def test_call_wait_for_ready_enabled(self): + for action in _RPC_ACTIONS: + with self.subTest(name=action.__name__): + # Starts the RPC + action_task = self.loop.create_task(action(self._stub, True)) + + # Wait for TRANSIENT_FAILURE, and RPC is not aborting + await _common.block_until_certain_state( + self._channel, + grpc.ChannelConnectivity.TRANSIENT_FAILURE) + + try: + # Start the server + _, server = await start_test_server(port=self._port) + + # The RPC should recover itself + await action_task + finally: + if server is not None: + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2)