Merge pull request #21988 from lidizheng/aio-fast-close-2

[Aio] Make client-side graceful shutdown faster
pull/22027/head
Lidi Zheng 5 years ago committed by GitHub
commit 11953d1315
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 132
      src/python/grpcio/grpc/experimental/aio/_channel.py
  3. 1
      src/python/grpcio_tests/tests_aio/tests.json
  4. 70
      src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

@ -15,7 +15,7 @@
cdef class _AioCall(GrpcCallWrapper):
cdef:
AioChannel _channel
readonly AioChannel _channel
list _references
object _deadline
list _done_callbacks

@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet
from weakref import WeakSet
import sys
from typing import Any, AsyncIterable, Iterable, Optional, Sequence
import logging
import grpc
from grpc import _common
from grpc import _common, _compression, _grpcio_metadata
from grpc._cython import cygrpc
from grpc import _compression
from grpc import _grpcio_metadata
from . import _base_call
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
@ -35,6 +33,15 @@ from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple()
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
if sys.version_info[1] < 7:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.Task.all_tasks()
else:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.all_tasks()
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
@ -48,50 +55,12 @@ def _augment_channel_arguments(base_options: ChannelArgumentType,
) + compression_channel_argument + user_agent_channel_argument
_LOGGER = logging.getLogger(__name__)
class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls."""
_calls: AbstractSet[_base_call.RpcContext]
def __init__(self):
self._calls = WeakSet()
def _remove_call(self, call: _base_call.RpcContext):
try:
self._calls.remove(call)
except KeyError:
pass
@property
def calls(self) -> AbstractSet[_base_call.RpcContext]:
"""Returns the set of ongoing calls."""
return self._calls
def size(self) -> int:
"""Returns the number of ongoing calls."""
return len(self._calls)
def trace_call(self, call: _base_call.RpcContext):
"""Adds and manages a new ongoing call."""
self._calls.add(call)
call.add_done_callback(self._remove_call)
class _BaseMultiCallable:
"""Base class of all multi callable objects.
Handles the initialization logic and stores common attributes.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_ongoing_calls: _OngoingCalls
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
@ -103,7 +72,6 @@ class _BaseMultiCallable:
def __init__(
self,
channel: cygrpc.AioChannel,
ongoing_calls: _OngoingCalls,
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
@ -112,7 +80,6 @@ class _BaseMultiCallable:
) -> None:
self._loop = loop
self._channel = channel
self._ongoing_calls = ongoing_calls
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
@ -170,7 +137,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
self._request_serializer, self._response_deserializer,
self._loop)
self._ongoing_calls.trace_call(call)
return call
@ -213,7 +179,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
@ -260,7 +226,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
@ -307,7 +273,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
@ -319,7 +285,6 @@ class Channel:
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_ongoing_calls: _OngoingCalls
def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
@ -359,7 +324,6 @@ class Channel:
_common.encode(target),
_augment_channel_arguments(options, compression), credentials,
self._loop)
self._ongoing_calls = _OngoingCalls()
async def __aenter__(self):
"""Starts an asynchronous context manager.
@ -383,22 +347,48 @@ class Channel:
# No new calls will be accepted by the Cython channel.
self._channel.closing()
if grace:
# pylint: disable=unused-variable
_, pending = await asyncio.wait(self._ongoing_calls.calls,
timeout=grace,
loop=self._loop)
if not pending:
return
# A new set is created acting as a shallow copy because
# when cancellation happens the calls are automatically
# removed from the originally set.
calls = WeakSet(data=self._ongoing_calls.calls)
# Iterate through running tasks
tasks = _all_tasks()
calls = []
call_tasks = []
for task in tasks:
stack = task.get_stack(limit=1)
# If the Task is created by a C-extension, the stack will be empty.
if not stack:
continue
# Locate ones created by `aio.Call`.
frame = stack[0]
candidate = frame.f_locals.get('self')
if candidate:
if isinstance(candidate, _base_call.Call):
if hasattr(candidate, '_channel'):
# For intercepted Call object
if candidate._channel is not self._channel:
continue
elif hasattr(candidate, '_cython_call'):
# For normal Call object
if candidate._cython_call._channel is not self._channel:
continue
else:
# Unidentified Call object
raise cygrpc.InternalError(
f'Unrecognized call object: {candidate}')
calls.append(candidate)
call_tasks.append(task)
# If needed, try to wait for them to finish.
# Call objects are not always awaitables.
if grace and call_tasks:
await asyncio.wait(call_tasks, timeout=grace, loop=self._loop)
# Time to cancel existing calls.
for call in calls:
call.cancel()
# Destroy the channel
self._channel.close()
async def close(self, grace: Optional[float] = None):
@ -487,8 +477,7 @@ class Channel:
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
"""
return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer,
self._unary_unary_interceptors,
@ -500,8 +489,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
@ -511,8 +499,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
@ -522,8 +509,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
return StreamStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None,
self._loop)

@ -12,7 +12,6 @@
"unit.channel_ready_test.TestChannelReady",
"unit.channel_test.TestChannel",
"unit.close_channel_test.TestCloseChannel",
"unit.close_channel_test.TestOngoingCalls",
"unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback",

@ -16,12 +16,10 @@
import asyncio
import logging
import unittest
from weakref import WeakSet
import grpc
from grpc.experimental import aio
from grpc.experimental.aio import _base_call
from grpc.experimental.aio._channel import _OngoingCalls
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_base import AioTestBase
@ -31,47 +29,6 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60
class TestOngoingCalls(unittest.TestCase):
class FakeCall(_base_call.RpcContext):
def add_done_callback(self, callback):
self.callback = callback
def cancel(self):
raise NotImplementedError
def cancelled(self):
raise NotImplementedError
def done(self):
raise NotImplementedError
def time_remaining(self):
raise NotImplementedError
def test_trace_call(self):
ongoing_calls = _OngoingCalls()
self.assertEqual(ongoing_calls.size(), 0)
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
self.assertEqual(ongoing_calls.size(), 1)
self.assertEqual(ongoing_calls.calls, WeakSet([call]))
call.callback(call)
self.assertEqual(ongoing_calls.size(), 0)
self.assertEqual(ongoing_calls.calls, WeakSet())
def test_deleted_call(self):
ongoing_calls = _OngoingCalls()
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
del (call)
self.assertEqual(ongoing_calls.size(), 0)
class TestCloseChannel(AioTestBase):
async def setUp(self):
@ -114,15 +71,11 @@ class TestCloseChannel(AioTestBase):
calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_unary_stream(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
@ -130,15 +83,11 @@ class TestCloseChannel(AioTestBase):
request = messages_pb2.StreamingOutputCallRequest()
calls = [stub.StreamingOutputCall(request) for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_stream_unary(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
@ -150,35 +99,38 @@ class TestCloseChannel(AioTestBase):
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_stream_stream(self):
channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel)
calls = [stub.FullDuplexCall() for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close()
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_async_context(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
calls = [
stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)
]
self.assertEqual(channel._ongoing_calls.size(), 2)
for call in calls:
self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_channel_isolation(self):
async with aio.insecure_channel(self._server_target) as channel1:
async with aio.insecure_channel(self._server_target) as channel2:
stub1 = test_pb2_grpc.TestServiceStub(channel1)
stub2 = test_pb2_grpc.TestServiceStub(channel2)
call1 = stub1.UnaryCall(messages_pb2.SimpleRequest())
call2 = stub2.UnaryCall(messages_pb2.SimpleRequest())
self.assertFalse(call1.cancelled())
self.assertTrue(call2.cancelled())
if __name__ == '__main__':

Loading…
Cancel
Save