[Aio] Close ongoing calls when the channel is closed

When the channel is closed, either by calling explicitly the `close()`
method or by leaving an asyncrhonous channel context all ongoing RPCs will be
cancelled.
pull/21819/head
Pau Freixes 5 years ago
parent 214fd8822b
commit c2b3e00068
  1. 9
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  3. 9
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  4. 139
      src/python/grpcio/grpc/experimental/aio/_channel.py
  5. 37
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  6. 1
      src/python/grpcio_tests/tests_aio/tests.json
  7. 100
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  8. 94
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py

@ -119,14 +119,14 @@ cdef class _AioCall(GrpcCallWrapper):
cdef void _set_status(self, AioRpcStatus status) except *:
cdef list waiters
self._status = status
if self._initial_metadata is None:
self._set_initial_metadata(_IMMUTABLE_EMPTY_METADATA)
self._status = status
waiters = self._waiters_status
# No more waiters should be expected since status
# has been set.
waiters = self._waiters_status
self._waiters_status = None
for waiter in waiters:
@ -141,10 +141,9 @@ cdef class _AioCall(GrpcCallWrapper):
self._initial_metadata = initial_metadata
waiters = self._waiters_initial_metadata
# No more waiters should be expected since initial metadata
# has been set.
waiters = self._waiters_initial_metadata
self._waiters_initial_metadata = None
for waiter in waiters:

@ -15,6 +15,7 @@
cdef enum AioChannelStatus:
AIO_CHANNEL_STATUS_UNKNOWN
AIO_CHANNEL_STATUS_READY
AIO_CHANNEL_STATUS_CLOSING
AIO_CHANNEL_STATUS_DESTROYED
cdef class AioChannel:

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class _WatchConnectivityFailed(Exception):
@ -69,9 +70,10 @@ cdef class AioChannel:
Keeps mirroring the behavior from Core, so we can easily switch to
other design of API if necessary.
"""
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING):
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef object future = self.loop.create_future()
@ -92,6 +94,9 @@ cdef class AioChannel:
else:
return True
def closing(self):
self._status = AIO_CHANNEL_STATUS_CLOSING
def close(self):
self._status = AIO_CHANNEL_STATUS_DESTROYED
grpc_channel_destroy(self.channel)
@ -105,7 +110,7 @@ cdef class AioChannel:
Returns:
The _AioCall object.
"""
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
if self._status in (AIO_CHANNEL_STATUS_CLOSING, AIO_CHANNEL_STATUS_DESTROYED):
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')

@ -15,6 +15,7 @@
import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, Text
import logging
import grpc
from grpc import _common
from grpc._cython import cygrpc
@ -28,8 +29,37 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction)
from ._utils import _timeout_to_deadline
_TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC = 0.1
_IMMUTABLE_EMPTY_TUPLE = tuple()
_LOGGER = logging.getLogger(__name__)
class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls."""
_calls: Sequence[_base_call.RpcContext]
def __init__(self):
self._calls = []
def _remove_call(self, call: _base_call.RpcContext):
self._calls.remove(call)
@property
def calls(self) -> Sequence[_base_call.RpcContext]:
"""Returns a shallow copy of the ongoing calls sequence."""
return self._calls[:]
def size(self) -> int:
"""Returns the number of ongoing calls."""
return len(self._calls)
def trace_call(self, call: _base_call.RpcContext):
"""Adds and manages a new ongoing call."""
self._calls.append(call)
call.add_done_callback(self._remove_call)
class _BaseMultiCallable:
"""Base class of all multi callable objects.
@ -38,6 +68,7 @@ class _BaseMultiCallable:
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_ongoing_calls: _OngoingCalls
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
@ -49,9 +80,11 @@ class _BaseMultiCallable:
_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_loop: asyncio.AbstractEventLoop
# pylint: disable=too-many-arguments
def __init__(
self,
channel: cygrpc.AioChannel,
ongoing_calls: _OngoingCalls,
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
@ -60,6 +93,7 @@ class _BaseMultiCallable:
) -> None:
self._loop = loop
self._channel = channel
self._ongoing_calls = ongoing_calls
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
@ -111,18 +145,21 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE
if not self._interceptors:
return UnaryUnaryCall(request, _timeout_to_deadline(timeout),
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
return InterceptedUnaryUnaryCall(self._interceptors, request,
call = InterceptedUnaryUnaryCall(self._interceptors, request,
timeout, metadata, credentials,
self._channel, self._method,
self._request_serializer,
self._response_deserializer,
self._loop)
self._ongoing_calls.trace_call(call)
return call
class UnaryStreamMultiCallable(_BaseMultiCallable):
"""Affords invoking a unary-stream RPC from client-side in an asynchronous way."""
@ -165,10 +202,12 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
return UnaryStreamCall(request, deadline, metadata, credentials,
call = UnaryStreamCall(request, deadline, metadata, credentials,
self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
class StreamUnaryMultiCallable(_BaseMultiCallable):
@ -216,10 +255,12 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamUnaryCall(request_async_iterator, deadline, metadata,
call = StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
class StreamStreamMultiCallable(_BaseMultiCallable):
@ -267,10 +308,12 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamStreamCall(request_async_iterator, deadline, metadata,
call = StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
class Channel:
@ -281,6 +324,7 @@ class Channel:
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_ongoing_calls: _OngoingCalls
def __init__(self, target: Text, options: Optional[ChannelArgumentType],
credentials: Optional[grpc.ChannelCredentials],
@ -322,6 +366,53 @@ class Channel:
self._loop = asyncio.get_event_loop()
self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials, self._loop)
self._ongoing_calls = _OngoingCalls()
async def __aenter__(self):
"""Starts an asynchronous context manager.
Returns:
Channel the channel that was instantiated.
"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Finishes the asynchronous context manager by closing gracefully the channel."""
await self._close()
async def _wait_for_close_ongoing_calls(self):
sleep_iterations_sec = 0.001
while self._ongoing_calls.size() > 0:
await asyncio.sleep(sleep_iterations_sec)
async def _close(self):
# No new calls will be accepted by the Cython channel.
self._channel.closing()
calls = self._ongoing_calls.calls
for call in calls:
call.cancel()
try:
await asyncio.wait_for(self._wait_for_close_ongoing_calls(),
_TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC,
loop=self._loop)
except asyncio.TimeoutError:
_LOGGER.warning("Closing channel %s, closing RPCs timed out",
str(self))
self._channel.close()
async def close(self):
"""Closes this Channel and releases all resources held by it.
Closing the Channel will proactively terminate all RPCs active with the
Channel and it is not valid to invoke new RPCs with the Channel.
This method is idempotent.
"""
await self._close()
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
@ -387,7 +478,8 @@ class Channel:
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
"""
return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
request_serializer,
response_deserializer,
self._unary_unary_interceptors,
@ -399,7 +491,8 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
return UnaryStreamMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
@ -409,7 +502,8 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
return StreamUnaryMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
@ -419,33 +513,8 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, _common.encode(method),
return StreamStreamMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
request_serializer,
response_deserializer, None,
self._loop)
async def _close(self):
# TODO: Send cancellation status
self._channel.close()
async def __aenter__(self):
"""Starts an asynchronous context manager.
Returns:
Channel the channel that was instantiated.
"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Finishes the asynchronous context manager by closing gracefully the channel."""
await self._close()
async def close(self):
"""Closes this Channel and releases all resources held by it.
Closing the Channel will proactively terminate all RPCs active with the
Channel and it is not valid to invoke new RPCs with the Channel.
This method is idempotent.
"""
await self._close()

@ -25,7 +25,7 @@ from . import _base_call
from ._call import UnaryUnaryCall, AioRpcError
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
MetadataType, ResponseType)
MetadataType, ResponseType, DoneCallbackType)
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
@ -102,6 +102,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
_intercepted_call: Optional[_base_call.UnaryUnaryCall]
_intercepted_call_created: asyncio.Event
_interceptors_task: asyncio.Task
_pending_add_done_callbacks: Sequence[DoneCallbackType]
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
@ -118,6 +119,9 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
interceptors, method, timeout, metadata, credentials, request,
request_serializer, response_deserializer),
loop=loop)
self._pending_add_done_callbacks = []
self._interceptors_task.add_done_callback(
self._fire_pending_add_done_callbacks)
def __del__(self):
self.cancel()
@ -163,6 +167,17 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return await _run_interceptor(iter(interceptors), client_call_details,
request)
def _fire_pending_add_done_callbacks(self,
unused_task: asyncio.Task) -> None:
for callback in self._pending_add_done_callbacks:
callback(self)
self._pending_add_done_callbacks = []
def _wrap_add_done_callback(self, callback: DoneCallbackType,
unused_task: asyncio.Task) -> None:
callback(self)
def cancel(self) -> bool:
if self._interceptors_task.done():
return False
@ -186,15 +201,21 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
if not self._interceptors_task.done():
return False
try:
call = self._interceptors_task.result()
except (AioRpcError, asyncio.CancelledError):
return True
call = self._interceptors_task.result()
return call.done()
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def add_done_callback(self, callback: DoneCallbackType) -> None:
if not self._interceptors_task.done():
self._pending_add_done_callbacks.append(callback)
return
call = self._interceptors_task.result()
if call.done():
callback(self)
else:
callback = functools.partial(self._wrap_add_done_callback, callback)
call.add_done_callback(self._wrap_add_done_callback)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()

@ -8,6 +8,7 @@
"unit.call_test.TestUnaryUnaryCall",
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_test.TestChannel",
"unit.channel_test.Test_OngoingCalls",
"unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback",
"unit.init_test.TestInsecureChannel",

@ -20,6 +20,8 @@ import unittest
import grpc
from grpc.experimental import aio
from grpc.experimental.aio import _base_call
from grpc.experimental.aio._channel import _OngoingCalls
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants
@ -42,6 +44,43 @@ _REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
class Test_OngoingCalls(unittest.TestCase):
def test_trace_call(self):
class FakeCall(_base_call.RpcContext):
def __init__(self):
self.callback = None
def add_done_callback(self, callback):
self.callback = callback
def cancel(self):
raise NotImplementedError
def cancelled(self):
raise NotImplementedError
def done(self):
raise NotImplementedError
def time_remaining(self):
raise NotImplementedError
ongoing_calls = _OngoingCalls()
self.assertEqual(ongoing_calls.size(), 0)
call = FakeCall()
ongoing_calls.trace_call(call)
self.assertEqual(ongoing_calls.size(), 1)
self.assertEqual(ongoing_calls.calls, [call])
call.callback(call)
self.assertEqual(ongoing_calls.size(), 0)
self.assertEqual(ongoing_calls.calls, [])
class TestChannel(AioTestBase):
async def setUp(self):
@ -225,7 +264,66 @@ class TestChannel(AioTestBase):
self.assertEqual(grpc.StatusCode.OK, await call.code())
await channel.close()
async def test_close_unary_unary(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_unary_stream(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
request = messages_pb2.StreamingOutputCallRequest()
calls = [stub.StreamingOutputCall(request) for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_stream_stream(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
calls = [stub.FullDuplexCall() for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_async_context(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
calls = [
stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)
]
self.assertEqual(channel._ongoing_calls.size(), 2)
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
unittest.main(verbosity=2)

@ -573,6 +573,100 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_add_done_callback_before_finishes(self):
called = False
interceptor_can_continue = asyncio.Event()
def callback(call):
nonlocal called
called = True
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
await interceptor_can_continue.wait()
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
call.add_done_callback(callback)
interceptor_can_continue.set()
await call
self.assertTrue(called)
async def test_add_done_callback_after_finishes(self):
called = False
def callback(call):
nonlocal called
called = True
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
await call
call.add_done_callback(callback)
self.assertTrue(called)
async def test_add_done_callback_after_finishes_before_await(self):
called = False
def callback(call):
nonlocal called
called = True
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
call.add_done_callback(callback)
await call
self.assertTrue(called)
if __name__ == '__main__':
logging.basicConfig()

Loading…
Cancel
Save