non-blocking server streaming for health service

pull/18095/head
Eric Gribkoff 6 years ago
parent b3b5d63423
commit 9345eac211
  1. 113
      src/python/grpcio/grpc/_server.py
  2. 59
      src/python/grpcio_health_checking/grpc_health/v1/health.py
  3. 317
      src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
  4. 439
      src/python/grpcio_tests/tests/unit/_rpc_test.py

@ -111,7 +111,7 @@ def _raise_rpc_error(state):
def _possibly_finish_call(state, token):
state.due.remove(token)
if (state.client is _CANCELLED or state.statused) and not state.due:
if not _is_rpc_state_active(state) and not state.due:
callbacks = state.callbacks
state.callbacks = None
return state, callbacks
@ -218,7 +218,7 @@ class _Context(grpc.ServicerContext):
def is_active(self):
with self._state.condition:
return self._state.client is not _CANCELLED and not self._state.statused
return _is_rpc_state_active(self._state)
def time_remaining(self):
return max(self._rpc_event.call_details.deadline - time.time(), 0)
@ -313,7 +313,7 @@ class _RequestIterator(object):
def _raise_or_start_receive_message(self):
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
elif self._state.client is _CLOSED or self._state.statused:
elif not _is_rpc_state_active(self._state):
raise StopIteration()
else:
self._call.start_server_batch(
@ -358,7 +358,7 @@ def _unary_request(rpc_event, state, request_deserializer):
def unary_request():
with state.condition:
if state.client is _CANCELLED or state.statused:
if not _is_rpc_state_active(state):
return None
else:
rpc_event.call.start_server_batch(
@ -386,10 +386,18 @@ def _unary_request(rpc_event, state, request_deserializer):
return unary_request
def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
def _call_behavior(rpc_event,
state,
behavior,
argument,
request_deserializer,
stream_observer=None):
context = _Context(rpc_event, state, request_deserializer)
try:
return behavior(argument, context), True
if stream_observer is not None:
return behavior(argument, context, stream_observer), True
else:
return behavior(argument, context), True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if state.aborted:
@ -434,7 +442,7 @@ def _serialize_response(rpc_event, state, response, response_serializer):
def _send_response(rpc_event, state, serialized_response):
with state.condition:
if state.client is _CANCELLED or state.statused:
if not _is_rpc_state_active(state):
return False
else:
if state.initial_metadata_allowed:
@ -455,7 +463,7 @@ def _send_response(rpc_event, state, serialized_response):
while True:
state.condition.wait()
if token not in state.due:
return state.client is not _CANCELLED and not state.statused
return _is_rpc_state_active(state)
def _status(rpc_event, state, serialized_response):
@ -501,65 +509,102 @@ def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk,
def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
request_deserializer, response_serializer):
cygrpc.install_context_from_call(rpc_event.call)
def on_next(response):
if response is None:
_status(rpc_event, state, None)
else:
serialized_response = _serialize_response(
rpc_event, state, response, response_serializer)
if serialized_response is not None:
_send_response(rpc_event, state, serialized_response)
try:
argument = argument_thunk()
if argument is not None:
response_iterator, proceed = _call_behavior(
rpc_event, state, behavior, argument, request_deserializer)
if proceed:
while True:
response, proceed = _take_response_from_response_iterator(
rpc_event, state, response_iterator)
if proceed:
if response is None:
_status(rpc_event, state, None)
break
else:
serialized_response = _serialize_response(
rpc_event, state, response, response_serializer)
if serialized_response is not None:
proceed = _send_response(
rpc_event, state, serialized_response)
if not proceed:
break
else:
break
else:
break
if hasattr(behavior, 'experimental_non_blocking'
) and behavior.experimental_non_blocking:
_call_behavior(
rpc_event,
state,
behavior,
argument,
request_deserializer,
stream_observer=on_next)
else:
response_iterator, proceed = _call_behavior(
rpc_event, state, behavior, argument, request_deserializer)
if proceed:
_stream_response_iterator_adapter(rpc_event, state, on_next,
response_iterator)
finally:
cygrpc.uninstall_context()
def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
def _is_rpc_state_active(state):
return state.client is not _CANCELLED and not state.statused
def _stream_response_iterator_adapter(rpc_event, state, stream_observer,
response_iterator):
while True:
response, proceed = _take_response_from_response_iterator(
rpc_event, state, response_iterator)
if proceed:
stream_observer(response)
if not _is_rpc_state_active(state):
break
else:
break
def _select_thread_pool_for_behavior(behavior, default_thread_pool):
if hasattr(behavior, 'experimental_thread_pool'
) and behavior.experimental_thread_pool is not None:
return behavior.experimental_thread_pool
else:
return default_thread_pool
def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary,
default_thread_pool)
return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
method_handler.unary_unary, unary_request,
method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream,
default_thread_pool)
return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
method_handler.unary_stream, unary_request,
method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool):
request_iterator = _RequestIterator(state, rpc_event.call,
method_handler.request_deserializer)
thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary,
default_thread_pool)
return thread_pool.submit(
_unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
lambda: request_iterator, method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
def _handle_stream_stream(rpc_event, state, method_handler,
default_thread_pool):
request_iterator = _RequestIterator(state, rpc_event.call,
method_handler.request_deserializer)
thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream,
default_thread_pool)
return thread_pool.submit(
_stream_response_in_pool, rpc_event, state,
method_handler.stream_stream, lambda: request_iterator,

@ -13,6 +13,7 @@
# limitations under the License.
"""Reference implementation for health checking in gRPC Python."""
import collections
import threading
import grpc
@ -27,7 +28,7 @@ class _Watcher():
def __init__(self):
self._condition = threading.Condition()
self._responses = list()
self._responses = collections.deque()
self._open = True
def __iter__(self):
@ -38,7 +39,7 @@ class _Watcher():
while not self._responses and self._open:
self._condition.wait()
if self._responses:
return self._responses.pop(0)
return self._responses.popleft()
else:
raise StopIteration()
@ -59,20 +60,35 @@ class _Watcher():
self._condition.notify()
def _watcher_to_on_next_adapter(watcher):
def on_next(response):
if response is None:
watcher.close()
else:
watcher.add(response)
return on_next
class HealthServicer(_health_pb2_grpc.HealthServicer):
"""Servicer handling RPCs for service statuses."""
def __init__(self):
def __init__(self,
experimental_non_blocking=True,
experimental_thread_pool=None):
self._lock = threading.RLock()
self._server_status = {}
self._watchers = {}
self._on_next_callbacks = {}
self.Watch.__func__.experimental_non_blocking = experimental_non_blocking
self.Watch.__func__.experimental_thread_pool = experimental_thread_pool
def _on_close_callback(self, watcher, service):
def _on_close_callback(self, on_next, service):
def callback():
with self._lock:
self._watchers[service].remove(watcher)
watcher.close()
self._on_next_callbacks[service].remove(on_next)
on_next(None)
return callback
@ -85,19 +101,26 @@ class HealthServicer(_health_pb2_grpc.HealthServicer):
else:
return _health_pb2.HealthCheckResponse(status=status)
def Watch(self, request, context):
# pylint: disable=arguments-differ
def Watch(self, request, context, on_next=None):
blocking_watcher = None
if on_next is None:
# The server does not support the experimental_non_blocking
# parameter. For backwards compatibility, return a blocking response
# generator.
blocking_watcher = _Watcher()
on_next = _watcher_to_on_next_adapter(blocking_watcher)
service = request.service
with self._lock:
status = self._server_status.get(service)
if status is None:
status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member
watcher = _Watcher()
watcher.add(_health_pb2.HealthCheckResponse(status=status))
if service not in self._watchers:
self._watchers[service] = set()
self._watchers[service].add(watcher)
context.add_callback(self._on_close_callback(watcher, service))
return watcher
on_next(_health_pb2.HealthCheckResponse(status=status))
if service not in self._on_next_callbacks:
self._on_next_callbacks[service] = set()
self._on_next_callbacks[service].add(on_next)
context.add_callback(self._on_close_callback(on_next, service))
return blocking_watcher
def set(self, service, status):
"""Sets the status of a service.
@ -109,6 +132,6 @@ class HealthServicer(_health_pb2_grpc.HealthServicer):
"""
with self._lock:
self._server_status[service] = status
if service in self._watchers:
for watcher in self._watchers[service]:
watcher.add(_health_pb2.HealthCheckResponse(status=status))
if service in self._on_next_callbacks:
for on_next in self._on_next_callbacks[service]:
on_next(_health_pb2.HealthCheckResponse(status=status))

@ -38,29 +38,170 @@ def _consume_responses(response_iterator, response_queue):
response_queue.put(response)
class HealthServicerTest(unittest.TestCase):
class BaseWatchTests(object):
class WatchTests(unittest.TestCase):
def start_server(self, servicer):
self._servicer = servicer
self._servicer.set('', health_pb2.HealthCheckResponse.SERVING)
self._servicer.set(_SERVING_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
self._servicer.set(_UNKNOWN_SERVICE,
health_pb2.HealthCheckResponse.UNKNOWN)
self._servicer.set(_NOT_SERVING_SERVICE,
health_pb2.HealthCheckResponse.NOT_SERVING)
self._server = test_common.test_server()
port = self._server.add_insecure_port('[::]:0')
health_pb2_grpc.add_HealthServicer_to_server(
self._servicer, self._server)
self._server.start()
self._channel = grpc.insecure_channel('localhost:%d' % port)
self._stub = health_pb2_grpc.HealthStub(self._channel)
def tearDown(self):
self._server.stop(None)
self._channel.close()
def test_watch_empty_service(self):
request = health_pb2.HealthCheckRequest(service='')
response_queue = queue.Queue()
rendezvous = self._stub.Watch(request)
thread = threading.Thread(
target=_consume_responses, args=(rendezvous, response_queue))
thread.start()
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
response.status)
rendezvous.cancel()
thread.join()
self.assertTrue(response_queue.empty())
def test_watch_new_service(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue = queue.Queue()
rendezvous = self._stub.Watch(request)
thread = threading.Thread(
target=_consume_responses, args=(rendezvous, response_queue))
thread.start()
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response.status)
self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
response.status)
self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.NOT_SERVING)
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING,
response.status)
rendezvous.cancel()
thread.join()
self.assertTrue(response_queue.empty())
def test_watch_service_isolation(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue = queue.Queue()
rendezvous = self._stub.Watch(request)
thread = threading.Thread(
target=_consume_responses, args=(rendezvous, response_queue))
thread.start()
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response.status)
self._servicer.set('some-other-service',
health_pb2.HealthCheckResponse.SERVING)
with self.assertRaises(queue.Empty):
response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
rendezvous.cancel()
thread.join()
self.assertTrue(response_queue.empty())
def test_two_watchers(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue1 = queue.Queue()
response_queue2 = queue.Queue()
rendezvous1 = self._stub.Watch(request)
rendezvous2 = self._stub.Watch(request)
thread1 = threading.Thread(
target=_consume_responses, args=(rendezvous1, response_queue1))
thread2 = threading.Thread(
target=_consume_responses, args=(rendezvous2, response_queue2))
thread1.start()
thread2.start()
response1 = response_queue1.get(
timeout=test_constants.SHORT_TIMEOUT)
response2 = response_queue2.get(
timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response1.status)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response2.status)
self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
response1 = response_queue1.get(
timeout=test_constants.SHORT_TIMEOUT)
response2 = response_queue2.get(
timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
response1.status)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
response2.status)
rendezvous1.cancel()
rendezvous2.cancel()
thread1.join()
thread2.join()
self.assertTrue(response_queue1.empty())
self.assertTrue(response_queue2.empty())
def test_cancelled_watch_removed_from_watch_list(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue = queue.Queue()
rendezvous = self._stub.Watch(request)
thread = threading.Thread(
target=_consume_responses, args=(rendezvous, response_queue))
thread.start()
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response.status)
rendezvous.cancel()
self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
thread.join()
# Wait, if necessary, for serving thread to process client cancellation
timeout = time.time() + test_constants.SHORT_TIMEOUT
while time.time(
) < timeout and self._servicer._on_next_callbacks[_WATCH_SERVICE]:
time.sleep(1)
self.assertFalse(self._servicer._on_next_callbacks[_WATCH_SERVICE],
'watch set should be empty')
self.assertTrue(response_queue.empty())
class HealthServicerTest(BaseWatchTests.WatchTests):
def setUp(self):
self._servicer = health.HealthServicer()
self._servicer.set('', health_pb2.HealthCheckResponse.SERVING)
self._servicer.set(_SERVING_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
self._servicer.set(_UNKNOWN_SERVICE,
health_pb2.HealthCheckResponse.UNKNOWN)
self._servicer.set(_NOT_SERVING_SERVICE,
health_pb2.HealthCheckResponse.NOT_SERVING)
self._server = test_common.test_server()
port = self._server.add_insecure_port('[::]:0')
health_pb2_grpc.add_HealthServicer_to_server(self._servicer,
self._server)
self._server.start()
self._channel = grpc.insecure_channel('localhost:%d' % port)
self._stub = health_pb2_grpc.HealthStub(self._channel)
def tearDown(self):
self._server.stop(None)
self._channel.close()
super(HealthServicerTest, self).start_server(
health.HealthServicer(
experimental_non_blocking=False, experimental_thread_pool=None))
def test_check_empty_service(self):
request = health_pb2.HealthCheckRequest()
@ -90,135 +231,17 @@ class HealthServicerTest(unittest.TestCase):
self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code())
def test_watch_empty_service(self):
request = health_pb2.HealthCheckRequest(service='')
response_queue = queue.Queue()
rendezvous = self._stub.Watch(request)
thread = threading.Thread(
target=_consume_responses, args=(rendezvous, response_queue))
thread.start()
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
response.status)
rendezvous.cancel()
thread.join()
self.assertTrue(response_queue.empty())
def test_watch_new_service(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue = queue.Queue()
rendezvous = self._stub.Watch(request)
thread = threading.Thread(
target=_consume_responses, args=(rendezvous, response_queue))
thread.start()
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response.status)
self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
response.status)
self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.NOT_SERVING)
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING,
response.status)
rendezvous.cancel()
thread.join()
self.assertTrue(response_queue.empty())
def test_watch_service_isolation(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue = queue.Queue()
rendezvous = self._stub.Watch(request)
thread = threading.Thread(
target=_consume_responses, args=(rendezvous, response_queue))
thread.start()
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response.status)
self._servicer.set('some-other-service',
health_pb2.HealthCheckResponse.SERVING)
with self.assertRaises(queue.Empty):
response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
rendezvous.cancel()
thread.join()
self.assertTrue(response_queue.empty())
def test_two_watchers(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue1 = queue.Queue()
response_queue2 = queue.Queue()
rendezvous1 = self._stub.Watch(request)
rendezvous2 = self._stub.Watch(request)
thread1 = threading.Thread(
target=_consume_responses, args=(rendezvous1, response_queue1))
thread2 = threading.Thread(
target=_consume_responses, args=(rendezvous2, response_queue2))
thread1.start()
thread2.start()
response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT)
response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response1.status)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response2.status)
self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT)
response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
response1.status)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
response2.status)
rendezvous1.cancel()
rendezvous2.cancel()
thread1.join()
thread2.join()
self.assertTrue(response_queue1.empty())
self.assertTrue(response_queue2.empty())
def test_cancelled_watch_removed_from_watch_list(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue = queue.Queue()
rendezvous = self._stub.Watch(request)
thread = threading.Thread(
target=_consume_responses, args=(rendezvous, response_queue))
thread.start()
response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
response.status)
rendezvous.cancel()
self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
thread.join()
# Wait, if necessary, for serving thread to process client cancellation
timeout = time.time() + test_constants.SHORT_TIMEOUT
while time.time() < timeout and self._servicer._watchers[_WATCH_SERVICE]:
time.sleep(1)
self.assertFalse(self._servicer._watchers[_WATCH_SERVICE],
'watch set should be empty')
self.assertTrue(response_queue.empty())
def test_health_service_name(self):
self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health')
class HealthServicerBackwardsCompatibleWatchTest(BaseWatchTests.WatchTests):
def setUp(self):
super(HealthServicerBackwardsCompatibleWatchTest, self).start_server(
health.HealthServicer(
experimental_non_blocking=False, experimental_thread_pool=None))
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -23,6 +23,7 @@ import grpc
from grpc.framework.foundation import logging_pool
from tests.unit import test_common
from tests.unit import _thread_pool
from tests.unit.framework.common import test_constants
from tests.unit.framework.common import test_control
@ -33,8 +34,10 @@ _DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
_UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream'
_UNARY_STREAM_NON_BLOCKING = '/test/UnaryStreamNonBlocking'
_STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream'
_STREAM_STREAM_NON_BLOCKING = '/test/StreamStreamNonBlocking'
class _Callback(object):
@ -59,8 +62,14 @@ class _Callback(object):
class _Handler(object):
def __init__(self, control):
def __init__(self, control, thread_pool):
self._control = control
self._thread_pool = thread_pool
non_blocking_functions = (self.handle_unary_stream_non_blocking,
self.handle_stream_stream_non_blocking)
for non_blocking_function in non_blocking_functions:
non_blocking_function.__func__.experimental_non_blocking = True
non_blocking_function.__func__.experimental_thread_pool = self._thread_pool
def handle_unary_unary(self, request, servicer_context):
self._control.control()
@ -87,6 +96,20 @@ class _Handler(object):
'testvalue',
),))
def handle_unary_stream_non_blocking(self, request, servicer_context,
on_next):
for _ in range(test_constants.STREAM_LENGTH):
self._control.control()
on_next(request)
# yield request
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
on_next(None)
def handle_stream_unary(self, request_iterator, servicer_context):
if servicer_context is not None:
servicer_context.invocation_metadata()
@ -115,6 +138,20 @@ class _Handler(object):
yield request
self._control.control()
def handle_stream_stream_non_blocking(self, request_iterator,
servicer_context, on_next):
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
for request in request_iterator:
self._control.control()
on_next(request)
self._control.control()
on_next(None)
class _MethodHandler(grpc.RpcMethodHandler):
@ -145,6 +182,10 @@ class _GenericHandler(grpc.GenericRpcHandler):
return _MethodHandler(False, True, _DESERIALIZE_REQUEST,
_SERIALIZE_RESPONSE, None,
self._handler.handle_unary_stream, None, None)
elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING:
return _MethodHandler(
False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None,
self._handler.handle_unary_stream_non_blocking, None, None)
elif handler_call_details.method == _STREAM_UNARY:
return _MethodHandler(True, False, _DESERIALIZE_REQUEST,
_SERIALIZE_RESPONSE, None, None,
@ -152,6 +193,10 @@ class _GenericHandler(grpc.GenericRpcHandler):
elif handler_call_details.method == _STREAM_STREAM:
return _MethodHandler(True, True, None, None, None, None, None,
self._handler.handle_stream_stream)
elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING:
return _MethodHandler(
True, True, None, None, None, None, None,
self._handler.handle_stream_stream_non_blocking)
else:
return None
@ -167,6 +212,13 @@ def _unary_stream_multi_callable(channel):
response_deserializer=_DESERIALIZE_RESPONSE)
def _unary_stream_non_blocking_multi_callable(channel):
return channel.unary_stream(
_UNARY_STREAM_NON_BLOCKING,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_unary_multi_callable(channel):
return channel.stream_unary(
_STREAM_UNARY,
@ -178,11 +230,16 @@ def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
def _stream_stream_non_blocking_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING)
class RPCTest(unittest.TestCase):
def setUp(self):
self._control = test_control.PauseFailControl()
self._handler = _Handler(self._control)
self._thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
self._handler = _Handler(self._control, self._thread_pool)
self._server = test_common.test_server()
port = self._server.add_insecure_port('[::]:0')
@ -195,6 +252,16 @@ class RPCTest(unittest.TestCase):
self._server.stop(None)
self._channel.close()
def testDefaultThreadPoolIsUsed(self):
self._consume_one_stream_response_unary_request(
_unary_stream_multi_callable(self._channel))
self.assertFalse(self._thread_pool.was_used())
def testExperimentalThreadPoolIsUsed(self):
self._consume_one_stream_response_unary_request(
_unary_stream_non_blocking_multi_callable(self._channel))
self.assertTrue(self._thread_pool.was_used())
def testUnrecognizedMethod(self):
request = b'abc'
@ -227,7 +294,7 @@ class RPCTest(unittest.TestCase):
self.assertEqual(expected_response, response)
self.assertIs(grpc.StatusCode.OK, call.code())
self.assertEqual("", call.debug_error_string())
self.assertEqual('', call.debug_error_string())
def testSuccessfulUnaryRequestFutureUnaryResponse(self):
request = b'\x07\x08'
@ -310,6 +377,7 @@ class RPCTest(unittest.TestCase):
def testSuccessfulStreamRequestStreamResponse(self):
requests = tuple(
b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH))
expected_responses = tuple(
self._handler.handle_stream_stream(iter(requests), None))
request_iterator = iter(requests)
@ -425,58 +493,36 @@ class RPCTest(unittest.TestCase):
test_is_running_cell[0] = False
def testConsumingOneStreamResponseUnaryRequest(self):
request = b'\x57\x38'
self._consume_one_stream_response_unary_request(
_unary_stream_multi_callable(self._channel))
multi_callable = _unary_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request,
metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),))
next(response_iterator)
def testConsumingOneStreamResponseUnaryRequestNonBlocking(self):
self._consume_one_stream_response_unary_request(
_unary_stream_non_blocking_multi_callable(self._channel))
def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self):
request = b'\x57\x38'
self._consume_some_but_not_all_stream_responses_unary_request(
_unary_stream_multi_callable(self._channel))
multi_callable = _unary_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request,
metadata=(('test',
'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator)
def testConsumingSomeButNotAllStreamResponsesUnaryRequestNonBlocking(self):
self._consume_some_but_not_all_stream_responses_unary_request(
_unary_stream_non_blocking_multi_callable(self._channel))
def testConsumingSomeButNotAllStreamResponsesStreamRequest(self):
requests = tuple(
b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._consume_some_but_not_all_stream_responses_stream_request(
_stream_stream_multi_callable(self._channel))
multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request_iterator,
metadata=(('test',
'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator)
def testConsumingSomeButNotAllStreamResponsesStreamRequestNonBlocking(self):
self._consume_some_but_not_all_stream_responses_stream_request(
_stream_stream_non_blocking_multi_callable(self._channel))
def testConsumingTooManyStreamResponsesStreamRequest(self):
requests = tuple(
b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._consume_too_many_stream_responses_stream_request(
_stream_stream_multi_callable(self._channel))
multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request_iterator,
metadata=(('test',
'ConsumingTooManyStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH):
next(response_iterator)
for _ in range(test_constants.STREAM_LENGTH):
with self.assertRaises(StopIteration):
next(response_iterator)
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.OK, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def testConsumingTooManyStreamResponsesStreamRequestNonBlocking(self):
self._consume_too_many_stream_responses_stream_request(
_stream_stream_non_blocking_multi_callable(self._channel))
def testCancelledUnaryRequestUnaryResponse(self):
request = b'\x07\x17'
@ -498,24 +544,12 @@ class RPCTest(unittest.TestCase):
self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
def testCancelledUnaryRequestStreamResponse(self):
request = b'\x07\x19'
multi_callable = _unary_stream_multi_callable(self._channel)
with self._control.pause():
response_iterator = multi_callable(
request,
metadata=(('test', 'CancelledUnaryRequestStreamResponse'),))
self._control.block_until_paused()
response_iterator.cancel()
self._cancelled_unary_request_stream_response(
_unary_stream_multi_callable(self._channel))
with self.assertRaises(grpc.RpcError) as exception_context:
next(response_iterator)
self.assertIs(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def testCancelledUnaryRequestStreamResponseNonBlocking(self):
self._cancelled_unary_request_stream_response(
_unary_stream_non_blocking_multi_callable(self._channel))
def testCancelledStreamRequestUnaryResponse(self):
requests = tuple(
@ -543,23 +577,12 @@ class RPCTest(unittest.TestCase):
self.assertIsNotNone(response_future.trailing_metadata())
def testCancelledStreamRequestStreamResponse(self):
requests = tuple(
b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._cancelled_stream_request_stream_response(
_stream_stream_multi_callable(self._channel))
multi_callable = _stream_stream_multi_callable(self._channel)
with self._control.pause():
response_iterator = multi_callable(
request_iterator,
metadata=(('test', 'CancelledStreamRequestStreamResponse'),))
response_iterator.cancel()
with self.assertRaises(grpc.RpcError):
next(response_iterator)
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def testCancelledStreamRequestStreamResponseNonBlocking(self):
self._cancelled_stream_request_stream_response(
_stream_stream_non_blocking_multi_callable(self._channel))
def testExpiredUnaryRequestBlockingUnaryResponse(self):
request = b'\x07\x17'
@ -608,21 +631,12 @@ class RPCTest(unittest.TestCase):
response_future.exception().code())
def testExpiredUnaryRequestStreamResponse(self):
request = b'\x07\x19'
self._expired_unary_request_stream_response(
_unary_stream_multi_callable(self._channel))
multi_callable = _unary_stream_multi_callable(self._channel)
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request,
timeout=test_constants.SHORT_TIMEOUT,
metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),))
next(response_iterator)
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
exception_context.exception.code())
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
response_iterator.code())
def testExpiredUnaryRequestStreamResponseNonBlocking(self):
self._expired_unary_request_stream_response(
_unary_stream_non_blocking_multi_callable(self._channel))
def testExpiredStreamRequestBlockingUnaryResponse(self):
requests = tuple(
@ -678,23 +692,12 @@ class RPCTest(unittest.TestCase):
self.assertIsNotNone(response_future.trailing_metadata())
def testExpiredStreamRequestStreamResponse(self):
requests = tuple(
b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request_iterator,
timeout=test_constants.SHORT_TIMEOUT,
metadata=(('test', 'ExpiredStreamRequestStreamResponse'),))
next(response_iterator)
self._expired_stream_request_stream_response(
_stream_stream_multi_callable(self._channel))
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
exception_context.exception.code())
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
response_iterator.code())
def testExpiredStreamRequestStreamResponseNonBlocking(self):
self._expired_stream_request_stream_response(
_stream_stream_non_blocking_multi_callable(self._channel))
def testFailedUnaryRequestBlockingUnaryResponse(self):
request = b'\x37\x17'
@ -712,10 +715,10 @@ class RPCTest(unittest.TestCase):
# sanity checks on to make sure returned string contains default members
# of the error
debug_error_string = exception_context.exception.debug_error_string()
self.assertIn("created", debug_error_string)
self.assertIn("description", debug_error_string)
self.assertIn("file", debug_error_string)
self.assertIn("file_line", debug_error_string)
self.assertIn('created', debug_error_string)
self.assertIn('description', debug_error_string)
self.assertIn('file', debug_error_string)
self.assertIn('file_line', debug_error_string)
def testFailedUnaryRequestFutureUnaryResponse(self):
request = b'\x37\x17'
@ -742,18 +745,12 @@ class RPCTest(unittest.TestCase):
self.assertIs(response_future, value_passed_to_callback)
def testFailedUnaryRequestStreamResponse(self):
request = b'\x37\x17'
self._failed_unary_request_stream_response(
_unary_stream_multi_callable(self._channel))
multi_callable = _unary_stream_multi_callable(self._channel)
with self.assertRaises(grpc.RpcError) as exception_context:
with self._control.fail():
response_iterator = multi_callable(
request,
metadata=(('test', 'FailedUnaryRequestStreamResponse'),))
next(response_iterator)
self.assertIs(grpc.StatusCode.UNKNOWN,
exception_context.exception.code())
def testFailedUnaryRequestStreamResponseNonBlocking(self):
self._failed_unary_request_stream_response(
_unary_stream_non_blocking_multi_callable(self._channel))
def testFailedStreamRequestBlockingUnaryResponse(self):
requests = tuple(
@ -795,21 +792,12 @@ class RPCTest(unittest.TestCase):
self.assertIs(response_future, value_passed_to_callback)
def testFailedStreamRequestStreamResponse(self):
requests = tuple(
b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._failed_stream_request_stream_response(
_stream_stream_multi_callable(self._channel))
multi_callable = _stream_stream_multi_callable(self._channel)
with self._control.fail():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request_iterator,
metadata=(('test', 'FailedStreamRequestStreamResponse'),))
tuple(response_iterator)
self.assertIs(grpc.StatusCode.UNKNOWN,
exception_context.exception.code())
self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
def testFailedStreamRequestStreamResponseNonBlocking(self):
self._failed_stream_request_stream_response(
_stream_stream_non_blocking_multi_callable(self._channel))
def testIgnoredUnaryRequestFutureUnaryResponse(self):
request = b'\x37\x17'
@ -820,11 +808,12 @@ class RPCTest(unittest.TestCase):
metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),))
def testIgnoredUnaryRequestStreamResponse(self):
request = b'\x37\x17'
self._ignored_unary_stream_request_future_unary_response(
_unary_stream_multi_callable(self._channel))
multi_callable = _unary_stream_multi_callable(self._channel)
multi_callable(
request, metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),))
def testIgnoredUnaryRequestStreamResponseNonBlocking(self):
self._ignored_unary_stream_request_future_unary_response(
_unary_stream_non_blocking_multi_callable(self._channel))
def testIgnoredStreamRequestFutureUnaryResponse(self):
requests = tuple(
@ -837,11 +826,177 @@ class RPCTest(unittest.TestCase):
metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),))
def testIgnoredStreamRequestStreamResponse(self):
self._ignored_stream_request_stream_response(
_stream_stream_multi_callable(self._channel))
def testIgnoredStreamRequestStreamResponseNonBlocking(self):
self._ignored_stream_request_stream_response(
_stream_stream_non_blocking_multi_callable(self._channel))
def _consume_one_stream_response_unary_request(self, multi_callable):
request = b'\x57\x38'
response_iterator = multi_callable(
request,
metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),))
next(response_iterator)
def _consume_some_but_not_all_stream_responses_unary_request(
self, multi_callable):
request = b'\x57\x38'
response_iterator = multi_callable(
request,
metadata=(('test',
'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator)
def _consume_some_but_not_all_stream_responses_stream_request(
self, multi_callable):
requests = tuple(
b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
response_iterator = multi_callable(
request_iterator,
metadata=(('test',
'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator)
def _consume_too_many_stream_responses_stream_request(self, multi_callable):
requests = tuple(
b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
response_iterator = multi_callable(
request_iterator,
metadata=(('test',
'ConsumingTooManyStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH):
next(response_iterator)
for _ in range(test_constants.STREAM_LENGTH):
with self.assertRaises(StopIteration):
next(response_iterator)
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.OK, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def _cancelled_unary_request_stream_response(self, multi_callable):
request = b'\x07\x19'
with self._control.pause():
response_iterator = multi_callable(
request,
metadata=(('test', 'CancelledUnaryRequestStreamResponse'),))
self._control.block_until_paused()
response_iterator.cancel()
with self.assertRaises(grpc.RpcError) as exception_context:
next(response_iterator)
self.assertIs(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def _cancelled_stream_request_stream_response(self, multi_callable):
requests = tuple(
b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
with self._control.pause():
response_iterator = multi_callable(
request_iterator,
metadata=(('test', 'CancelledStreamRequestStreamResponse'),))
response_iterator.cancel()
with self.assertRaises(grpc.RpcError):
next(response_iterator)
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def _expired_unary_request_stream_response(self, multi_callable):
request = b'\x07\x19'
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request,
timeout=test_constants.SHORT_TIMEOUT,
metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),))
next(response_iterator)
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
exception_context.exception.code())
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
response_iterator.code())
def _expired_stream_request_stream_response(self, multi_callable):
requests = tuple(
b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request_iterator,
timeout=test_constants.SHORT_TIMEOUT,
metadata=(('test', 'ExpiredStreamRequestStreamResponse'),))
next(response_iterator)
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
exception_context.exception.code())
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
response_iterator.code())
def _failed_unary_request_stream_response(self, multi_callable):
request = b'\x37\x17'
with self.assertRaises(grpc.RpcError) as exception_context:
with self._control.fail():
response_iterator = multi_callable(
request,
metadata=(('test', 'FailedUnaryRequestStreamResponse'),))
next(response_iterator)
self.assertIs(grpc.StatusCode.UNKNOWN,
exception_context.exception.code())
def _failed_stream_request_stream_response(self, multi_callable):
requests = tuple(
b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
with self._control.fail():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request_iterator,
metadata=(('test', 'FailedStreamRequestStreamResponse'),))
tuple(response_iterator)
self.assertIs(grpc.StatusCode.UNKNOWN,
exception_context.exception.code())
self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
def _ignored_unary_stream_request_future_unary_response(
self, multi_callable):
request = b'\x37\x17'
multi_callable(
request, metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),))
def _ignored_stream_request_stream_response(self, multi_callable):
requests = tuple(
b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
multi_callable(
request_iterator,
metadata=(('test', 'IgnoredStreamRequestStreamResponse'),))

Loading…
Cancel
Save