Merge pull request #21803 from lidizheng/aio-wait

[Aio] Support wait-for-ready mechanism
pull/21865/head
Lidi Zheng 5 years ago committed by GitHub
commit 41bc9b9910
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 110
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  4. 5
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 7
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi
  6. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  7. 60
      src/python/grpcio/grpc/experimental/aio/_call.py
  8. 41
      src/python/grpcio/grpc/experimental/aio/_channel.py
  9. 31
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  10. 3
      src/python/grpcio_tests/tests_aio/tests.json
  11. 10
      src/python/grpcio_tests/tests_aio/unit/_common.py
  12. 7
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  13. 7
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  14. 18
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py
  15. 14
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py
  16. 10
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py
  17. 146
      src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py

@ -40,6 +40,8 @@ cdef class _AioCall(GrpcCallWrapper):
list _waiters_status
list _waiters_initial_metadata
int _send_initial_metadata_flags
cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except *
cdef void _set_status(self, AioRpcStatus status) except *
cdef void _set_initial_metadata(self, tuple initial_metadata) except *

@ -30,10 +30,22 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'>')
cdef int _get_send_initial_metadata_flags(object wait_for_ready) except *:
cdef int flags = 0
# Wait-for-ready can be None, which means using default value in Core.
if wait_for_ready is not None:
flags |= InitialMetadataFlags.wait_for_ready_explicitly_set
if wait_for_ready:
flags |= InitialMetadataFlags.wait_for_ready
flags &= InitialMetadataFlags.used_mask
return flags
cdef class _AioCall(GrpcCallWrapper):
def __cinit__(self, AioChannel channel, object deadline,
bytes method, CallCredentials call_credentials):
bytes method, CallCredentials call_credentials, object wait_for_ready):
self.call = NULL
self._channel = channel
self._loop = channel.loop
@ -45,6 +57,7 @@ cdef class _AioCall(GrpcCallWrapper):
self._done_callbacks = []
self._is_locally_cancelled = False
self._deadline = deadline
self._send_initial_metadata_flags = _get_send_initial_metadata_flags(wait_for_ready)
self._create_grpc_call(deadline, method, call_credentials)
def __dealloc__(self):
@ -279,7 +292,7 @@ cdef class _AioCall(GrpcCallWrapper):
cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation(
outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK)
self._send_initial_metadata_flags)
cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS)
cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS)
cdef ReceiveInitialMetadataOperation receive_initial_metadata_op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
@ -366,12 +379,12 @@ cdef class _AioCall(GrpcCallWrapper):
"""Implementation of the start of a unary-stream call."""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received())
status_task = self._loop.create_task(self._handle_status_once_received())
cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation(
outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK)
self._send_initial_metadata_flags)
cdef Operation send_message_op = SendMessageOperation(
request,
_EMPTY_FLAGS)
@ -384,16 +397,20 @@ cdef class _AioCall(GrpcCallWrapper):
send_close_op,
)
# Sends out the request message.
await execute_batch(self,
outbound_ops,
self._loop)
# Receives initial metadata.
self._set_initial_metadata(
await _receive_initial_metadata(self,
self._loop),
)
try:
# Sends out the request message.
await execute_batch(self,
outbound_ops,
self._loop)
# Receives initial metadata.
self._set_initial_metadata(
await _receive_initial_metadata(self,
self._loop),
)
except ExecuteBatchError as batch_error:
# Core should explain why this batch failed
await status_task
async def stream_unary(self,
tuple outbound_initial_metadata,
@ -404,17 +421,26 @@ cdef class _AioCall(GrpcCallWrapper):
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
outbound_initial_metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
# Receives initial metadata.
self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop)
)
try:
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
outbound_initial_metadata,
self._send_initial_metadata_flags,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
# Receives initial metadata.
self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop)
)
except ExecuteBatchError:
# Core should explain why this batch failed
await self._handle_status_once_received()
# Allow upper layer to proceed only if the status is set
metadata_sent_observer()
return None
cdef tuple inbound_ops
cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS)
@ -452,16 +478,24 @@ cdef class _AioCall(GrpcCallWrapper):
"""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received())
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
outbound_initial_metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
# Receives initial metadata.
self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop)
)
status_task = self._loop.create_task(self._handle_status_once_received())
try:
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
outbound_initial_metadata,
self._send_initial_metadata_flags,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
# Receives initial metadata.
self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop)
)
except ExecuteBatchError as batch_error:
# Core should explain why this batch failed
await status_task
# Allow upper layer to proceed only if the status is set
metadata_sent_observer()

@ -164,10 +164,11 @@ async def _send_message(GrpcCallWrapper grpc_call_wrapper,
async def _send_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
tuple metadata,
int flags,
object loop):
cdef SendInitialMetadataOperation op = SendInitialMetadataOperation(
metadata,
_EMPTY_FLAG)
flags)
cdef tuple ops = (op,)
await execute_batch(grpc_call_wrapper, ops, loop)

@ -99,7 +99,8 @@ cdef class AioChannel:
def call(self,
bytes method,
object deadline,
object python_call_credentials):
object python_call_credentials,
object wait_for_ready):
"""Assembles a Cython Call object.
Returns:
@ -115,4 +116,4 @@ cdef class AioChannel:
else:
cython_call_credentials = None
return _AioCall(self, deadline, method, cython_call_credentials)
return _AioCall(self, deadline, method, cython_call_credentials, wait_for_ready)

@ -87,7 +87,7 @@ cdef class _AsyncioSocket:
except Exception as e:
error = True
error_msg = "%s: %s" % (type(e), str(e))
_LOGGER.exception(e)
_LOGGER.debug(e)
finally:
self._task_read = None
@ -167,6 +167,11 @@ cdef class _AsyncioSocket:
self._py_socket.close()
def _new_connection_callback(self, object reader, object writer):
# Close the connection if server is not started yet.
if self._grpc_accept_cb == NULL:
writer.close()
return
client_socket = _AsyncioSocket.create(
self._grpc_client_socket,
reader,

@ -125,7 +125,7 @@ cdef class _ServicerContext:
if self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
else:
await _send_initial_metadata(self._rpc_state, metadata, self._loop)
await _send_initial_metadata(self._rpc_state, metadata, _EMPTY_FLAG, self._loop)
self._rpc_state.metadata_sent = True
async def abort(self,

@ -15,6 +15,7 @@
import asyncio
from functools import partial
import logging
from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc
@ -43,6 +44,8 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tdebug_error_string = "{}"\n'
'>')
_LOGGER = logging.getLogger(__name__)
class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API.
@ -168,8 +171,10 @@ class Call:
self._response_deserializer = response_deserializer
def __del__(self) -> None:
if not self._cython_call.done():
self._cancel(_GC_CANCELLATION_DETAILS)
# The '_cython_call' object might be destructed before Call object
if hasattr(self, '_cython_call'):
if not self._cython_call.done():
self._cancel(_GC_CANCELLATION_DETAILS)
def cancelled(self) -> bool:
return self._cython_call.cancelled()
@ -345,9 +350,16 @@ class _StreamRequestMixin(Call):
async def _consume_request_iterator(
self, request_async_iterator: AsyncIterable[RequestType]) -> None:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
try:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised
# within this Task won't be retrieved by another coroutine. It's
# better to suppress the error than spamming users' screen.
_LOGGER.debug('Exception while consuming the request_iterator: %s',
rpc_error)
async def write(self, request: RequestType) -> None:
if self.done():
@ -356,6 +368,8 @@ class _StreamRequestMixin(Call):
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
if self.done():
await self._raise_for_status()
serialized_request = _common.serialize(request,
self._request_serializer)
@ -394,12 +408,13 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
super().__init__(
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._request = request
self._init_unary_response_mixin(self._invoke())
@ -436,12 +451,13 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
super().__init__(
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._request = request
self._send_unary_request_task = loop.create_task(
self._send_unary_request())
@ -471,12 +487,13 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
super().__init__(
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._init_stream_request_mixin(request_async_iterator)
self._init_unary_response_mixin(self._conduct_rpc())
@ -509,12 +526,13 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
super().__init__(
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._initializer = self._loop.create_task(self._prepare_rpc())
self._init_stream_request_mixin(request_async_iterator)
self._init_stream_response_mixin(self._initializer)

@ -101,9 +101,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -112,16 +109,16 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
if not self._interceptors:
return UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, self._channel,
self._method, self._request_serializer,
metadata, credentials, wait_for_ready,
self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
else:
return InterceptedUnaryUnaryCall(self._interceptors, request,
timeout, metadata, credentials,
self._channel, self._method,
self._request_serializer,
self._response_deserializer,
self._loop)
return InterceptedUnaryUnaryCall(
self._interceptors, request, timeout, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
class UnaryStreamMultiCallable(_BaseMultiCallable):
@ -154,10 +151,6 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
Returns:
A Call object instance which is an awaitable object.
"""
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -166,7 +159,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE
return UnaryStreamCall(request, deadline, metadata, credentials,
self._channel, self._method,
wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
@ -205,10 +198,6 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -217,8 +206,8 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method,
self._request_serializer,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
@ -256,10 +245,6 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -268,8 +253,8 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method,
self._request_serializer,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)

@ -33,13 +33,14 @@ _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
class ClientCallDetails(
collections.namedtuple(
'ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials')),
('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')),
grpc.ClientCallDetails):
method: Text
timeout: Optional[float]
metadata: Optional[MetadataType]
credentials: Optional[grpc.CallCredentials]
wait_for_ready: Optional[bool]
class UnaryUnaryClientInterceptor(metaclass=ABCMeta):
@ -108,28 +109,29 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
request: RequestType, timeout: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._channel = channel
self._loop = loop
self._interceptors_task = asyncio.ensure_future(self._invoke(
interceptors, method, timeout, metadata, credentials, request,
request_serializer, response_deserializer),
interceptors, method, timeout, metadata, credentials,
wait_for_ready, request, request_serializer, response_deserializer),
loop=loop)
def __del__(self):
self.cancel()
# pylint: disable=too-many-arguments
async def _invoke(
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> UnaryUnaryCall:
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction
) -> UnaryUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
@ -154,12 +156,13 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return UnaryUnaryCall(
request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials, self._channel,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials)
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request)

@ -15,5 +15,6 @@
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.metadata_test.TestMetadata",
"unit.server_test.TestServer"
"unit.server_test.TestServer",
"unit.wait_for_ready_test.TestWaitForReady"
]

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import grpc
from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType
@ -22,3 +24,11 @@ def seen_metadata(expected: MetadataType, actual: MetadataType):
def seen_metadatum(expected: MetadatumType, actual: MetadataType):
metadata_dict = dict(actual)
return metadata_dict.get(expected[0]) == expected[1]
async def block_until_certain_state(channel: aio.Channel,
expected_state: grpc.ChannelConnectivity):
state = channel.get_state()
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()

@ -87,7 +87,7 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
response_parameters.size))
async def start_test_server(secure=False):
async def start_test_server(port=0, secure=False):
server = aio.server(options=(('grpc.so_reuseport', 0),))
servicer = _TestServiceServicer()
test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)
@ -109,10 +109,11 @@ async def start_test_server(secure=False):
if secure:
server_credentials = grpc.local_server_credentials(
grpc.LocalConnectionType.LOCAL_TCP)
port = server.add_secure_port('[::]:0', server_credentials)
port = server.add_secure_port(f'[::]:{port}', server_credentials)
else:
port = server.add_insecure_port('[::]:0')
port = server.add_insecure_port(f'[::]:{port}')
await server.start()
# NOTE(lidizheng) returning the server to prevent it from deallocation
return 'localhost:%d' % port, server

@ -80,17 +80,16 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.UnaryCall(messages_pb2.SimpleRequest(), timeout=0.1)
call = stub.UnaryCall(messages_pb2.SimpleRequest())
with self.assertRaises(grpc.RpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED,
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
call.code())
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
# Exception is cached at call object level, reentrance
# returns again the same exception

@ -23,18 +23,12 @@ import grpc
from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from tests_aio.unit import _common
from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
async def _block_until_certain_state(channel, expected_state):
state = channel.get_state()
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
class TestConnectivityState(AioTestBase):
async def setUp(self):
@ -52,7 +46,7 @@ class TestConnectivityState(AioTestBase):
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(
_common.block_until_certain_state(
channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE),
test_constants.SHORT_TIMEOUT)
@ -63,8 +57,8 @@ class TestConnectivityState(AioTestBase):
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
_common.block_until_certain_state(
channel, grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_timeout(self):
@ -75,8 +69,8 @@ class TestConnectivityState(AioTestBase):
# If timed out, the function should return None.
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
_block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
_common.block_until_certain_state(
channel, grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_shutdown(self):

@ -13,20 +13,6 @@
# limitations under the License.
"""Testing the done callbacks mechanism."""
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest

@ -132,7 +132,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
method=client_call_details.method,
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready)
return await continuation(new_client_call_details, request)
interceptor = TimeoutInterceptor()
@ -173,7 +174,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
method=client_call_details.method,
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready)
try:
call = await continuation(new_client_call_details, request)
@ -187,7 +189,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
method=client_call_details.method,
timeout=None,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready)
call = await continuation(new_client_call_details, request)
self.calls.append(call)
@ -552,6 +555,7 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
metadata=client_call_details.metadata +
_INITIAL_METADATA_TO_INJECT,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
)
return await continuation(new_details, request)

@ -0,0 +1,146 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing the done callbacks mechanism."""
import asyncio
import logging
import unittest
import time
import gc
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from tests.unit.framework.common import get_socket
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit import _common
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
async def _perform_unary_unary(stub, wait_for_ready):
await stub.UnaryCall(messages_pb2.SimpleRequest(),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
async def _perform_unary_stream(stub, wait_for_ready):
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
call = stub.StreamingOutputCall(request,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
for _ in range(_NUM_STREAM_RESPONSES):
await call.read()
assert await call.code() == grpc.StatusCode.OK
async def _perform_stream_unary(stub, wait_for_ready):
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await stub.StreamingInputCall(gen(),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
async def _perform_stream_stream(stub, wait_for_ready):
call = stub.FullDuplexCall(timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
await call.done_writing()
assert await call.code() == grpc.StatusCode.OK
_RPC_ACTIONS = (
_perform_unary_unary,
_perform_unary_stream,
_perform_stream_unary,
_perform_stream_stream,
)
class TestWaitForReady(AioTestBase):
async def setUp(self):
address, self._port, self._socket = get_socket(listen=False)
self._channel = aio.insecure_channel(f"{address}:{self._port}")
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
self._socket.close()
async def tearDown(self):
await self._channel.close()
async def _connection_fails_fast(self, wait_for_ready):
for action in _RPC_ACTIONS:
with self.subTest(name=action):
with self.assertRaises(aio.AioRpcError) as exception_context:
await action(self._stub, wait_for_ready)
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
async def test_call_wait_for_ready_default(self):
"""RPC should fail immediately after connection failed."""
await self._connection_fails_fast(None)
async def test_call_wait_for_ready_disabled(self):
"""RPC should fail immediately after connection failed."""
await self._connection_fails_fast(False)
async def test_call_wait_for_ready_enabled(self):
"""RPC will wait until the connection is ready."""
for action in _RPC_ACTIONS:
with self.subTest(name=action.__name__):
# Starts the RPC
action_task = self.loop.create_task(action(self._stub, True))
# Wait for TRANSIENT_FAILURE, and RPC is not aborting
await _common.block_until_certain_state(
self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE)
try:
# Start the server
_, server = await start_test_server(port=self._port)
# The RPC should recover itself
await action_task
finally:
if server is not None:
await server.stop(None)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save