From bccd32dafa1ed60e745c958a55960cf75c56d7d2 Mon Sep 17 00:00:00 2001 From: Nathaniel Manista Date: Wed, 2 May 2018 23:35:00 +0000 Subject: [PATCH] Add grpc.Channel.close --- src/python/grpcio/grpc/__init__.py | 17 +- src/python/grpcio/grpc/_channel.py | 23 +++ src/python/grpcio/grpc/_interceptor.py | 13 ++ .../grpc_testing/_channel/_channel.py | 15 ++ src/python/grpcio_tests/tests/tests.json | 1 + .../tests/unit/_channel_close_test.py | 185 ++++++++++++++++++ 6 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 src/python/grpcio_tests/tests/unit/_channel_close_test.py diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 7fa73036914..b7ed0c85635 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -813,7 +813,11 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): class Channel(six.with_metaclass(abc.ABCMeta)): - """Affords RPC invocation via generic methods on client-side.""" + """Affords RPC invocation via generic methods on client-side. + + Channel objects implement the Context Manager type, although they need not + support being entered and exited multiple times. + """ @abc.abstractmethod def subscribe(self, callback, try_to_connect=False): @@ -926,6 +930,17 @@ class Channel(six.with_metaclass(abc.ABCMeta)): """ raise NotImplementedError() + @abc.abstractmethod + def close(self): + """Closes this Channel and releases all resources held by it. + + Closing the Channel will immediately terminate all RPCs active with the + Channel and it is not valid to invoke new RPCs with the Channel. + + This method is idempotent. + """ + raise NotImplementedError() + ########################## Service-Side Context ############################## diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 6604f8f35c0..3a4585a5115 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -909,5 +909,28 @@ class Channel(grpc.Channel): self._channel, _channel_managed_call_management(self._call_state), _common.encode(method), request_serializer, response_deserializer) + def _close(self): + self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!') + _moot(self._connectivity_state) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._close() + return False + + def close(self): + self._close() + def __del__(self): + # TODO(https://github.com/grpc/grpc/issues/12531): Several releases + # after 1.12 (1.16 or thereabouts?) add a "self._channel.close" call + # here (or more likely, call self._close() here). We don't do this today + # because many valid use cases today allow the channel to be deleted + # immediately after stubs are created. After a sufficient period of time + # has passed for all users to be trusted to hang out to their channels + # for as long as they are in use and to close them after using them, + # then deletion of this grpc._channel.Channel instance can be made to + # effect closure of the underlying cygrpc.Channel instance. _moot(self._connectivity_state) diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index d029472c687..f465e35a9c3 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -334,6 +334,19 @@ class _Channel(grpc.Channel): else: return thunk(method) + def _close(self): + self._channel.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._close() + return False + + def close(self): + self._channel.close() + def intercept_channel(channel, *interceptors): for interceptor in reversed(list(interceptors)): diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_channel.py b/src/python/grpcio_testing/grpc_testing/_channel/_channel.py index b015b8d7388..0c1941e6bea 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_channel.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_channel.py @@ -56,6 +56,21 @@ class TestingChannel(grpc_testing.Channel): response_deserializer=None): return _multi_callable.StreamStream(method, self._state) + def _close(self): + # TODO(https://github.com/grpc/grpc/issues/12531): Decide what + # action to take here, if any? + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._close() + return False + + def close(self): + self._close() + def take_unary_unary(self, method_descriptor): return _channel_rpc.unary_unary(self._state, method_descriptor) diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index d38ee517f09..2fae27a220c 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -25,6 +25,7 @@ "unit._auth_test.AccessTokenAuthMetadataPluginTest", "unit._auth_test.GoogleCallCredentialsTest", "unit._channel_args_test.ChannelArgsTest", + "unit._channel_close_test.ChannelCloseTest", "unit._channel_connectivity_test.ChannelConnectivityTest", "unit._channel_ready_future_test.ChannelReadyFutureTest", "unit._compression_test.CompressionTest", diff --git a/src/python/grpcio_tests/tests/unit/_channel_close_test.py b/src/python/grpcio_tests/tests/unit/_channel_close_test.py new file mode 100644 index 00000000000..af3a9ee1ee1 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_channel_close_test.py @@ -0,0 +1,185 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server and client side compression.""" + +import threading +import time +import unittest + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_BEAT = 0.5 +_SOME_TIME = 5 +_MORE_TIME = 10 + + +class _MethodHandler(grpc.RpcMethodHandler): + + request_streaming = True + response_streaming = True + request_deserializer = None + response_serializer = None + + def stream_stream(self, request_iterator, servicer_context): + for request in request_iterator: + yield request * 2 + + +_METHOD_HANDLER = _MethodHandler() + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + return _METHOD_HANDLER + + +_GENERIC_HANDLER = _GenericHandler() + + +class _Pipe(object): + + def __init__(self, values): + self._condition = threading.Condition() + self._values = list(values) + self._open = True + + def __iter__(self): + return self + + def _next(self): + with self._condition: + while not self._values and self._open: + self._condition.wait() + if self._values: + return self._values.pop(0) + else: + raise StopIteration() + + def next(self): + return self._next() + + def __next__(self): + return self._next() + + def add(self, value): + with self._condition: + self._values.append(value) + self._condition.notify() + + def close(self): + with self._condition: + self._open = False + self._condition.notify() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + +class ChannelCloseTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server( + max_workers=test_constants.THREAD_CONCURRENCY) + self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,)) + self._port = self._server.add_insecure_port('[::]:0') + self._server.start() + + def tearDown(self): + self._server.stop(None) + + def test_close_immediately_after_call_invocation(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream('Meffod') + request_iterator = _Pipe(()) + response_iterator = multi_callable(request_iterator) + channel.close() + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_close_while_call_active(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream('Meffod') + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + channel.close() + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_context_manager_close_while_call_active(self): + with grpc.insecure_channel('localhost:{}'.format( + self._port)) as channel: # pylint: disable=bad-continuation + multi_callable = channel.stream_stream('Meffod') + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_context_manager_close_while_many_calls_active(self): + with grpc.insecure_channel('localhost:{}'.format( + self._port)) as channel: # pylint: disable=bad-continuation + multi_callable = channel.stream_stream('Meffod') + request_iterators = tuple( + _Pipe((b'abc',)) + for _ in range(test_constants.THREAD_CONCURRENCY)) + response_iterators = [] + for request_iterator in request_iterators: + response_iterator = multi_callable(request_iterator) + next(response_iterator) + response_iterators.append(response_iterator) + for request_iterator in request_iterators: + request_iterator.close() + + for response_iterator in response_iterators: + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_many_concurrent_closes(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream('Meffod') + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + start = time.time() + end = start + _MORE_TIME + + def sleep_some_time_then_close(): + time.sleep(_SOME_TIME) + channel.close() + + for _ in range(test_constants.THREAD_CONCURRENCY): + close_thread = threading.Thread(target=sleep_some_time_then_close) + close_thread.start() + while True: + request_iterator.add(b'def') + time.sleep(_BEAT) + if end < time.time(): + break + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + +if __name__ == '__main__': + unittest.main(verbosity=2)