Support wait-for-ready mechanism

* Fixing a segfault & a deadlock along the way
* Patching another loophole in the error path
pull/21803/head
Lidi Zheng 5 years ago
parent b94490bc74
commit 72d6642226
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 47
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  4. 5
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 7
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi
  6. 24
      src/python/grpcio/grpc/experimental/aio/_call.py
  7. 24
      src/python/grpcio/grpc/experimental/aio/_channel.py
  8. 3
      src/python/grpcio_tests/tests_aio/tests.json
  9. 10
      src/python/grpcio_tests/tests_aio/unit/_common.py
  10. 7
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  11. 13
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py
  12. 14
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py
  13. 136
      src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py

@ -40,6 +40,8 @@ cdef class _AioCall(GrpcCallWrapper):
list _waiters_status list _waiters_status
list _waiters_initial_metadata list _waiters_initial_metadata
int _send_initial_metadata_flags
cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except * cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except *
cdef void _set_status(self, AioRpcStatus status) except * cdef void _set_status(self, AioRpcStatus status) except *
cdef void _set_initial_metadata(self, tuple initial_metadata) except * cdef void _set_initial_metadata(self, tuple initial_metadata) except *

@ -30,10 +30,22 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'>') '>')
cdef int _get_send_initial_metadata_flags(object wait_for_ready) except *:
cdef int flags = 0
# Wait-for-ready can be None, which means using default value in Core.
if wait_for_ready is not None:
flags |= InitialMetadataFlags.wait_for_ready_explicitly_set
if wait_for_ready:
flags |= InitialMetadataFlags.wait_for_ready
flags &= InitialMetadataFlags.used_mask
return flags
cdef class _AioCall(GrpcCallWrapper): cdef class _AioCall(GrpcCallWrapper):
def __cinit__(self, AioChannel channel, object deadline, def __cinit__(self, AioChannel channel, object deadline,
bytes method, CallCredentials call_credentials): bytes method, CallCredentials call_credentials, object wait_for_ready):
self.call = NULL self.call = NULL
self._channel = channel self._channel = channel
self._loop = channel.loop self._loop = channel.loop
@ -45,6 +57,7 @@ cdef class _AioCall(GrpcCallWrapper):
self._done_callbacks = [] self._done_callbacks = []
self._is_locally_cancelled = False self._is_locally_cancelled = False
self._deadline = deadline self._deadline = deadline
self._send_initial_metadata_flags = _get_send_initial_metadata_flags(wait_for_ready)
self._create_grpc_call(deadline, method, call_credentials) self._create_grpc_call(deadline, method, call_credentials)
def __dealloc__(self): def __dealloc__(self):
@ -279,7 +292,7 @@ cdef class _AioCall(GrpcCallWrapper):
cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation( cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation(
outbound_initial_metadata, outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK) self._send_initial_metadata_flags)
cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS) cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS)
cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS) cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS)
cdef ReceiveInitialMetadataOperation receive_initial_metadata_op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS) cdef ReceiveInitialMetadataOperation receive_initial_metadata_op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
@ -366,12 +379,12 @@ cdef class _AioCall(GrpcCallWrapper):
"""Implementation of the start of a unary-stream call.""" """Implementation of the start of a unary-stream call."""
# Peer may prematurely end this RPC at any point. We need a corutine # Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status. # that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received()) status_task = self._loop.create_task(self._handle_status_once_received())
cdef tuple outbound_ops cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation( cdef Operation initial_metadata_op = SendInitialMetadataOperation(
outbound_initial_metadata, outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK) self._send_initial_metadata_flags)
cdef Operation send_message_op = SendMessageOperation( cdef Operation send_message_op = SendMessageOperation(
request, request,
_EMPTY_FLAGS) _EMPTY_FLAGS)
@ -384,6 +397,7 @@ cdef class _AioCall(GrpcCallWrapper):
send_close_op, send_close_op,
) )
try:
# Sends out the request message. # Sends out the request message.
await execute_batch(self, await execute_batch(self,
outbound_ops, outbound_ops,
@ -394,6 +408,10 @@ cdef class _AioCall(GrpcCallWrapper):
await _receive_initial_metadata(self, await _receive_initial_metadata(self,
self._loop), self._loop),
) )
except ExecuteBatchError as batch_error:
# Core should explain why this batch failed
await status_task
assert self._status.code() != StatusCode.ok
async def stream_unary(self, async def stream_unary(self,
tuple outbound_initial_metadata, tuple outbound_initial_metadata,
@ -404,9 +422,11 @@ cdef class _AioCall(GrpcCallWrapper):
propagate the final status exception, then we have to raise it. propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`. Othersize, it would end normally and raise `StopAsyncIteration()`.
""" """
try:
# Sends out initial_metadata ASAP. # Sends out initial_metadata ASAP.
await _send_initial_metadata(self, await _send_initial_metadata(self,
outbound_initial_metadata, outbound_initial_metadata,
self._send_initial_metadata_flags,
self._loop) self._loop)
# Notify upper level that sending messages are allowed now. # Notify upper level that sending messages are allowed now.
metadata_sent_observer() metadata_sent_observer()
@ -415,6 +435,14 @@ cdef class _AioCall(GrpcCallWrapper):
self._set_initial_metadata( self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop) await _receive_initial_metadata(self, self._loop)
) )
except ExecuteBatchError:
# Core should explain why this batch failed
await self._handle_status_once_received()
assert self._status.code() != StatusCode.ok
# Allow upper layer to proceed only if the status is set
metadata_sent_observer()
return None
cdef tuple inbound_ops cdef tuple inbound_ops
cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS) cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS)
@ -452,11 +480,13 @@ cdef class _AioCall(GrpcCallWrapper):
""" """
# Peer may prematurely end this RPC at any point. We need a corutine # Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status. # that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received()) status_task = self._loop.create_task(self._handle_status_once_received())
try:
# Sends out initial_metadata ASAP. # Sends out initial_metadata ASAP.
await _send_initial_metadata(self, await _send_initial_metadata(self,
outbound_initial_metadata, outbound_initial_metadata,
self._send_initial_metadata_flags,
self._loop) self._loop)
# Notify upper level that sending messages are allowed now. # Notify upper level that sending messages are allowed now.
metadata_sent_observer() metadata_sent_observer()
@ -465,3 +495,10 @@ cdef class _AioCall(GrpcCallWrapper):
self._set_initial_metadata( self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop) await _receive_initial_metadata(self, self._loop)
) )
except ExecuteBatchError as batch_error:
# Core should explain why this batch failed
await status_task
assert self._status.code() != StatusCode.ok
# Allow upper layer to proceed only if the status is set
metadata_sent_observer()

@ -164,10 +164,11 @@ async def _send_message(GrpcCallWrapper grpc_call_wrapper,
async def _send_initial_metadata(GrpcCallWrapper grpc_call_wrapper, async def _send_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
tuple metadata, tuple metadata,
int flags,
object loop): object loop):
cdef SendInitialMetadataOperation op = SendInitialMetadataOperation( cdef SendInitialMetadataOperation op = SendInitialMetadataOperation(
metadata, metadata,
_EMPTY_FLAG) flags)
cdef tuple ops = (op,) cdef tuple ops = (op,)
await execute_batch(grpc_call_wrapper, ops, loop) await execute_batch(grpc_call_wrapper, ops, loop)

@ -99,7 +99,8 @@ cdef class AioChannel:
def call(self, def call(self,
bytes method, bytes method,
object deadline, object deadline,
object python_call_credentials): object python_call_credentials,
object wait_for_ready):
"""Assembles a Cython Call object. """Assembles a Cython Call object.
Returns: Returns:
@ -115,4 +116,4 @@ cdef class AioChannel:
else: else:
cython_call_credentials = None cython_call_credentials = None
return _AioCall(self, deadline, method, cython_call_credentials) return _AioCall(self, deadline, method, cython_call_credentials, wait_for_ready)

@ -87,7 +87,7 @@ cdef class _AsyncioSocket:
except Exception as e: except Exception as e:
error = True error = True
error_msg = "%s: %s" % (type(e), str(e)) error_msg = "%s: %s" % (type(e), str(e))
_LOGGER.exception(e) _LOGGER.debug(e)
finally: finally:
self._task_read = None self._task_read = None
@ -167,6 +167,11 @@ cdef class _AsyncioSocket:
self._py_socket.close() self._py_socket.close()
def _new_connection_callback(self, object reader, object writer): def _new_connection_callback(self, object reader, object writer):
# Close the connection if server is not started yet.
if self._grpc_accept_cb == NULL:
writer.close()
return
client_socket = _AsyncioSocket.create( client_socket = _AsyncioSocket.create(
self._grpc_client_socket, self._grpc_client_socket,
reader, reader,

@ -15,6 +15,7 @@
import asyncio import asyncio
from functools import partial from functools import partial
import logging
from typing import AsyncIterable, Awaitable, Dict, Optional from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc import grpc
@ -43,6 +44,8 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tdebug_error_string = "{}"\n' '\tdebug_error_string = "{}"\n'
'>') '>')
_LOGGER = logging.getLogger(__name__)
class AioRpcError(grpc.RpcError): class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API. """An implementation of RpcError to be used by the asynchronous API.
@ -168,6 +171,7 @@ class Call:
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
def __del__(self) -> None: def __del__(self) -> None:
if hasattr(self, '_cython_call'):
if not self._cython_call.done(): if not self._cython_call.done():
self._cancel(_GC_CANCELLATION_DETAILS) self._cancel(_GC_CANCELLATION_DETAILS)
@ -345,9 +349,15 @@ class _StreamRequestMixin(Call):
async def _consume_request_iterator( async def _consume_request_iterator(
self, request_async_iterator: AsyncIterable[RequestType]) -> None: self, request_async_iterator: AsyncIterable[RequestType]) -> None:
try:
async for request in request_async_iterator: async for request in request_async_iterator:
await self.write(request) await self.write(request)
await self.done_writing() await self.done_writing()
except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised
# within this Task won't be retrieved by another coroutine. It's
# better to suppress the error than spamming users' screen.
_LOGGER.debug('Exception while consuming of the request_iterator: %s', rpc_error)
async def write(self, request: RequestType) -> None: async def write(self, request: RequestType) -> None:
if self.done(): if self.done():
@ -356,6 +366,8 @@ class _StreamRequestMixin(Call):
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set(): if not self._metadata_sent.is_set():
await self._metadata_sent.wait() await self._metadata_sent.wait()
if self.done():
await self._raise_for_status()
serialized_request = _common.serialize(request, serialized_request = _common.serialize(request,
self._request_serializer) self._request_serializer)
@ -394,11 +406,12 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
def __init__(self, request: RequestType, deadline: Optional[float], def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType, metadata: MetadataType,
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None: loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata, super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata,
request_serializer, response_deserializer, loop) request_serializer, response_deserializer, loop)
self._request = request self._request = request
self._init_unary_response_mixin(self._invoke()) self._init_unary_response_mixin(self._invoke())
@ -436,11 +449,12 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
def __init__(self, request: RequestType, deadline: Optional[float], def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType, metadata: MetadataType,
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None: loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata, super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata,
request_serializer, response_deserializer, loop) request_serializer, response_deserializer, loop)
self._request = request self._request = request
self._send_unary_request_task = loop.create_task( self._send_unary_request_task = loop.create_task(
@ -471,11 +485,12 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
request_async_iterator: Optional[AsyncIterable[RequestType]], request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType, deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None: loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata, super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata,
request_serializer, response_deserializer, loop) request_serializer, response_deserializer, loop)
self._init_stream_request_mixin(request_async_iterator) self._init_stream_request_mixin(request_async_iterator)
@ -509,11 +524,12 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
request_async_iterator: Optional[AsyncIterable[RequestType]], request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType, deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None: loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata, super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata,
request_serializer, response_deserializer, loop) request_serializer, response_deserializer, loop)
self._initializer = self._loop.create_task(self._prepare_rpc()) self._initializer = self._loop.create_task(self._prepare_rpc())
self._init_stream_request_mixin(request_async_iterator) self._init_stream_request_mixin(request_async_iterator)

@ -101,9 +101,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details. metadata, status code, and details.
""" """
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") raise NotImplementedError("TODO: compression not implemented yet")
@ -112,12 +109,13 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
if not self._interceptors: if not self._interceptors:
return UnaryUnaryCall(request, _timeout_to_deadline(timeout), return UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, self._channel, metadata, credentials, wait_for_ready, self._channel,
self._method, self._request_serializer, self._method, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)
else: else:
return InterceptedUnaryUnaryCall(self._interceptors, request, return InterceptedUnaryUnaryCall(self._interceptors, request,
timeout, metadata, credentials, timeout, metadata, credentials,
wait_for_ready,
self._channel, self._method, self._channel, self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._response_deserializer,
@ -154,10 +152,6 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
Returns: Returns:
A Call object instance which is an awaitable object. A Call object instance which is an awaitable object.
""" """
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") raise NotImplementedError("TODO: compression not implemented yet")
@ -165,7 +159,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
if metadata is None: if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE metadata = _IMMUTABLE_EMPTY_TUPLE
return UnaryStreamCall(request, deadline, metadata, credentials, return UnaryStreamCall(request, deadline, metadata, credentials,wait_for_ready,
self._channel, self._method, self._channel, self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)
@ -205,10 +199,6 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details. metadata, status code, and details.
""" """
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") raise NotImplementedError("TODO: compression not implemented yet")
@ -217,7 +207,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamUnaryCall(request_async_iterator, deadline, metadata, return StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method, credentials, wait_for_ready, self._channel, self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)
@ -256,10 +246,6 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details. metadata, status code, and details.
""" """
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") raise NotImplementedError("TODO: compression not implemented yet")
@ -268,7 +254,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamStreamCall(request_async_iterator, deadline, metadata, return StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method, credentials, wait_for_ready, self._channel, self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)

@ -15,5 +15,6 @@
"unit.interceptor_test.TestInterceptedUnaryUnaryCall", "unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor", "unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.metadata_test.TestMetadata", "unit.metadata_test.TestMetadata",
"unit.server_test.TestServer" "unit.server_test.TestServer",
"unit.wait_for_ready.TestWaitForReady"
] ]

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import grpc
from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType from grpc.experimental.aio._typing import MetadataType, MetadatumType
@ -22,3 +24,11 @@ def seen_metadata(expected: MetadataType, actual: MetadataType):
def seen_metadatum(expected: MetadatumType, actual: MetadataType): def seen_metadatum(expected: MetadatumType, actual: MetadataType):
metadata_dict = dict(actual) metadata_dict = dict(actual)
return metadata_dict.get(expected[0]) == expected[1] return metadata_dict.get(expected[0]) == expected[1]
async def block_until_certain_state(channel: aio.Channel, expected_state: grpc.ChannelConnectivity):
state = channel.get_state()
while state != expected_state:
import logging;logging.debug('Get %s want %s', state, expected_state)
await channel.wait_for_state_change(state)
state = channel.get_state()

@ -87,7 +87,7 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
response_parameters.size)) response_parameters.size))
async def start_test_server(secure=False): async def start_test_server(port=0, secure=False):
server = aio.server(options=(('grpc.so_reuseport', 0),)) server = aio.server(options=(('grpc.so_reuseport', 0),))
servicer = _TestServiceServicer() servicer = _TestServiceServicer()
test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)
@ -109,10 +109,11 @@ async def start_test_server(secure=False):
if secure: if secure:
server_credentials = grpc.local_server_credentials( server_credentials = grpc.local_server_credentials(
grpc.LocalConnectionType.LOCAL_TCP) grpc.LocalConnectionType.LOCAL_TCP)
port = server.add_secure_port('[::]:0', server_credentials) port = server.add_secure_port(f'[::]:{port}', server_credentials)
else: else:
port = server.add_insecure_port('[::]:0') port = server.add_insecure_port(f'[::]:{port}')
await server.start() await server.start()
# NOTE(lidizheng) returning the server to prevent it from deallocation # NOTE(lidizheng) returning the server to prevent it from deallocation
return 'localhost:%d' % port, server return 'localhost:%d' % port, server

@ -28,13 +28,6 @@ from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_server import start_test_server
async def _block_until_certain_state(channel, expected_state):
state = channel.get_state()
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
class TestConnectivityState(AioTestBase): class TestConnectivityState(AioTestBase):
async def setUp(self): async def setUp(self):
@ -52,7 +45,7 @@ class TestConnectivityState(AioTestBase):
# Should not time out # Should not time out
await asyncio.wait_for( await asyncio.wait_for(
_block_until_certain_state( _common.block_until_certain_state(
channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE), channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE),
test_constants.SHORT_TIMEOUT) test_constants.SHORT_TIMEOUT)
@ -63,7 +56,7 @@ class TestConnectivityState(AioTestBase):
# Should not time out # Should not time out
await asyncio.wait_for( await asyncio.wait_for(
_block_until_certain_state(channel, _common.block_until_certain_state(channel,
grpc.ChannelConnectivity.READY), grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT) test_constants.SHORT_TIMEOUT)
@ -75,7 +68,7 @@ class TestConnectivityState(AioTestBase):
# If timed out, the function should return None. # If timed out, the function should return None.
with self.assertRaises(asyncio.TimeoutError): with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for( await asyncio.wait_for(
_block_until_certain_state(channel, _common.block_until_certain_state(channel,
grpc.ChannelConnectivity.READY), grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT) test_constants.SHORT_TIMEOUT)

@ -13,20 +13,6 @@
# limitations under the License. # limitations under the License.
"""Testing the done callbacks mechanism.""" """Testing the done callbacks mechanism."""
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
import logging import logging
import unittest import unittest

@ -0,0 +1,136 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing the done callbacks mechanism."""
import asyncio
import logging
import unittest
import time
import gc
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from tests.unit.framework.common import get_socket
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit import _common
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
async def _perform_unary_unary(stub, wait_for_ready):
await stub.UnaryCall(messages_pb2.SimpleRequest(), timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready)
async def _perform_unary_stream(stub, wait_for_ready):
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
call = stub.StreamingOutputCall(request, timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready)
for _ in range(_NUM_STREAM_RESPONSES):
await call.read()
assert await call.code() == grpc.StatusCode.OK
async def _perform_stream_unary(stub, wait_for_ready):
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await stub.StreamingInputCall(gen(), timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready)
async def _perform_stream_stream(stub, wait_for_ready):
call = stub.FullDuplexCall(timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
await call.done_writing()
assert await call.code() == grpc.StatusCode.OK
_RPC_ACTIONS = (
_perform_unary_unary,
_perform_unary_stream,
_perform_stream_unary,
_perform_stream_stream,
)
class TestWaitForReady(AioTestBase):
async def setUp(self):
address, self._port, self._socket = get_socket(listen=False)
self._channel = aio.insecure_channel(f"{address}:{self._port}")
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
self._socket.close()
async def tearDown(self):
await self._channel.close()
async def _connection_fails_fast(self, wait_for_ready):
for action in _RPC_ACTIONS:
with self.subTest(name=action):
with self.assertRaises(aio.AioRpcError) as exception_context:
await action(self._stub, wait_for_ready)
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
async def test_call_wait_for_ready_default(self):
await self._connection_fails_fast(None)
async def test_call_wait_for_ready_disabled(self):
await self._connection_fails_fast(False)
async def test_call_wait_for_ready_enabled(self):
for action in _RPC_ACTIONS:
with self.subTest(name=action.__name__):
# Starts the RPC
action_task = self.loop.create_task(action(self._stub, True))
# Wait for TRANSIENT_FAILURE, and RPC is not aborting
await _common.block_until_certain_state(
self._channel,
grpc.ChannelConnectivity.TRANSIENT_FAILURE)
try:
# Start the server
_, server = await start_test_server(port=self._port)
# The RPC should recover itself
await action_task
finally:
if server is not None:
await server.stop(None)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save