diff --git a/src/python/grpcio_health_checking/grpc_health/v1/health.py b/src/python/grpcio_health_checking/grpc_health/v1/health.py index f135ffbb645..c1bb998df90 100644 --- a/src/python/grpcio_health_checking/grpc_health/v1/health.py +++ b/src/python/grpcio_health_checking/grpc_health/v1/health.py @@ -60,15 +60,15 @@ class _Watcher(): self._condition.notify() -def _watcher_to_on_next_adapter(watcher): +def _watcher_to_on_next_callback_adapter(watcher): - def on_next(response): + def on_next_callback(response): if response is None: watcher.close() else: watcher.add(response) - return on_next + return on_next_callback class HealthServicer(_health_pb2_grpc.HealthServicer): @@ -83,12 +83,12 @@ class HealthServicer(_health_pb2_grpc.HealthServicer): self.Watch.__func__.experimental_non_blocking = experimental_non_blocking self.Watch.__func__.experimental_thread_pool = experimental_thread_pool - def _on_close_callback(self, on_next, service): + def _on_close_callback(self, on_next_callback, service): def callback(): with self._lock: - self._on_next_callbacks[service].remove(on_next) - on_next(None) + self._on_next_callbacks[service].remove(on_next_callback) + on_next_callback(None) return callback @@ -102,24 +102,26 @@ class HealthServicer(_health_pb2_grpc.HealthServicer): return _health_pb2.HealthCheckResponse(status=status) # pylint: disable=arguments-differ - def Watch(self, request, context, on_next=None): + def Watch(self, request, context, on_next_callback=None): blocking_watcher = None - if on_next is None: + if on_next_callback is None: # The server does not support the experimental_non_blocking # parameter. For backwards compatibility, return a blocking response # generator. blocking_watcher = _Watcher() - on_next = _watcher_to_on_next_adapter(blocking_watcher) + on_next_callback = _watcher_to_on_next_callback_adapter( + blocking_watcher) service = request.service with self._lock: status = self._server_status.get(service) if status is None: status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member - on_next(_health_pb2.HealthCheckResponse(status=status)) + on_next_callback(_health_pb2.HealthCheckResponse(status=status)) if service not in self._on_next_callbacks: self._on_next_callbacks[service] = set() - self._on_next_callbacks[service].add(on_next) - context.add_callback(self._on_close_callback(on_next, service)) + self._on_next_callbacks[service].add(on_next_callback) + context.add_callback( + self._on_close_callback(on_next_callback, service)) return blocking_watcher def set(self, service, status): @@ -133,5 +135,6 @@ class HealthServicer(_health_pb2_grpc.HealthServicer): with self._lock: self._server_status[service] = status if service in self._on_next_callbacks: - for on_next in self._on_next_callbacks[service]: - on_next(_health_pb2.HealthCheckResponse(status=status)) + for on_next_callback in self._on_next_callbacks[service]: + on_next_callback( + _health_pb2.HealthCheckResponse(status=status)) diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py index 3b8ee883bbe..2b1d17adfb0 100644 --- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py +++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py @@ -23,6 +23,7 @@ from grpc_health.v1 import health_pb2 from grpc_health.v1 import health_pb2_grpc from tests.unit import test_common +from tests.unit import _thread_pool from tests.unit.framework.common import test_constants from six.moves import queue @@ -42,8 +43,11 @@ class BaseWatchTests(object): class WatchTests(unittest.TestCase): - def start_server(self, servicer): - self._servicer = servicer + def start_server(self, non_blocking=False, thread_pool=None): + self._thread_pool = thread_pool + self._servicer = health.HealthServicer( + experimental_non_blocking=non_blocking, + experimental_thread_pool=thread_pool) self._servicer.set('', health_pb2.HealthCheckResponse.SERVING) self._servicer.set(_SERVING_SERVICE, health_pb2.HealthCheckResponse.SERVING) @@ -80,6 +84,9 @@ class BaseWatchTests(object): thread.join() self.assertTrue(response_queue.empty()) + if self._thread_pool is not None: + self.assertTrue(self._thread_pool.was_used()) + def test_watch_new_service(self): request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) response_queue = queue.Queue() @@ -199,9 +206,9 @@ class BaseWatchTests(object): class HealthServicerTest(BaseWatchTests.WatchTests): def setUp(self): + self._thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) super(HealthServicerTest, self).start_server( - health.HealthServicer( - experimental_non_blocking=False, experimental_thread_pool=None)) + non_blocking=True, thread_pool=self._thread_pool) def test_check_empty_service(self): request = health_pb2.HealthCheckRequest() @@ -239,8 +246,7 @@ class HealthServicerBackwardsCompatibleWatchTest(BaseWatchTests.WatchTests): def setUp(self): super(HealthServicerBackwardsCompatibleWatchTest, self).start_server( - health.HealthServicer( - experimental_non_blocking=False, experimental_thread_pool=None)) + non_blocking=False, thread_pool=None) if __name__ == '__main__':