pull/18095/head
Eric Gribkoff 6 years ago
parent 0346ec2f45
commit a5c96cf765
  1. 31
      src/python/grpcio_health_checking/grpc_health/v1/health.py
  2. 18
      src/python/grpcio_tests/tests/health_check/_health_servicer_test.py

@ -60,15 +60,15 @@ class _Watcher():
self._condition.notify() self._condition.notify()
def _watcher_to_on_next_adapter(watcher): def _watcher_to_on_next_callback_adapter(watcher):
def on_next(response): def on_next_callback(response):
if response is None: if response is None:
watcher.close() watcher.close()
else: else:
watcher.add(response) watcher.add(response)
return on_next return on_next_callback
class HealthServicer(_health_pb2_grpc.HealthServicer): class HealthServicer(_health_pb2_grpc.HealthServicer):
@ -83,12 +83,12 @@ class HealthServicer(_health_pb2_grpc.HealthServicer):
self.Watch.__func__.experimental_non_blocking = experimental_non_blocking self.Watch.__func__.experimental_non_blocking = experimental_non_blocking
self.Watch.__func__.experimental_thread_pool = experimental_thread_pool self.Watch.__func__.experimental_thread_pool = experimental_thread_pool
def _on_close_callback(self, on_next, service): def _on_close_callback(self, on_next_callback, service):
def callback(): def callback():
with self._lock: with self._lock:
self._on_next_callbacks[service].remove(on_next) self._on_next_callbacks[service].remove(on_next_callback)
on_next(None) on_next_callback(None)
return callback return callback
@ -102,24 +102,26 @@ class HealthServicer(_health_pb2_grpc.HealthServicer):
return _health_pb2.HealthCheckResponse(status=status) return _health_pb2.HealthCheckResponse(status=status)
# pylint: disable=arguments-differ # pylint: disable=arguments-differ
def Watch(self, request, context, on_next=None): def Watch(self, request, context, on_next_callback=None):
blocking_watcher = None blocking_watcher = None
if on_next is None: if on_next_callback is None:
# The server does not support the experimental_non_blocking # The server does not support the experimental_non_blocking
# parameter. For backwards compatibility, return a blocking response # parameter. For backwards compatibility, return a blocking response
# generator. # generator.
blocking_watcher = _Watcher() blocking_watcher = _Watcher()
on_next = _watcher_to_on_next_adapter(blocking_watcher) on_next_callback = _watcher_to_on_next_callback_adapter(
blocking_watcher)
service = request.service service = request.service
with self._lock: with self._lock:
status = self._server_status.get(service) status = self._server_status.get(service)
if status is None: if status is None:
status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member
on_next(_health_pb2.HealthCheckResponse(status=status)) on_next_callback(_health_pb2.HealthCheckResponse(status=status))
if service not in self._on_next_callbacks: if service not in self._on_next_callbacks:
self._on_next_callbacks[service] = set() self._on_next_callbacks[service] = set()
self._on_next_callbacks[service].add(on_next) self._on_next_callbacks[service].add(on_next_callback)
context.add_callback(self._on_close_callback(on_next, service)) context.add_callback(
self._on_close_callback(on_next_callback, service))
return blocking_watcher return blocking_watcher
def set(self, service, status): def set(self, service, status):
@ -133,5 +135,6 @@ class HealthServicer(_health_pb2_grpc.HealthServicer):
with self._lock: with self._lock:
self._server_status[service] = status self._server_status[service] = status
if service in self._on_next_callbacks: if service in self._on_next_callbacks:
for on_next in self._on_next_callbacks[service]: for on_next_callback in self._on_next_callbacks[service]:
on_next(_health_pb2.HealthCheckResponse(status=status)) on_next_callback(
_health_pb2.HealthCheckResponse(status=status))

@ -23,6 +23,7 @@ from grpc_health.v1 import health_pb2
from grpc_health.v1 import health_pb2_grpc from grpc_health.v1 import health_pb2_grpc
from tests.unit import test_common from tests.unit import test_common
from tests.unit import _thread_pool
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from six.moves import queue from six.moves import queue
@ -42,8 +43,11 @@ class BaseWatchTests(object):
class WatchTests(unittest.TestCase): class WatchTests(unittest.TestCase):
def start_server(self, servicer): def start_server(self, non_blocking=False, thread_pool=None):
self._servicer = servicer self._thread_pool = thread_pool
self._servicer = health.HealthServicer(
experimental_non_blocking=non_blocking,
experimental_thread_pool=thread_pool)
self._servicer.set('', health_pb2.HealthCheckResponse.SERVING) self._servicer.set('', health_pb2.HealthCheckResponse.SERVING)
self._servicer.set(_SERVING_SERVICE, self._servicer.set(_SERVING_SERVICE,
health_pb2.HealthCheckResponse.SERVING) health_pb2.HealthCheckResponse.SERVING)
@ -80,6 +84,9 @@ class BaseWatchTests(object):
thread.join() thread.join()
self.assertTrue(response_queue.empty()) self.assertTrue(response_queue.empty())
if self._thread_pool is not None:
self.assertTrue(self._thread_pool.was_used())
def test_watch_new_service(self): def test_watch_new_service(self):
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
response_queue = queue.Queue() response_queue = queue.Queue()
@ -199,9 +206,9 @@ class BaseWatchTests(object):
class HealthServicerTest(BaseWatchTests.WatchTests): class HealthServicerTest(BaseWatchTests.WatchTests):
def setUp(self): def setUp(self):
self._thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
super(HealthServicerTest, self).start_server( super(HealthServicerTest, self).start_server(
health.HealthServicer( non_blocking=True, thread_pool=self._thread_pool)
experimental_non_blocking=False, experimental_thread_pool=None))
def test_check_empty_service(self): def test_check_empty_service(self):
request = health_pb2.HealthCheckRequest() request = health_pb2.HealthCheckRequest()
@ -239,8 +246,7 @@ class HealthServicerBackwardsCompatibleWatchTest(BaseWatchTests.WatchTests):
def setUp(self): def setUp(self):
super(HealthServicerBackwardsCompatibleWatchTest, self).start_server( super(HealthServicerBackwardsCompatibleWatchTest, self).start_server(
health.HealthServicer( non_blocking=False, thread_pool=None)
experimental_non_blocking=False, experimental_thread_pool=None))
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save