Make client-side graceful shutdown faster

pull/21988/head
Lidi Zheng 5 years ago
parent 3ff3c0d8da
commit fbd213d04b
  1. 110
      src/python/grpcio/grpc/experimental/aio/_channel.py
  2. 59
      src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

@ -13,15 +13,12 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet
from weakref import WeakSet
import sys
from typing import AbstractSet, Any, AsyncIterable, Optional, Sequence
import logging
import grpc
from grpc import _common
from grpc import _common, _compression, _grpcio_metadata
from grpc._cython import cygrpc
from grpc import _compression
from grpc import _grpcio_metadata
from . import _base_call
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
@ -35,6 +32,15 @@ from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple()
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
if sys.version_info[1] < 7:
def _all_tasks() -> Sequence[asyncio.Task]:
return asyncio.Task.all_tasks()
else:
def _all_tasks() -> Sequence[asyncio.Task]:
return asyncio.all_tasks()
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
@ -48,38 +54,6 @@ def _augment_channel_arguments(base_options: ChannelArgumentType,
) + compression_channel_argument + user_agent_channel_argument
_LOGGER = logging.getLogger(__name__)
class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls."""
_calls: AbstractSet[_base_call.RpcContext]
def __init__(self):
self._calls = WeakSet()
def _remove_call(self, call: _base_call.RpcContext):
try:
self._calls.remove(call)
except KeyError:
pass
@property
def calls(self) -> AbstractSet[_base_call.RpcContext]:
"""Returns the set of ongoing calls."""
return self._calls
def size(self) -> int:
"""Returns the number of ongoing calls."""
return len(self._calls)
def trace_call(self, call: _base_call.RpcContext):
"""Adds and manages a new ongoing call."""
self._calls.add(call)
call.add_done_callback(self._remove_call)
class _BaseMultiCallable:
"""Base class of all multi callable objects.
@ -87,7 +61,6 @@ class _BaseMultiCallable:
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_ongoing_calls: _OngoingCalls
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
@ -103,7 +76,6 @@ class _BaseMultiCallable:
def __init__(
self,
channel: cygrpc.AioChannel,
ongoing_calls: _OngoingCalls,
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
@ -112,7 +84,6 @@ class _BaseMultiCallable:
) -> None:
self._loop = loop
self._channel = channel
self._ongoing_calls = ongoing_calls
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
@ -170,7 +141,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
self._request_serializer, self._response_deserializer,
self._loop)
self._ongoing_calls.trace_call(call)
return call
@ -213,7 +183,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
@ -260,7 +230,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
@ -307,7 +277,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
@ -319,7 +289,6 @@ class Channel:
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_ongoing_calls: _OngoingCalls
def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
@ -359,7 +328,6 @@ class Channel:
_common.encode(target),
_augment_channel_arguments(options, compression), credentials,
self._loop)
self._ongoing_calls = _OngoingCalls()
async def __aenter__(self):
"""Starts an asynchronous context manager.
@ -383,22 +351,32 @@ class Channel:
# No new calls will be accepted by the Cython channel.
self._channel.closing()
if grace:
# pylint: disable=unused-variable
_, pending = await asyncio.wait(self._ongoing_calls.calls,
timeout=grace,
loop=self._loop)
if not pending:
return
# A new set is created acting as a shallow copy because
# when cancellation happens the calls are automatically
# removed from the originally set.
calls = WeakSet(data=self._ongoing_calls.calls)
# Iterate through running tasks
tasks = _all_tasks()
calls = []
call_tasks = []
for task in tasks:
stack = task.get_stack(limit=1)
if not stack:
continue
# Locate ones created by `aio.Call`.
frame = stack[0]
if 'self' in frame.f_locals:
if isinstance(frame.f_locals['self'], _base_call.Call):
calls.append(frame.f_locals['self'])
call_tasks.append(task)
# If needed, try to wait for them to finish.
# Call objects are not always awaitables.
if grace and call_tasks:
await asyncio.wait(call_tasks, timeout=grace, loop=self._loop)
# Time to cancel existing calls.
for call in calls:
call.cancel()
# Destroy the channel
self._channel.close()
async def close(self, grace: Optional[float] = None):
@ -487,8 +465,7 @@ class Channel:
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
"""
return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._unary_unary_interceptors,
@ -500,8 +477,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
@ -511,8 +487,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
@ -522,8 +497,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
return StreamStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None,
self._loop)

@ -21,7 +21,6 @@ from weakref import WeakSet
import grpc
from grpc.experimental import aio
from grpc.experimental.aio import _base_call
from grpc.experimental.aio._channel import _OngoingCalls
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_base import AioTestBase
@ -31,47 +30,6 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60
class TestOngoingCalls(unittest.TestCase):
class FakeCall(_base_call.RpcContext):
def add_done_callback(self, callback):
self.callback = callback
def cancel(self):
raise NotImplementedError
def cancelled(self):
raise NotImplementedError
def done(self):
raise NotImplementedError
def time_remaining(self):
raise NotImplementedError
def test_trace_call(self):
ongoing_calls = _OngoingCalls()
self.assertEqual(ongoing_calls.size(), 0)
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
self.assertEqual(ongoing_calls.size(), 1)
self.assertEqual(ongoing_calls.calls, WeakSet([call]))
call.callback(call)
self.assertEqual(ongoing_calls.size(), 0)
self.assertEqual(ongoing_calls.calls, WeakSet())
def test_deleted_call(self):
ongoing_calls = _OngoingCalls()
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
del (call)
self.assertEqual(ongoing_calls.size(), 0)
class TestCloseChannel(AioTestBase):
async def setUp(self):
@ -114,15 +72,11 @@ class TestCloseChannel(AioTestBase):
calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_unary_stream(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
@ -130,15 +84,11 @@ class TestCloseChannel(AioTestBase):
request = messages_pb2.StreamingOutputCallRequest()
calls = [stub.StreamingOutputCall(request) for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_stream_unary(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
@ -150,36 +100,27 @@ class TestCloseChannel(AioTestBase):
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_stream_stream(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
calls = [stub.FullDuplexCall() for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_async_context(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
calls = [
stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)
]
self.assertEqual(channel._ongoing_calls.size(), 2)
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)

Loading…
Cancel
Save