diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 6caaece82c4..b58201b79d0 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -111,7 +111,7 @@ def _raise_rpc_error(state): def _possibly_finish_call(state, token): state.due.remove(token) - if (state.client is _CANCELLED or state.statused) and not state.due: + if not _is_rpc_state_active(state) and not state.due: callbacks = state.callbacks state.callbacks = None return state, callbacks @@ -218,7 +218,7 @@ class _Context(grpc.ServicerContext): def is_active(self): with self._state.condition: - return self._state.client is not _CANCELLED and not self._state.statused + return _is_rpc_state_active(self._state) def time_remaining(self): return max(self._rpc_event.call_details.deadline - time.time(), 0) @@ -313,7 +313,7 @@ class _RequestIterator(object): def _raise_or_start_receive_message(self): if self._state.client is _CANCELLED: _raise_rpc_error(self._state) - elif self._state.client is _CLOSED or self._state.statused: + elif not _is_rpc_state_active(self._state): raise StopIteration() else: self._call.start_server_batch( @@ -358,7 +358,7 @@ def _unary_request(rpc_event, state, request_deserializer): def unary_request(): with state.condition: - if state.client is _CANCELLED or state.statused: + if not _is_rpc_state_active(state): return None else: rpc_event.call.start_server_batch( @@ -386,10 +386,18 @@ def _unary_request(rpc_event, state, request_deserializer): return unary_request -def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): +def _call_behavior(rpc_event, + state, + behavior, + argument, + request_deserializer, + stream_observer=None): context = _Context(rpc_event, state, request_deserializer) try: - return behavior(argument, context), True + if stream_observer is not None: + return behavior(argument, context, stream_observer), True + else: + return behavior(argument, context), True except Exception as exception: # pylint: disable=broad-except with state.condition: if state.aborted: @@ -434,7 +442,7 @@ def _serialize_response(rpc_event, state, response, response_serializer): def _send_response(rpc_event, state, serialized_response): with state.condition: - if state.client is _CANCELLED or state.statused: + if not _is_rpc_state_active(state): return False else: if state.initial_metadata_allowed: @@ -455,7 +463,7 @@ def _send_response(rpc_event, state, serialized_response): while True: state.condition.wait() if token not in state.due: - return state.client is not _CANCELLED and not state.statused + return _is_rpc_state_active(state) def _status(rpc_event, state, serialized_response): @@ -501,65 +509,102 @@ def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk, def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk, request_deserializer, response_serializer): cygrpc.install_context_from_call(rpc_event.call) + + def on_next(response): + if response is None: + _status(rpc_event, state, None) + else: + serialized_response = _serialize_response( + rpc_event, state, response, response_serializer) + if serialized_response is not None: + _send_response(rpc_event, state, serialized_response) + try: argument = argument_thunk() if argument is not None: - response_iterator, proceed = _call_behavior( - rpc_event, state, behavior, argument, request_deserializer) - if proceed: - while True: - response, proceed = _take_response_from_response_iterator( - rpc_event, state, response_iterator) - if proceed: - if response is None: - _status(rpc_event, state, None) - break - else: - serialized_response = _serialize_response( - rpc_event, state, response, response_serializer) - if serialized_response is not None: - proceed = _send_response( - rpc_event, state, serialized_response) - if not proceed: - break - else: - break - else: - break + if hasattr(behavior, 'experimental_non_blocking' + ) and behavior.experimental_non_blocking: + _call_behavior( + rpc_event, + state, + behavior, + argument, + request_deserializer, + stream_observer=on_next) + else: + response_iterator, proceed = _call_behavior( + rpc_event, state, behavior, argument, request_deserializer) + if proceed: + _stream_response_iterator_adapter(rpc_event, state, on_next, + response_iterator) finally: cygrpc.uninstall_context() -def _handle_unary_unary(rpc_event, state, method_handler, thread_pool): +def _is_rpc_state_active(state): + return state.client is not _CANCELLED and not state.statused + + +def _stream_response_iterator_adapter(rpc_event, state, stream_observer, + response_iterator): + while True: + response, proceed = _take_response_from_response_iterator( + rpc_event, state, response_iterator) + if proceed: + stream_observer(response) + if not _is_rpc_state_active(state): + break + else: + break + + +def _select_thread_pool_for_behavior(behavior, default_thread_pool): + if hasattr(behavior, 'experimental_thread_pool' + ) and behavior.experimental_thread_pool is not None: + return behavior.experimental_thread_pool + else: + return default_thread_pool + + +def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool): unary_request = _unary_request(rpc_event, state, method_handler.request_deserializer) + thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary, + default_thread_pool) return thread_pool.submit(_unary_response_in_pool, rpc_event, state, method_handler.unary_unary, unary_request, method_handler.request_deserializer, method_handler.response_serializer) -def _handle_unary_stream(rpc_event, state, method_handler, thread_pool): +def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool): unary_request = _unary_request(rpc_event, state, method_handler.request_deserializer) + thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream, + default_thread_pool) return thread_pool.submit(_stream_response_in_pool, rpc_event, state, method_handler.unary_stream, unary_request, method_handler.request_deserializer, method_handler.response_serializer) -def _handle_stream_unary(rpc_event, state, method_handler, thread_pool): +def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool): request_iterator = _RequestIterator(state, rpc_event.call, method_handler.request_deserializer) + thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary, + default_thread_pool) return thread_pool.submit( _unary_response_in_pool, rpc_event, state, method_handler.stream_unary, lambda: request_iterator, method_handler.request_deserializer, method_handler.response_serializer) -def _handle_stream_stream(rpc_event, state, method_handler, thread_pool): +def _handle_stream_stream(rpc_event, state, method_handler, + default_thread_pool): request_iterator = _RequestIterator(state, rpc_event.call, method_handler.request_deserializer) + thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream, + default_thread_pool) return thread_pool.submit( _stream_response_in_pool, rpc_event, state, method_handler.stream_stream, lambda: request_iterator, diff --git a/src/python/grpcio_health_checking/grpc_health/v1/health.py b/src/python/grpcio_health_checking/grpc_health/v1/health.py index 0a5bbb5504c..f135ffbb645 100644 --- a/src/python/grpcio_health_checking/grpc_health/v1/health.py +++ b/src/python/grpcio_health_checking/grpc_health/v1/health.py @@ -13,6 +13,7 @@ # limitations under the License. """Reference implementation for health checking in gRPC Python.""" +import collections import threading import grpc @@ -27,7 +28,7 @@ class _Watcher(): def __init__(self): self._condition = threading.Condition() - self._responses = list() + self._responses = collections.deque() self._open = True def __iter__(self): @@ -38,7 +39,7 @@ class _Watcher(): while not self._responses and self._open: self._condition.wait() if self._responses: - return self._responses.pop(0) + return self._responses.popleft() else: raise StopIteration() @@ -59,20 +60,35 @@ class _Watcher(): self._condition.notify() +def _watcher_to_on_next_adapter(watcher): + + def on_next(response): + if response is None: + watcher.close() + else: + watcher.add(response) + + return on_next + + class HealthServicer(_health_pb2_grpc.HealthServicer): """Servicer handling RPCs for service statuses.""" - def __init__(self): + def __init__(self, + experimental_non_blocking=True, + experimental_thread_pool=None): self._lock = threading.RLock() self._server_status = {} - self._watchers = {} + self._on_next_callbacks = {} + self.Watch.__func__.experimental_non_blocking = experimental_non_blocking + self.Watch.__func__.experimental_thread_pool = experimental_thread_pool - def _on_close_callback(self, watcher, service): + def _on_close_callback(self, on_next, service): def callback(): with self._lock: - self._watchers[service].remove(watcher) - watcher.close() + self._on_next_callbacks[service].remove(on_next) + on_next(None) return callback @@ -85,19 +101,26 @@ class HealthServicer(_health_pb2_grpc.HealthServicer): else: return _health_pb2.HealthCheckResponse(status=status) - def Watch(self, request, context): + # pylint: disable=arguments-differ + def Watch(self, request, context, on_next=None): + blocking_watcher = None + if on_next is None: + # The server does not support the experimental_non_blocking + # parameter. For backwards compatibility, return a blocking response + # generator. + blocking_watcher = _Watcher() + on_next = _watcher_to_on_next_adapter(blocking_watcher) service = request.service with self._lock: status = self._server_status.get(service) if status is None: status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member - watcher = _Watcher() - watcher.add(_health_pb2.HealthCheckResponse(status=status)) - if service not in self._watchers: - self._watchers[service] = set() - self._watchers[service].add(watcher) - context.add_callback(self._on_close_callback(watcher, service)) - return watcher + on_next(_health_pb2.HealthCheckResponse(status=status)) + if service not in self._on_next_callbacks: + self._on_next_callbacks[service] = set() + self._on_next_callbacks[service].add(on_next) + context.add_callback(self._on_close_callback(on_next, service)) + return blocking_watcher def set(self, service, status): """Sets the status of a service. @@ -109,6 +132,6 @@ class HealthServicer(_health_pb2_grpc.HealthServicer): """ with self._lock: self._server_status[service] = status - if service in self._watchers: - for watcher in self._watchers[service]: - watcher.add(_health_pb2.HealthCheckResponse(status=status)) + if service in self._on_next_callbacks: + for on_next in self._on_next_callbacks[service]: + on_next(_health_pb2.HealthCheckResponse(status=status)) diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py index 35794987bc8..3b8ee883bbe 100644 --- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py +++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py @@ -38,29 +38,170 @@ def _consume_responses(response_iterator, response_queue): response_queue.put(response) -class HealthServicerTest(unittest.TestCase): +class BaseWatchTests(object): + + class WatchTests(unittest.TestCase): + + def start_server(self, servicer): + self._servicer = servicer + self._servicer.set('', health_pb2.HealthCheckResponse.SERVING) + self._servicer.set(_SERVING_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + self._servicer.set(_UNKNOWN_SERVICE, + health_pb2.HealthCheckResponse.UNKNOWN) + self._servicer.set(_NOT_SERVING_SERVICE, + health_pb2.HealthCheckResponse.NOT_SERVING) + self._server = test_common.test_server() + port = self._server.add_insecure_port('[::]:0') + health_pb2_grpc.add_HealthServicer_to_server( + self._servicer, self._server) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = health_pb2_grpc.HealthStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def test_watch_empty_service(self): + request = health_pb2.HealthCheckRequest(service='') + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response.status) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + def test_watch_new_service(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.NOT_SERVING) + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + response.status) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + def test_watch_service_isolation(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + self._servicer.set('some-other-service', + health_pb2.HealthCheckResponse.SERVING) + with self.assertRaises(queue.Empty): + response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + def test_two_watchers(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue1 = queue.Queue() + response_queue2 = queue.Queue() + rendezvous1 = self._stub.Watch(request) + rendezvous2 = self._stub.Watch(request) + thread1 = threading.Thread( + target=_consume_responses, args=(rendezvous1, response_queue1)) + thread2 = threading.Thread( + target=_consume_responses, args=(rendezvous2, response_queue2)) + thread1.start() + thread2.start() + + response1 = response_queue1.get( + timeout=test_constants.SHORT_TIMEOUT) + response2 = response_queue2.get( + timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response1.status) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response2.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + response1 = response_queue1.get( + timeout=test_constants.SHORT_TIMEOUT) + response2 = response_queue2.get( + timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response1.status) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response2.status) + + rendezvous1.cancel() + rendezvous2.cancel() + thread1.join() + thread2.join() + self.assertTrue(response_queue1.empty()) + self.assertTrue(response_queue2.empty()) + + def test_cancelled_watch_removed_from_watch_list(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + rendezvous.cancel() + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + thread.join() + + # Wait, if necessary, for serving thread to process client cancellation + timeout = time.time() + test_constants.SHORT_TIMEOUT + while time.time( + ) < timeout and self._servicer._on_next_callbacks[_WATCH_SERVICE]: + time.sleep(1) + self.assertFalse(self._servicer._on_next_callbacks[_WATCH_SERVICE], + 'watch set should be empty') + self.assertTrue(response_queue.empty()) + + +class HealthServicerTest(BaseWatchTests.WatchTests): def setUp(self): - self._servicer = health.HealthServicer() - self._servicer.set('', health_pb2.HealthCheckResponse.SERVING) - self._servicer.set(_SERVING_SERVICE, - health_pb2.HealthCheckResponse.SERVING) - self._servicer.set(_UNKNOWN_SERVICE, - health_pb2.HealthCheckResponse.UNKNOWN) - self._servicer.set(_NOT_SERVING_SERVICE, - health_pb2.HealthCheckResponse.NOT_SERVING) - self._server = test_common.test_server() - port = self._server.add_insecure_port('[::]:0') - health_pb2_grpc.add_HealthServicer_to_server(self._servicer, - self._server) - self._server.start() - - self._channel = grpc.insecure_channel('localhost:%d' % port) - self._stub = health_pb2_grpc.HealthStub(self._channel) - - def tearDown(self): - self._server.stop(None) - self._channel.close() + super(HealthServicerTest, self).start_server( + health.HealthServicer( + experimental_non_blocking=False, experimental_thread_pool=None)) def test_check_empty_service(self): request = health_pb2.HealthCheckRequest() @@ -90,135 +231,17 @@ class HealthServicerTest(unittest.TestCase): self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code()) - def test_watch_empty_service(self): - request = health_pb2.HealthCheckRequest(service='') - response_queue = queue.Queue() - rendezvous = self._stub.Watch(request) - thread = threading.Thread( - target=_consume_responses, args=(rendezvous, response_queue)) - thread.start() - - response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response.status) - - rendezvous.cancel() - thread.join() - self.assertTrue(response_queue.empty()) - - def test_watch_new_service(self): - request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) - response_queue = queue.Queue() - rendezvous = self._stub.Watch(request) - thread = threading.Thread( - target=_consume_responses, args=(rendezvous, response_queue)) - thread.start() - - response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response.status) - - self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) - response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response.status) - - self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.NOT_SERVING) - response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, - response.status) - - rendezvous.cancel() - thread.join() - self.assertTrue(response_queue.empty()) - - def test_watch_service_isolation(self): - request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) - response_queue = queue.Queue() - rendezvous = self._stub.Watch(request) - thread = threading.Thread( - target=_consume_responses, args=(rendezvous, response_queue)) - thread.start() - - response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response.status) - - self._servicer.set('some-other-service', - health_pb2.HealthCheckResponse.SERVING) - with self.assertRaises(queue.Empty): - response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - - rendezvous.cancel() - thread.join() - self.assertTrue(response_queue.empty()) - - def test_two_watchers(self): - request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) - response_queue1 = queue.Queue() - response_queue2 = queue.Queue() - rendezvous1 = self._stub.Watch(request) - rendezvous2 = self._stub.Watch(request) - thread1 = threading.Thread( - target=_consume_responses, args=(rendezvous1, response_queue1)) - thread2 = threading.Thread( - target=_consume_responses, args=(rendezvous2, response_queue2)) - thread1.start() - thread2.start() - - response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT) - response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response1.status) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response2.status) - - self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) - response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT) - response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response1.status) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response2.status) - - rendezvous1.cancel() - rendezvous2.cancel() - thread1.join() - thread2.join() - self.assertTrue(response_queue1.empty()) - self.assertTrue(response_queue2.empty()) - - def test_cancelled_watch_removed_from_watch_list(self): - request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) - response_queue = queue.Queue() - rendezvous = self._stub.Watch(request) - thread = threading.Thread( - target=_consume_responses, args=(rendezvous, response_queue)) - thread.start() - - response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response.status) - - rendezvous.cancel() - self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) - thread.join() - - # Wait, if necessary, for serving thread to process client cancellation - timeout = time.time() + test_constants.SHORT_TIMEOUT - while time.time() < timeout and self._servicer._watchers[_WATCH_SERVICE]: - time.sleep(1) - self.assertFalse(self._servicer._watchers[_WATCH_SERVICE], - 'watch set should be empty') - self.assertTrue(response_queue.empty()) - def test_health_service_name(self): self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health') +class HealthServicerBackwardsCompatibleWatchTest(BaseWatchTests.WatchTests): + + def setUp(self): + super(HealthServicerBackwardsCompatibleWatchTest, self).start_server( + health.HealthServicer( + experimental_non_blocking=False, experimental_thread_pool=None)) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test.py b/src/python/grpcio_tests/tests/unit/_rpc_test.py index a99121cee57..20ef66671a0 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_test.py @@ -23,6 +23,7 @@ import grpc from grpc.framework.foundation import logging_pool from tests.unit import test_common +from tests.unit import _thread_pool from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_control @@ -33,8 +34,10 @@ _DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] _UNARY_UNARY = '/test/UnaryUnary' _UNARY_STREAM = '/test/UnaryStream' +_UNARY_STREAM_NON_BLOCKING = '/test/UnaryStreamNonBlocking' _STREAM_UNARY = '/test/StreamUnary' _STREAM_STREAM = '/test/StreamStream' +_STREAM_STREAM_NON_BLOCKING = '/test/StreamStreamNonBlocking' class _Callback(object): @@ -59,8 +62,14 @@ class _Callback(object): class _Handler(object): - def __init__(self, control): + def __init__(self, control, thread_pool): self._control = control + self._thread_pool = thread_pool + non_blocking_functions = (self.handle_unary_stream_non_blocking, + self.handle_stream_stream_non_blocking) + for non_blocking_function in non_blocking_functions: + non_blocking_function.__func__.experimental_non_blocking = True + non_blocking_function.__func__.experimental_thread_pool = self._thread_pool def handle_unary_unary(self, request, servicer_context): self._control.control() @@ -87,6 +96,20 @@ class _Handler(object): 'testvalue', ),)) + def handle_unary_stream_non_blocking(self, request, servicer_context, + on_next): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + on_next(request) + # yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + on_next(None) + def handle_stream_unary(self, request_iterator, servicer_context): if servicer_context is not None: servicer_context.invocation_metadata() @@ -115,6 +138,20 @@ class _Handler(object): yield request self._control.control() + def handle_stream_stream_non_blocking(self, request_iterator, + servicer_context, on_next): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + for request in request_iterator: + self._control.control() + on_next(request) + self._control.control() + on_next(None) + class _MethodHandler(grpc.RpcMethodHandler): @@ -145,6 +182,10 @@ class _GenericHandler(grpc.GenericRpcHandler): return _MethodHandler(False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING: + return _MethodHandler( + False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream_non_blocking, None, None) elif handler_call_details.method == _STREAM_UNARY: return _MethodHandler(True, False, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, None, @@ -152,6 +193,10 @@ class _GenericHandler(grpc.GenericRpcHandler): elif handler_call_details.method == _STREAM_STREAM: return _MethodHandler(True, True, None, None, None, None, None, self._handler.handle_stream_stream) + elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING: + return _MethodHandler( + True, True, None, None, None, None, None, + self._handler.handle_stream_stream_non_blocking) else: return None @@ -167,6 +212,13 @@ def _unary_stream_multi_callable(channel): response_deserializer=_DESERIALIZE_RESPONSE) +def _unary_stream_non_blocking_multi_callable(channel): + return channel.unary_stream( + _UNARY_STREAM_NON_BLOCKING, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + def _stream_unary_multi_callable(channel): return channel.stream_unary( _STREAM_UNARY, @@ -178,11 +230,16 @@ def _stream_stream_multi_callable(channel): return channel.stream_stream(_STREAM_STREAM) +def _stream_stream_non_blocking_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING) + + class RPCTest(unittest.TestCase): def setUp(self): self._control = test_control.PauseFailControl() - self._handler = _Handler(self._control) + self._thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) + self._handler = _Handler(self._control, self._thread_pool) self._server = test_common.test_server() port = self._server.add_insecure_port('[::]:0') @@ -195,6 +252,16 @@ class RPCTest(unittest.TestCase): self._server.stop(None) self._channel.close() + def testDefaultThreadPoolIsUsed(self): + self._consume_one_stream_response_unary_request( + _unary_stream_multi_callable(self._channel)) + self.assertFalse(self._thread_pool.was_used()) + + def testExperimentalThreadPoolIsUsed(self): + self._consume_one_stream_response_unary_request( + _unary_stream_non_blocking_multi_callable(self._channel)) + self.assertTrue(self._thread_pool.was_used()) + def testUnrecognizedMethod(self): request = b'abc' @@ -227,7 +294,7 @@ class RPCTest(unittest.TestCase): self.assertEqual(expected_response, response) self.assertIs(grpc.StatusCode.OK, call.code()) - self.assertEqual("", call.debug_error_string()) + self.assertEqual('', call.debug_error_string()) def testSuccessfulUnaryRequestFutureUnaryResponse(self): request = b'\x07\x08' @@ -310,6 +377,7 @@ class RPCTest(unittest.TestCase): def testSuccessfulStreamRequestStreamResponse(self): requests = tuple( b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) + expected_responses = tuple( self._handler.handle_stream_stream(iter(requests), None)) request_iterator = iter(requests) @@ -425,58 +493,36 @@ class RPCTest(unittest.TestCase): test_is_running_cell[0] = False def testConsumingOneStreamResponseUnaryRequest(self): - request = b'\x57\x38' + self._consume_one_stream_response_unary_request( + _unary_stream_multi_callable(self._channel)) - multi_callable = _unary_stream_multi_callable(self._channel) - response_iterator = multi_callable( - request, - metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),)) - next(response_iterator) + def testConsumingOneStreamResponseUnaryRequestNonBlocking(self): + self._consume_one_stream_response_unary_request( + _unary_stream_non_blocking_multi_callable(self._channel)) def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self): - request = b'\x57\x38' + self._consume_some_but_not_all_stream_responses_unary_request( + _unary_stream_multi_callable(self._channel)) - multi_callable = _unary_stream_multi_callable(self._channel) - response_iterator = multi_callable( - request, - metadata=(('test', - 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) - for _ in range(test_constants.STREAM_LENGTH // 2): - next(response_iterator) + def testConsumingSomeButNotAllStreamResponsesUnaryRequestNonBlocking(self): + self._consume_some_but_not_all_stream_responses_unary_request( + _unary_stream_non_blocking_multi_callable(self._channel)) def testConsumingSomeButNotAllStreamResponsesStreamRequest(self): - requests = tuple( - b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) - request_iterator = iter(requests) + self._consume_some_but_not_all_stream_responses_stream_request( + _stream_stream_multi_callable(self._channel)) - multi_callable = _stream_stream_multi_callable(self._channel) - response_iterator = multi_callable( - request_iterator, - metadata=(('test', - 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) - for _ in range(test_constants.STREAM_LENGTH // 2): - next(response_iterator) + def testConsumingSomeButNotAllStreamResponsesStreamRequestNonBlocking(self): + self._consume_some_but_not_all_stream_responses_stream_request( + _stream_stream_non_blocking_multi_callable(self._channel)) def testConsumingTooManyStreamResponsesStreamRequest(self): - requests = tuple( - b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) - request_iterator = iter(requests) + self._consume_too_many_stream_responses_stream_request( + _stream_stream_multi_callable(self._channel)) - multi_callable = _stream_stream_multi_callable(self._channel) - response_iterator = multi_callable( - request_iterator, - metadata=(('test', - 'ConsumingTooManyStreamResponsesStreamRequest'),)) - for _ in range(test_constants.STREAM_LENGTH): - next(response_iterator) - for _ in range(test_constants.STREAM_LENGTH): - with self.assertRaises(StopIteration): - next(response_iterator) - - self.assertIsNotNone(response_iterator.initial_metadata()) - self.assertIs(grpc.StatusCode.OK, response_iterator.code()) - self.assertIsNotNone(response_iterator.details()) - self.assertIsNotNone(response_iterator.trailing_metadata()) + def testConsumingTooManyStreamResponsesStreamRequestNonBlocking(self): + self._consume_too_many_stream_responses_stream_request( + _stream_stream_non_blocking_multi_callable(self._channel)) def testCancelledUnaryRequestUnaryResponse(self): request = b'\x07\x17' @@ -498,24 +544,12 @@ class RPCTest(unittest.TestCase): self.assertIs(grpc.StatusCode.CANCELLED, response_future.code()) def testCancelledUnaryRequestStreamResponse(self): - request = b'\x07\x19' - - multi_callable = _unary_stream_multi_callable(self._channel) - with self._control.pause(): - response_iterator = multi_callable( - request, - metadata=(('test', 'CancelledUnaryRequestStreamResponse'),)) - self._control.block_until_paused() - response_iterator.cancel() + self._cancelled_unary_request_stream_response( + _unary_stream_multi_callable(self._channel)) - with self.assertRaises(grpc.RpcError) as exception_context: - next(response_iterator) - self.assertIs(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) - self.assertIsNotNone(response_iterator.initial_metadata()) - self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) - self.assertIsNotNone(response_iterator.details()) - self.assertIsNotNone(response_iterator.trailing_metadata()) + def testCancelledUnaryRequestStreamResponseNonBlocking(self): + self._cancelled_unary_request_stream_response( + _unary_stream_non_blocking_multi_callable(self._channel)) def testCancelledStreamRequestUnaryResponse(self): requests = tuple( @@ -543,23 +577,12 @@ class RPCTest(unittest.TestCase): self.assertIsNotNone(response_future.trailing_metadata()) def testCancelledStreamRequestStreamResponse(self): - requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) - request_iterator = iter(requests) + self._cancelled_stream_request_stream_response( + _stream_stream_multi_callable(self._channel)) - multi_callable = _stream_stream_multi_callable(self._channel) - with self._control.pause(): - response_iterator = multi_callable( - request_iterator, - metadata=(('test', 'CancelledStreamRequestStreamResponse'),)) - response_iterator.cancel() - - with self.assertRaises(grpc.RpcError): - next(response_iterator) - self.assertIsNotNone(response_iterator.initial_metadata()) - self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) - self.assertIsNotNone(response_iterator.details()) - self.assertIsNotNone(response_iterator.trailing_metadata()) + def testCancelledStreamRequestStreamResponseNonBlocking(self): + self._cancelled_stream_request_stream_response( + _stream_stream_non_blocking_multi_callable(self._channel)) def testExpiredUnaryRequestBlockingUnaryResponse(self): request = b'\x07\x17' @@ -608,21 +631,12 @@ class RPCTest(unittest.TestCase): response_future.exception().code()) def testExpiredUnaryRequestStreamResponse(self): - request = b'\x07\x19' + self._expired_unary_request_stream_response( + _unary_stream_multi_callable(self._channel)) - multi_callable = _unary_stream_multi_callable(self._channel) - with self._control.pause(): - with self.assertRaises(grpc.RpcError) as exception_context: - response_iterator = multi_callable( - request, - timeout=test_constants.SHORT_TIMEOUT, - metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),)) - next(response_iterator) - - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - response_iterator.code()) + def testExpiredUnaryRequestStreamResponseNonBlocking(self): + self._expired_unary_request_stream_response( + _unary_stream_non_blocking_multi_callable(self._channel)) def testExpiredStreamRequestBlockingUnaryResponse(self): requests = tuple( @@ -678,23 +692,12 @@ class RPCTest(unittest.TestCase): self.assertIsNotNone(response_future.trailing_metadata()) def testExpiredStreamRequestStreamResponse(self): - requests = tuple( - b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH)) - request_iterator = iter(requests) - - multi_callable = _stream_stream_multi_callable(self._channel) - with self._control.pause(): - with self.assertRaises(grpc.RpcError) as exception_context: - response_iterator = multi_callable( - request_iterator, - timeout=test_constants.SHORT_TIMEOUT, - metadata=(('test', 'ExpiredStreamRequestStreamResponse'),)) - next(response_iterator) + self._expired_stream_request_stream_response( + _stream_stream_multi_callable(self._channel)) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - response_iterator.code()) + def testExpiredStreamRequestStreamResponseNonBlocking(self): + self._expired_stream_request_stream_response( + _stream_stream_non_blocking_multi_callable(self._channel)) def testFailedUnaryRequestBlockingUnaryResponse(self): request = b'\x37\x17' @@ -712,10 +715,10 @@ class RPCTest(unittest.TestCase): # sanity checks on to make sure returned string contains default members # of the error debug_error_string = exception_context.exception.debug_error_string() - self.assertIn("created", debug_error_string) - self.assertIn("description", debug_error_string) - self.assertIn("file", debug_error_string) - self.assertIn("file_line", debug_error_string) + self.assertIn('created', debug_error_string) + self.assertIn('description', debug_error_string) + self.assertIn('file', debug_error_string) + self.assertIn('file_line', debug_error_string) def testFailedUnaryRequestFutureUnaryResponse(self): request = b'\x37\x17' @@ -742,18 +745,12 @@ class RPCTest(unittest.TestCase): self.assertIs(response_future, value_passed_to_callback) def testFailedUnaryRequestStreamResponse(self): - request = b'\x37\x17' + self._failed_unary_request_stream_response( + _unary_stream_multi_callable(self._channel)) - multi_callable = _unary_stream_multi_callable(self._channel) - with self.assertRaises(grpc.RpcError) as exception_context: - with self._control.fail(): - response_iterator = multi_callable( - request, - metadata=(('test', 'FailedUnaryRequestStreamResponse'),)) - next(response_iterator) - - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + def testFailedUnaryRequestStreamResponseNonBlocking(self): + self._failed_unary_request_stream_response( + _unary_stream_non_blocking_multi_callable(self._channel)) def testFailedStreamRequestBlockingUnaryResponse(self): requests = tuple( @@ -795,21 +792,12 @@ class RPCTest(unittest.TestCase): self.assertIs(response_future, value_passed_to_callback) def testFailedStreamRequestStreamResponse(self): - requests = tuple( - b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) - request_iterator = iter(requests) + self._failed_stream_request_stream_response( + _stream_stream_multi_callable(self._channel)) - multi_callable = _stream_stream_multi_callable(self._channel) - with self._control.fail(): - with self.assertRaises(grpc.RpcError) as exception_context: - response_iterator = multi_callable( - request_iterator, - metadata=(('test', 'FailedStreamRequestStreamResponse'),)) - tuple(response_iterator) - - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) - self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) + def testFailedStreamRequestStreamResponseNonBlocking(self): + self._failed_stream_request_stream_response( + _stream_stream_non_blocking_multi_callable(self._channel)) def testIgnoredUnaryRequestFutureUnaryResponse(self): request = b'\x37\x17' @@ -820,11 +808,12 @@ class RPCTest(unittest.TestCase): metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),)) def testIgnoredUnaryRequestStreamResponse(self): - request = b'\x37\x17' + self._ignored_unary_stream_request_future_unary_response( + _unary_stream_multi_callable(self._channel)) - multi_callable = _unary_stream_multi_callable(self._channel) - multi_callable( - request, metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),)) + def testIgnoredUnaryRequestStreamResponseNonBlocking(self): + self._ignored_unary_stream_request_future_unary_response( + _unary_stream_non_blocking_multi_callable(self._channel)) def testIgnoredStreamRequestFutureUnaryResponse(self): requests = tuple( @@ -837,11 +826,177 @@ class RPCTest(unittest.TestCase): metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),)) def testIgnoredStreamRequestStreamResponse(self): + self._ignored_stream_request_stream_response( + _stream_stream_multi_callable(self._channel)) + + def testIgnoredStreamRequestStreamResponseNonBlocking(self): + self._ignored_stream_request_stream_response( + _stream_stream_non_blocking_multi_callable(self._channel)) + + def _consume_one_stream_response_unary_request(self, multi_callable): + request = b'\x57\x38' + + response_iterator = multi_callable( + request, + metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),)) + next(response_iterator) + + def _consume_some_but_not_all_stream_responses_unary_request( + self, multi_callable): + request = b'\x57\x38' + + response_iterator = multi_callable( + request, + metadata=(('test', + 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(response_iterator) + + def _consume_some_but_not_all_stream_responses_stream_request( + self, multi_callable): + requests = tuple( + b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + response_iterator = multi_callable( + request_iterator, + metadata=(('test', + 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(response_iterator) + + def _consume_too_many_stream_responses_stream_request(self, multi_callable): + requests = tuple( + b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + response_iterator = multi_callable( + request_iterator, + metadata=(('test', + 'ConsumingTooManyStreamResponsesStreamRequest'),)) + for _ in range(test_constants.STREAM_LENGTH): + next(response_iterator) + for _ in range(test_constants.STREAM_LENGTH): + with self.assertRaises(StopIteration): + next(response_iterator) + + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.OK, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def _cancelled_unary_request_stream_response(self, multi_callable): + request = b'\x07\x19' + + with self._control.pause(): + response_iterator = multi_callable( + request, + metadata=(('test', 'CancelledUnaryRequestStreamResponse'),)) + self._control.block_until_paused() + response_iterator.cancel() + + with self.assertRaises(grpc.RpcError) as exception_context: + next(response_iterator) + self.assertIs(grpc.StatusCode.CANCELLED, + exception_context.exception.code()) + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def _cancelled_stream_request_stream_response(self, multi_callable): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + with self._control.pause(): + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'CancelledStreamRequestStreamResponse'),)) + response_iterator.cancel() + + with self.assertRaises(grpc.RpcError): + next(response_iterator) + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def _expired_unary_request_stream_response(self, multi_callable): + request = b'\x07\x19' + + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request, + timeout=test_constants.SHORT_TIMEOUT, + metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + response_iterator.code()) + + def _expired_stream_request_stream_response(self, multi_callable): + requests = tuple( + b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request_iterator, + timeout=test_constants.SHORT_TIMEOUT, + metadata=(('test', 'ExpiredStreamRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + response_iterator.code()) + + def _failed_unary_request_stream_response(self, multi_callable): + request = b'\x37\x17' + + with self.assertRaises(grpc.RpcError) as exception_context: + with self._control.fail(): + response_iterator = multi_callable( + request, + metadata=(('test', 'FailedUnaryRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + def _failed_stream_request_stream_response(self, multi_callable): + requests = tuple( + b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'FailedStreamRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) + + def _ignored_unary_stream_request_future_unary_response( + self, multi_callable): + request = b'\x37\x17' + + multi_callable( + request, metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),)) + + def _ignored_stream_request_stream_response(self, multi_callable): requests = tuple( b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) request_iterator = iter(requests) - multi_callable = _stream_stream_multi_callable(self._channel) multi_callable( request_iterator, metadata=(('test', 'IgnoredStreamRequestStreamResponse'),))