Merge pull request #21988 from lidizheng/aio-fast-close-2

[Aio] Make client-side graceful shutdown faster
pull/22027/head
Lidi Zheng 5 years ago committed by GitHub
commit 11953d1315
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 128
      src/python/grpcio/grpc/experimental/aio/_channel.py
  3. 1
      src/python/grpcio_tests/tests_aio/tests.json
  4. 70
      src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

@ -15,7 +15,7 @@
cdef class _AioCall(GrpcCallWrapper): cdef class _AioCall(GrpcCallWrapper):
cdef: cdef:
AioChannel _channel readonly AioChannel _channel
list _references list _references
object _deadline object _deadline
list _done_callbacks list _done_callbacks

@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python.""" """Invocation-side implementation of gRPC Asyncio Python."""
import asyncio import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet import sys
from weakref import WeakSet from typing import Any, AsyncIterable, Iterable, Optional, Sequence
import logging
import grpc import grpc
from grpc import _common from grpc import _common, _compression, _grpcio_metadata
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc import _compression
from grpc import _grpcio_metadata
from . import _base_call from . import _base_call
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
@ -35,6 +33,15 @@ from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple() _IMMUTABLE_EMPTY_TUPLE = tuple()
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
if sys.version_info[1] < 7:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.Task.all_tasks()
else:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.all_tasks()
def _augment_channel_arguments(base_options: ChannelArgumentType, def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]): compression: Optional[grpc.Compression]):
@ -48,50 +55,12 @@ def _augment_channel_arguments(base_options: ChannelArgumentType,
) + compression_channel_argument + user_agent_channel_argument ) + compression_channel_argument + user_agent_channel_argument
_LOGGER = logging.getLogger(__name__)
class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls."""
_calls: AbstractSet[_base_call.RpcContext]
def __init__(self):
self._calls = WeakSet()
def _remove_call(self, call: _base_call.RpcContext):
try:
self._calls.remove(call)
except KeyError:
pass
@property
def calls(self) -> AbstractSet[_base_call.RpcContext]:
"""Returns the set of ongoing calls."""
return self._calls
def size(self) -> int:
"""Returns the number of ongoing calls."""
return len(self._calls)
def trace_call(self, call: _base_call.RpcContext):
"""Adds and manages a new ongoing call."""
self._calls.add(call)
call.add_done_callback(self._remove_call)
class _BaseMultiCallable: class _BaseMultiCallable:
"""Base class of all multi callable objects. """Base class of all multi callable objects.
Handles the initialization logic and stores common attributes. Handles the initialization logic and stores common attributes.
""" """
_loop: asyncio.AbstractEventLoop _loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_ongoing_calls: _OngoingCalls
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_channel: cygrpc.AioChannel _channel: cygrpc.AioChannel
_method: bytes _method: bytes
_request_serializer: SerializingFunction _request_serializer: SerializingFunction
@ -103,7 +72,6 @@ class _BaseMultiCallable:
def __init__( def __init__(
self, self,
channel: cygrpc.AioChannel, channel: cygrpc.AioChannel,
ongoing_calls: _OngoingCalls,
method: bytes, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
@ -112,7 +80,6 @@ class _BaseMultiCallable:
) -> None: ) -> None:
self._loop = loop self._loop = loop
self._channel = channel self._channel = channel
self._ongoing_calls = ongoing_calls
self._method = method self._method = method
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
@ -170,7 +137,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
self._request_serializer, self._response_deserializer, self._request_serializer, self._response_deserializer,
self._loop) self._loop)
self._ongoing_calls.trace_call(call)
return call return call
@ -213,7 +179,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
wait_for_ready, self._channel, self._method, wait_for_ready, self._channel, self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call return call
@ -260,7 +226,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
credentials, wait_for_ready, self._channel, credentials, wait_for_ready, self._channel,
self._method, self._request_serializer, self._method, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call return call
@ -307,7 +273,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
credentials, wait_for_ready, self._channel, credentials, wait_for_ready, self._channel,
self._method, self._request_serializer, self._method, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call return call
@ -319,7 +285,6 @@ class Channel:
_loop: asyncio.AbstractEventLoop _loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel _channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_ongoing_calls: _OngoingCalls
def __init__(self, target: str, options: ChannelArgumentType, def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials], credentials: Optional[grpc.ChannelCredentials],
@ -359,7 +324,6 @@ class Channel:
_common.encode(target), _common.encode(target),
_augment_channel_arguments(options, compression), credentials, _augment_channel_arguments(options, compression), credentials,
self._loop) self._loop)
self._ongoing_calls = _OngoingCalls()
async def __aenter__(self): async def __aenter__(self):
"""Starts an asynchronous context manager. """Starts an asynchronous context manager.
@ -383,22 +347,48 @@ class Channel:
# No new calls will be accepted by the Cython channel. # No new calls will be accepted by the Cython channel.
self._channel.closing() self._channel.closing()
if grace: # Iterate through running tasks
# pylint: disable=unused-variable tasks = _all_tasks()
_, pending = await asyncio.wait(self._ongoing_calls.calls, calls = []
timeout=grace, call_tasks = []
loop=self._loop) for task in tasks:
stack = task.get_stack(limit=1)
# If the Task is created by a C-extension, the stack will be empty.
if not stack:
continue
# Locate ones created by `aio.Call`.
frame = stack[0]
candidate = frame.f_locals.get('self')
if candidate:
if isinstance(candidate, _base_call.Call):
if hasattr(candidate, '_channel'):
# For intercepted Call object
if candidate._channel is not self._channel:
continue
elif hasattr(candidate, '_cython_call'):
# For normal Call object
if candidate._cython_call._channel is not self._channel:
continue
else:
# Unidentified Call object
raise cygrpc.InternalError(
f'Unrecognized call object: {candidate}')
if not pending: calls.append(candidate)
return call_tasks.append(task)
# If needed, try to wait for them to finish.
# Call objects are not always awaitables.
if grace and call_tasks:
await asyncio.wait(call_tasks, timeout=grace, loop=self._loop)
# A new set is created acting as a shallow copy because # Time to cancel existing calls.
# when cancellation happens the calls are automatically
# removed from the originally set.
calls = WeakSet(data=self._ongoing_calls.calls)
for call in calls: for call in calls:
call.cancel() call.cancel()
# Destroy the channel
self._channel.close() self._channel.close()
async def close(self, grace: Optional[float] = None): async def close(self, grace: Optional[float] = None):
@ -487,8 +477,7 @@ class Channel:
Returns: Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method. A UnaryUnaryMultiCallable value for the named unary-unary method.
""" """
return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls, return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
_common.encode(method),
request_serializer, request_serializer,
response_deserializer, response_deserializer,
self._unary_unary_interceptors, self._unary_unary_interceptors,
@ -500,8 +489,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable: ) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, self._ongoing_calls, return UnaryStreamMultiCallable(self._channel, _common.encode(method),
_common.encode(method),
request_serializer, request_serializer,
response_deserializer, None, self._loop) response_deserializer, None, self._loop)
@ -511,8 +499,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable: ) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, self._ongoing_calls, return StreamUnaryMultiCallable(self._channel, _common.encode(method),
_common.encode(method),
request_serializer, request_serializer,
response_deserializer, None, self._loop) response_deserializer, None, self._loop)
@ -522,8 +509,7 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable: ) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, self._ongoing_calls, return StreamStreamMultiCallable(self._channel, _common.encode(method),
_common.encode(method),
request_serializer, request_serializer,
response_deserializer, None, response_deserializer, None,
self._loop) self._loop)

@ -12,7 +12,6 @@
"unit.channel_ready_test.TestChannelReady", "unit.channel_ready_test.TestChannelReady",
"unit.channel_test.TestChannel", "unit.channel_test.TestChannel",
"unit.close_channel_test.TestCloseChannel", "unit.close_channel_test.TestCloseChannel",
"unit.close_channel_test.TestOngoingCalls",
"unit.compression_test.TestCompression", "unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState", "unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback", "unit.done_callback_test.TestDoneCallback",

@ -16,12 +16,10 @@
import asyncio import asyncio
import logging import logging
import unittest import unittest
from weakref import WeakSet
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from grpc.experimental.aio import _base_call from grpc.experimental.aio import _base_call
from grpc.experimental.aio._channel import _OngoingCalls
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
@ -31,47 +29,6 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60 _LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60
class TestOngoingCalls(unittest.TestCase):
class FakeCall(_base_call.RpcContext):
def add_done_callback(self, callback):
self.callback = callback
def cancel(self):
raise NotImplementedError
def cancelled(self):
raise NotImplementedError
def done(self):
raise NotImplementedError
def time_remaining(self):
raise NotImplementedError
def test_trace_call(self):
ongoing_calls = _OngoingCalls()
self.assertEqual(ongoing_calls.size(), 0)
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
self.assertEqual(ongoing_calls.size(), 1)
self.assertEqual(ongoing_calls.calls, WeakSet([call]))
call.callback(call)
self.assertEqual(ongoing_calls.size(), 0)
self.assertEqual(ongoing_calls.calls, WeakSet())
def test_deleted_call(self):
ongoing_calls = _OngoingCalls()
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
del (call)
self.assertEqual(ongoing_calls.size(), 0)
class TestCloseChannel(AioTestBase): class TestCloseChannel(AioTestBase):
async def setUp(self): async def setUp(self):
@ -114,15 +71,11 @@ class TestCloseChannel(AioTestBase):
calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)] calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close() await channel.close()
for call in calls: for call in calls:
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_unary_stream(self): async def test_close_unary_stream(self):
channel = aio.insecure_channel(self._server_target) channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
@ -130,15 +83,11 @@ class TestCloseChannel(AioTestBase):
request = messages_pb2.StreamingOutputCallRequest() request = messages_pb2.StreamingOutputCallRequest()
calls = [stub.StreamingOutputCall(request) for _ in range(2)] calls = [stub.StreamingOutputCall(request) for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close() await channel.close()
for call in calls: for call in calls:
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_stream_unary(self): async def test_close_stream_unary(self):
channel = aio.insecure_channel(self._server_target) channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
@ -150,35 +99,38 @@ class TestCloseChannel(AioTestBase):
for call in calls: for call in calls:
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_stream_stream(self): async def test_close_stream_stream(self):
channel = aio.insecure_channel(self._server_target) channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
calls = [stub.FullDuplexCall() for _ in range(2)] calls = [stub.FullDuplexCall() for _ in range(2)]
self.assertEqual(channel._ongoing_calls.size(), 2)
await channel.close() await channel.close()
for call in calls: for call in calls:
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0)
async def test_close_async_context(self): async def test_close_async_context(self):
async with aio.insecure_channel(self._server_target) as channel: async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
calls = [ calls = [
stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2) stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)
] ]
self.assertEqual(channel._ongoing_calls.size(), 2)
for call in calls: for call in calls:
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertEqual(channel._ongoing_calls.size(), 0) async def test_channel_isolation(self):
async with aio.insecure_channel(self._server_target) as channel1:
async with aio.insecure_channel(self._server_target) as channel2:
stub1 = test_pb2_grpc.TestServiceStub(channel1)
stub2 = test_pb2_grpc.TestServiceStub(channel2)
call1 = stub1.UnaryCall(messages_pb2.SimpleRequest())
call2 = stub2.UnaryCall(messages_pb2.SimpleRequest())
self.assertFalse(call1.cancelled())
self.assertTrue(call2.cancelled())
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save