Support wait-for-ready mechanism

* Fixing a segfault & a deadlock along the way
* Patching another loophole in the error path
pull/21803/head
Lidi Zheng 5 years ago
parent b94490bc74
commit 72d6642226
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 47
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  4. 5
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 7
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi
  6. 24
      src/python/grpcio/grpc/experimental/aio/_call.py
  7. 24
      src/python/grpcio/grpc/experimental/aio/_channel.py
  8. 3
      src/python/grpcio_tests/tests_aio/tests.json
  9. 10
      src/python/grpcio_tests/tests_aio/unit/_common.py
  10. 7
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  11. 13
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py
  12. 14
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py
  13. 136
      src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py

@ -40,6 +40,8 @@ cdef class _AioCall(GrpcCallWrapper):
list _waiters_status
list _waiters_initial_metadata
int _send_initial_metadata_flags
cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except *
cdef void _set_status(self, AioRpcStatus status) except *
cdef void _set_initial_metadata(self, tuple initial_metadata) except *

@ -30,10 +30,22 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'>')
cdef int _get_send_initial_metadata_flags(object wait_for_ready) except *:
cdef int flags = 0
# Wait-for-ready can be None, which means using default value in Core.
if wait_for_ready is not None:
flags |= InitialMetadataFlags.wait_for_ready_explicitly_set
if wait_for_ready:
flags |= InitialMetadataFlags.wait_for_ready
flags &= InitialMetadataFlags.used_mask
return flags
cdef class _AioCall(GrpcCallWrapper):
def __cinit__(self, AioChannel channel, object deadline,
bytes method, CallCredentials call_credentials):
bytes method, CallCredentials call_credentials, object wait_for_ready):
self.call = NULL
self._channel = channel
self._loop = channel.loop
@ -45,6 +57,7 @@ cdef class _AioCall(GrpcCallWrapper):
self._done_callbacks = []
self._is_locally_cancelled = False
self._deadline = deadline
self._send_initial_metadata_flags = _get_send_initial_metadata_flags(wait_for_ready)
self._create_grpc_call(deadline, method, call_credentials)
def __dealloc__(self):
@ -279,7 +292,7 @@ cdef class _AioCall(GrpcCallWrapper):
cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation(
outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK)
self._send_initial_metadata_flags)
cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS)
cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS)
cdef ReceiveInitialMetadataOperation receive_initial_metadata_op = ReceiveInitialMetadataOperation(_EMPTY_FLAGS)
@ -366,12 +379,12 @@ cdef class _AioCall(GrpcCallWrapper):
"""Implementation of the start of a unary-stream call."""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received())
status_task = self._loop.create_task(self._handle_status_once_received())
cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation(
outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK)
self._send_initial_metadata_flags)
cdef Operation send_message_op = SendMessageOperation(
request,
_EMPTY_FLAGS)
@ -384,6 +397,7 @@ cdef class _AioCall(GrpcCallWrapper):
send_close_op,
)
try:
# Sends out the request message.
await execute_batch(self,
outbound_ops,
@ -394,6 +408,10 @@ cdef class _AioCall(GrpcCallWrapper):
await _receive_initial_metadata(self,
self._loop),
)
except ExecuteBatchError as batch_error:
# Core should explain why this batch failed
await status_task
assert self._status.code() != StatusCode.ok
async def stream_unary(self,
tuple outbound_initial_metadata,
@ -404,9 +422,11 @@ cdef class _AioCall(GrpcCallWrapper):
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
try:
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
outbound_initial_metadata,
self._send_initial_metadata_flags,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
@ -415,6 +435,14 @@ cdef class _AioCall(GrpcCallWrapper):
self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop)
)
except ExecuteBatchError:
# Core should explain why this batch failed
await self._handle_status_once_received()
assert self._status.code() != StatusCode.ok
# Allow upper layer to proceed only if the status is set
metadata_sent_observer()
return None
cdef tuple inbound_ops
cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS)
@ -452,11 +480,13 @@ cdef class _AioCall(GrpcCallWrapper):
"""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received())
status_task = self._loop.create_task(self._handle_status_once_received())
try:
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
outbound_initial_metadata,
self._send_initial_metadata_flags,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
@ -465,3 +495,10 @@ cdef class _AioCall(GrpcCallWrapper):
self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop)
)
except ExecuteBatchError as batch_error:
# Core should explain why this batch failed
await status_task
assert self._status.code() != StatusCode.ok
# Allow upper layer to proceed only if the status is set
metadata_sent_observer()

@ -164,10 +164,11 @@ async def _send_message(GrpcCallWrapper grpc_call_wrapper,
async def _send_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
tuple metadata,
int flags,
object loop):
cdef SendInitialMetadataOperation op = SendInitialMetadataOperation(
metadata,
_EMPTY_FLAG)
flags)
cdef tuple ops = (op,)
await execute_batch(grpc_call_wrapper, ops, loop)

@ -99,7 +99,8 @@ cdef class AioChannel:
def call(self,
bytes method,
object deadline,
object python_call_credentials):
object python_call_credentials,
object wait_for_ready):
"""Assembles a Cython Call object.
Returns:
@ -115,4 +116,4 @@ cdef class AioChannel:
else:
cython_call_credentials = None
return _AioCall(self, deadline, method, cython_call_credentials)
return _AioCall(self, deadline, method, cython_call_credentials, wait_for_ready)

@ -87,7 +87,7 @@ cdef class _AsyncioSocket:
except Exception as e:
error = True
error_msg = "%s: %s" % (type(e), str(e))
_LOGGER.exception(e)
_LOGGER.debug(e)
finally:
self._task_read = None
@ -167,6 +167,11 @@ cdef class _AsyncioSocket:
self._py_socket.close()
def _new_connection_callback(self, object reader, object writer):
# Close the connection if server is not started yet.
if self._grpc_accept_cb == NULL:
writer.close()
return
client_socket = _AsyncioSocket.create(
self._grpc_client_socket,
reader,

@ -15,6 +15,7 @@
import asyncio
from functools import partial
import logging
from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc
@ -43,6 +44,8 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tdebug_error_string = "{}"\n'
'>')
_LOGGER = logging.getLogger(__name__)
class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API.
@ -168,6 +171,7 @@ class Call:
self._response_deserializer = response_deserializer
def __del__(self) -> None:
if hasattr(self, '_cython_call'):
if not self._cython_call.done():
self._cancel(_GC_CANCELLATION_DETAILS)
@ -345,9 +349,15 @@ class _StreamRequestMixin(Call):
async def _consume_request_iterator(
self, request_async_iterator: AsyncIterable[RequestType]) -> None:
try:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised
# within this Task won't be retrieved by another coroutine. It's
# better to suppress the error than spamming users' screen.
_LOGGER.debug('Exception while consuming of the request_iterator: %s', rpc_error)
async def write(self, request: RequestType) -> None:
if self.done():
@ -356,6 +366,8 @@ class _StreamRequestMixin(Call):
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
if self.done():
await self._raise_for_status()
serialized_request = _common.serialize(request,
self._request_serializer)
@ -394,11 +406,12 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata,
request_serializer, response_deserializer, loop)
self._request = request
self._init_unary_response_mixin(self._invoke())
@ -436,11 +449,12 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata,
request_serializer, response_deserializer, loop)
self._request = request
self._send_unary_request_task = loop.create_task(
@ -471,11 +485,12 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata,
request_serializer, response_deserializer, loop)
self._init_stream_request_mixin(request_async_iterator)
@ -509,11 +524,12 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
super().__init__(channel.call(method, deadline, credentials, wait_for_ready), metadata,
request_serializer, response_deserializer, loop)
self._initializer = self._loop.create_task(self._prepare_rpc())
self._init_stream_request_mixin(request_async_iterator)

@ -101,9 +101,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -112,12 +109,13 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
if not self._interceptors:
return UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, self._channel,
metadata, credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
return InterceptedUnaryUnaryCall(self._interceptors, request,
timeout, metadata, credentials,
wait_for_ready,
self._channel, self._method,
self._request_serializer,
self._response_deserializer,
@ -154,10 +152,6 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
Returns:
A Call object instance which is an awaitable object.
"""
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -165,7 +159,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
return UnaryStreamCall(request, deadline, metadata, credentials,
return UnaryStreamCall(request, deadline, metadata, credentials,wait_for_ready,
self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
@ -205,10 +199,6 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -217,7 +207,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method,
credentials, wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
@ -256,10 +246,6 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -268,7 +254,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method,
credentials, wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)

@ -15,5 +15,6 @@
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.metadata_test.TestMetadata",
"unit.server_test.TestServer"
"unit.server_test.TestServer",
"unit.wait_for_ready.TestWaitForReady"
]

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import grpc
from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType
@ -22,3 +24,11 @@ def seen_metadata(expected: MetadataType, actual: MetadataType):
def seen_metadatum(expected: MetadatumType, actual: MetadataType):
metadata_dict = dict(actual)
return metadata_dict.get(expected[0]) == expected[1]
async def block_until_certain_state(channel: aio.Channel, expected_state: grpc.ChannelConnectivity):
state = channel.get_state()
while state != expected_state:
import logging;logging.debug('Get %s want %s', state, expected_state)
await channel.wait_for_state_change(state)
state = channel.get_state()

@ -87,7 +87,7 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
response_parameters.size))
async def start_test_server(secure=False):
async def start_test_server(port=0, secure=False):
server = aio.server(options=(('grpc.so_reuseport', 0),))
servicer = _TestServiceServicer()
test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)
@ -109,10 +109,11 @@ async def start_test_server(secure=False):
if secure:
server_credentials = grpc.local_server_credentials(
grpc.LocalConnectionType.LOCAL_TCP)
port = server.add_secure_port('[::]:0', server_credentials)
port = server.add_secure_port(f'[::]:{port}', server_credentials)
else:
port = server.add_insecure_port('[::]:0')
port = server.add_insecure_port(f'[::]:{port}')
await server.start()
# NOTE(lidizheng) returning the server to prevent it from deallocation
return 'localhost:%d' % port, server

@ -28,13 +28,6 @@ from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
async def _block_until_certain_state(channel, expected_state):
state = channel.get_state()
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
class TestConnectivityState(AioTestBase):
async def setUp(self):
@ -52,7 +45,7 @@ class TestConnectivityState(AioTestBase):
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(
_common.block_until_certain_state(
channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE),
test_constants.SHORT_TIMEOUT)
@ -63,7 +56,7 @@ class TestConnectivityState(AioTestBase):
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(channel,
_common.block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
@ -75,7 +68,7 @@ class TestConnectivityState(AioTestBase):
# If timed out, the function should return None.
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
_block_until_certain_state(channel,
_common.block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)

@ -13,20 +13,6 @@
# limitations under the License.
"""Testing the done callbacks mechanism."""
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest

@ -0,0 +1,136 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing the done callbacks mechanism."""
import asyncio
import logging
import unittest
import time
import gc
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from tests.unit.framework.common import get_socket
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit import _common
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
async def _perform_unary_unary(stub, wait_for_ready):
await stub.UnaryCall(messages_pb2.SimpleRequest(), timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready)
async def _perform_unary_stream(stub, wait_for_ready):
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
call = stub.StreamingOutputCall(request, timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready)
for _ in range(_NUM_STREAM_RESPONSES):
await call.read()
assert await call.code() == grpc.StatusCode.OK
async def _perform_stream_unary(stub, wait_for_ready):
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await stub.StreamingInputCall(gen(), timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready)
async def _perform_stream_stream(stub, wait_for_ready):
call = stub.FullDuplexCall(timeout=test_constants.SHORT_TIMEOUT, wait_for_ready=wait_for_ready)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
await call.done_writing()
assert await call.code() == grpc.StatusCode.OK
_RPC_ACTIONS = (
_perform_unary_unary,
_perform_unary_stream,
_perform_stream_unary,
_perform_stream_stream,
)
class TestWaitForReady(AioTestBase):
async def setUp(self):
address, self._port, self._socket = get_socket(listen=False)
self._channel = aio.insecure_channel(f"{address}:{self._port}")
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
self._socket.close()
async def tearDown(self):
await self._channel.close()
async def _connection_fails_fast(self, wait_for_ready):
for action in _RPC_ACTIONS:
with self.subTest(name=action):
with self.assertRaises(aio.AioRpcError) as exception_context:
await action(self._stub, wait_for_ready)
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code())
async def test_call_wait_for_ready_default(self):
await self._connection_fails_fast(None)
async def test_call_wait_for_ready_disabled(self):
await self._connection_fails_fast(False)
async def test_call_wait_for_ready_enabled(self):
for action in _RPC_ACTIONS:
with self.subTest(name=action.__name__):
# Starts the RPC
action_task = self.loop.create_task(action(self._stub, True))
# Wait for TRANSIENT_FAILURE, and RPC is not aborting
await _common.block_until_certain_state(
self._channel,
grpc.ChannelConnectivity.TRANSIENT_FAILURE)
try:
# Start the server
_, server = await start_test_server(port=self._port)
# The RPC should recover itself
await action_task
finally:
if server is not None:
await server.stop(None)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save