Add max_requests argument to server

If the server is already serving max requests, return RESOURCE_EXHAUSTED
pull/10022/head
Ken Payson 8 years ago
parent eb064ec7b8
commit 39a5932097
  1. 11
      src/python/grpcio/grpc/__init__.py
  2. 78
      src/python/grpcio/grpc/_server.py
  3. 1
      src/python/grpcio_tests/tests/tests.json
  4. 270
      src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py

@ -1273,7 +1273,10 @@ def secure_channel(target, credentials, options=None):
credentials._credentials)
def server(thread_pool, handlers=None, options=None):
def server(thread_pool,
handlers=None,
options=None,
maximum_concurrent_rpcs=None):
"""Creates a Server with which RPCs can be serviced.
Args:
@ -1286,13 +1289,17 @@ def server(thread_pool, handlers=None, options=None):
returned Server is started.
options: A sequence of string-value pairs according to which to configure
the created server.
maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server
will service before returning status RESOURCE_EXHAUSTED, or None to
indicate no limit.
Returns:
A Server with which RPCs can be serviced.
"""
from grpc import _server # pylint: disable=cyclic-import
return _server.Server(thread_pool, () if handlers is None else handlers, ()
if options is None else options)
if options is None else options,
maximum_concurrent_rpcs)
################################### __all__ #################################

@ -504,7 +504,7 @@ def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
thread_pool.submit(_unary_response_in_pool, rpc_event, state,
return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
method_handler.unary_unary, unary_request,
method_handler.request_deserializer,
method_handler.response_serializer)
@ -513,7 +513,7 @@ def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
thread_pool.submit(_stream_response_in_pool, rpc_event, state,
return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
method_handler.unary_stream, unary_request,
method_handler.request_deserializer,
method_handler.response_serializer)
@ -522,19 +522,19 @@ def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(state, rpc_event.operation_call,
method_handler.request_deserializer)
thread_pool.submit(_unary_response_in_pool, rpc_event, state,
method_handler.stream_unary, lambda: request_iterator,
method_handler.request_deserializer,
return thread_pool.submit(
_unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
lambda: request_iterator, method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(state, rpc_event.operation_call,
method_handler.request_deserializer)
thread_pool.submit(_stream_response_in_pool, rpc_event, state,
return thread_pool.submit(
_stream_response_in_pool, rpc_event, state,
method_handler.stream_stream, lambda: request_iterator,
method_handler.request_deserializer,
method_handler.response_serializer)
method_handler.request_deserializer, method_handler.response_serializer)
def _find_method_handler(rpc_event, generic_handlers):
@ -549,13 +549,12 @@ def _find_method_handler(rpc_event, generic_handlers):
return None
def _handle_unrecognized_method(rpc_event):
def _reject_rpc(rpc_event, status, details):
operations = (cygrpc.operation_send_initial_metadata(_common.EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_common.EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
b'Method not found!', _EMPTY_FLAGS),)
_common.EMPTY_METADATA, status, details, _EMPTY_FLAGS),)
rpc_state = _RPCState()
rpc_event.operation_call.start_server_batch(
operations, lambda ignored_event: (rpc_state, (),))
@ -572,33 +571,37 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
if method_handler.request_streaming:
if method_handler.response_streaming:
_handle_stream_stream(rpc_event, state, method_handler,
thread_pool)
return state, _handle_stream_stream(rpc_event, state,
method_handler, thread_pool)
else:
_handle_stream_unary(rpc_event, state, method_handler,
thread_pool)
return state, _handle_stream_unary(rpc_event, state,
method_handler, thread_pool)
else:
if method_handler.response_streaming:
_handle_unary_stream(rpc_event, state, method_handler,
thread_pool)
return state, _handle_unary_stream(rpc_event, state,
method_handler, thread_pool)
else:
_handle_unary_unary(rpc_event, state, method_handler,
thread_pool)
return state
return state, _handle_unary_unary(rpc_event, state,
method_handler, thread_pool)
def _handle_call(rpc_event, generic_handlers, thread_pool):
def _handle_call(rpc_event, generic_handlers, thread_pool,
concurrency_exceeded):
if not rpc_event.success:
return None
return None, None
if rpc_event.request_call_details.method is not None:
method_handler = _find_method_handler(rpc_event, generic_handlers)
if method_handler is None:
return _handle_unrecognized_method(rpc_event)
return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented,
b'Method not found!'), None
elif concurrency_exceeded:
return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted,
b'Concurrent RPC limit exceeded!'), None
else:
return _handle_with_method_handler(rpc_event, method_handler,
thread_pool)
else:
return None
return None, None
@enum.unique
@ -610,7 +613,8 @@ class _ServerStage(enum.Enum):
class _ServerState(object):
def __init__(self, completion_queue, server, generic_handlers, thread_pool):
def __init__(self, completion_queue, server, generic_handlers, thread_pool,
maximum_concurrent_rpcs):
self.lock = threading.Lock()
self.completion_queue = completion_queue
self.server = server
@ -618,6 +622,8 @@ class _ServerState(object):
self.thread_pool = thread_pool
self.stage = _ServerStage.STOPPED
self.shutdown_events = None
self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
self.active_rpc_count = 0
# TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
self.rpc_states = set()
@ -657,6 +663,11 @@ def _stop_serving(state):
return False
def _on_call_completed(state):
with state.lock:
state.active_rpc_count -= 1
def _serve(state):
while True:
event = state.completion_queue.poll()
@ -668,10 +679,18 @@ def _serve(state):
elif event.tag is _REQUEST_CALL_TAG:
with state.lock:
state.due.remove(_REQUEST_CALL_TAG)
rpc_state = _handle_call(event, state.generic_handlers,
state.thread_pool)
concurrency_exceeded = (
state.maximum_concurrent_rpcs is not None and
state.active_rpc_count >= state.maximum_concurrent_rpcs)
rpc_state, rpc_future = _handle_call(
event, state.generic_handlers, state.thread_pool,
concurrency_exceeded)
if rpc_state is not None:
state.rpc_states.add(rpc_state)
if rpc_future is not None:
state.active_rpc_count += 1
rpc_future.add_done_callback(
lambda unused_future: _on_call_completed(state))
if state.stage is _ServerStage.STARTED:
_request_call(state)
elif _stop_serving(state):
@ -749,12 +768,13 @@ def _start(state):
class Server(grpc.Server):
def __init__(self, thread_pool, generic_handlers, options):
def __init__(self, thread_pool, generic_handlers, options,
maximum_concurrent_rpcs):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(_common.channel_args(options))
server.register_completion_queue(completion_queue)
self._state = _ServerState(completion_queue, server, generic_handlers,
thread_pool)
thread_pool, maximum_concurrent_rpcs)
def add_generic_rpc_handlers(self, generic_rpc_handlers):
_add_generic_handlers(self._state, generic_rpc_handlers)

@ -31,6 +31,7 @@
"unit._invocation_defects_test.InvocationDefectsTest",
"unit._metadata_code_details_test.MetadataCodeDetailsTest",
"unit._metadata_test.MetadataTest",
"unit._resource_exhausted_test.ResourceExhaustedTest",
"unit._rpc_test.RPCTest",
"unit._sanity._sanity_test.Sanity",
"unit._thread_cleanup_test.CleanupThreadTest",

@ -0,0 +1,270 @@
# Copyright 2017, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests server responding with RESOURCE_EXHAUSTED."""
import threading
import unittest
import grpc
from grpc import _channel
from grpc.framework.foundation import logging_pool
from tests.unit import test_common
from tests.unit.framework.common import test_constants
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x00\x00\x00'
_UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream'
_STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream'
class _TestTrigger(object):
def __init__(self, total_call_count):
self._total_call_count = total_call_count
self._pending_calls = 0
self._triggered = False
self._finish_condition = threading.Condition()
self._start_condition = threading.Condition()
# Wait for all calls be be blocked in their handler
def await_calls(self):
with self._start_condition:
while self._pending_calls < self._total_call_count:
self._start_condition.wait()
# Block in a response handler and wait for a trigger
def await_trigger(self):
with self._start_condition:
self._pending_calls += 1
self._start_condition.notify()
with self._finish_condition:
if not self._triggered:
self._finish_condition.wait()
# Finish all response handlers
def trigger(self):
with self._finish_condition:
self._triggered = True
self._finish_condition.notify_all()
def handle_unary_unary(trigger, request, servicer_context):
trigger.await_trigger()
return _RESPONSE
def handle_unary_stream(trigger, request, servicer_context):
trigger.await_trigger()
for _ in range(test_constants.STREAM_LENGTH):
yield _RESPONSE
def handle_stream_unary(trigger, request_iterator, servicer_context):
trigger.await_trigger()
# TODO(issue:#6891) We should be able to remove this loop
for request in request_iterator:
pass
return _RESPONSE
def handle_stream_stream(trigger, request_iterator, servicer_context):
trigger.await_trigger()
# TODO(issue:#6891) We should be able to remove this loop,
# and replace with return; yield
for request in request_iterator:
yield _RESPONSE
class _MethodHandler(grpc.RpcMethodHandler):
def __init__(self, trigger, request_streaming, response_streaming):
self.request_streaming = request_streaming
self.response_streaming = response_streaming
self.request_deserializer = None
self.response_serializer = None
self.unary_unary = None
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
if self.request_streaming and self.response_streaming:
self.stream_stream = (
lambda x, y: handle_stream_stream(trigger, x, y))
elif self.request_streaming:
self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y)
elif self.response_streaming:
self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y)
else:
self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y)
class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self, trigger):
self._trigger = trigger
def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY:
return _MethodHandler(self._trigger, False, False)
elif handler_call_details.method == _UNARY_STREAM:
return _MethodHandler(self._trigger, False, True)
elif handler_call_details.method == _STREAM_UNARY:
return _MethodHandler(self._trigger, True, False)
elif handler_call_details.method == _STREAM_STREAM:
return _MethodHandler(self._trigger, True, True)
else:
return None
class ResourceExhaustedTest(unittest.TestCase):
def setUp(self):
self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY)
self._server = grpc.server(
self._server_pool,
handlers=(_GenericHandler(self._trigger),),
maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY)
port = self._server.add_insecure_port('[::]:0')
self._server.start()
self._channel = grpc.insecure_channel('localhost:%d' % port)
def tearDown(self):
self._server.stop(0)
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
futures = []
for _ in range(test_constants.THREAD_CONCURRENCY):
futures.append(multi_callable.future(_REQUEST))
self._trigger.await_calls()
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(_REQUEST)
self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
exception_context.exception.code())
future_exception = multi_callable.future(_REQUEST)
self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
future_exception.exception().code())
self._trigger.trigger()
for future in futures:
self.assertEqual(_RESPONSE, future.result())
# Ensure a new request can be handled
self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM)
calls = []
for _ in range(test_constants.THREAD_CONCURRENCY):
calls.append(multi_callable(_REQUEST))
self._trigger.await_calls()
with self.assertRaises(grpc.RpcError) as exception_context:
next(multi_callable(_REQUEST))
self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
exception_context.exception.code())
self._trigger.trigger()
for call in calls:
for response in call:
self.assertEqual(_RESPONSE, response)
# Ensure a new request can be handled
new_call = multi_callable(_REQUEST)
for response in new_call:
self.assertEqual(_RESPONSE, response)
def testStreamUnary(self):
multi_callable = self._channel.stream_unary(_STREAM_UNARY)
futures = []
request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
for _ in range(test_constants.THREAD_CONCURRENCY):
futures.append(multi_callable.future(request))
self._trigger.await_calls()
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(request)
self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
exception_context.exception.code())
future_exception = multi_callable.future(request)
self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
future_exception.exception().code())
self._trigger.trigger()
for future in futures:
self.assertEqual(_RESPONSE, future.result())
# Ensure a new request can be handled
self.assertEqual(_RESPONSE, multi_callable(request))
def testStreamStream(self):
multi_callable = self._channel.stream_stream(_STREAM_STREAM)
calls = []
request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
for _ in range(test_constants.THREAD_CONCURRENCY):
calls.append(multi_callable(request))
self._trigger.await_calls()
with self.assertRaises(grpc.RpcError) as exception_context:
next(multi_callable(request))
self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
exception_context.exception.code())
self._trigger.trigger()
for call in calls:
for response in call:
self.assertEqual(_RESPONSE, response)
# Ensure a new request can be handled
new_call = multi_callable(request)
for response in new_call:
self.assertEqual(_RESPONSE, response)
if __name__ == '__main__':
unittest.main(verbosity=2)
Loading…
Cancel
Save