diff --git a/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py b/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py index 4e2ce6a6c0b..8c438cc8d9b 100644 --- a/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py +++ b/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py @@ -33,11 +33,16 @@ from tests.unit.framework.common import get_socket _NUM_CORES = multiprocessing.cpu_count() _NUM_CORE_PYTHON_CAN_USE = 1 +_WORKER_ENTRY_FILE = os.path.split(os.path.abspath(__file__))[0] + '/worker.py' +_SubWorker = collections.namedtuple( + '_SubWorker', ['process', 'port', 'channel', 'stub']) + _LOGGER = logging.getLogger(__name__) -def _get_server_status(start_time: float, end_time: float, +def _get_server_status(start_time: float, + end_time: float, port: int) -> control_pb2.ServerStatus: end_time = time.time() elapsed_time = end_time - start_time @@ -46,11 +51,13 @@ def _get_server_status(start_time: float, end_time: float, time_system=elapsed_time) return control_pb2.ServerStatus(stats=stats, port=port, - cores=_NUM_CORE_PYTHON_CAN_USE) + cores=_NUM_CORES) def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: - server = aio.server() + server = aio.server(options=( + ('grpc.so_reuseport', 1), + )) if config.server_type == control_pb2.ASYNC_SERVER: servicer = benchmark_servicer.BenchmarkServicer() benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( @@ -84,7 +91,7 @@ def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: def _get_client_status(start_time: float, end_time: float, qps_data: histogram.Histogram - ) -> control_pb2.ClientStatus: + ) -> control_pb2.ClientStatus: latencies = qps_data.get_data() end_time = time.time() elapsed_time = end_time - start_time @@ -97,7 +104,7 @@ def _get_client_status(start_time: float, end_time: float, def _create_client(server: str, config: control_pb2.ClientConfig, qps_data: histogram.Histogram - ) -> benchmark_client.BenchmarkClient: + ) -> benchmark_client.BenchmarkClient: if config.load_params.WhichOneof('load') != 'closed_loop': raise NotImplementedError( f'Unsupported load parameter {config.load_params}') @@ -117,25 +124,28 @@ def _create_client(server: str, config: control_pb2.ClientConfig, return client_type(server, config, qps_data) -WORKER_ENTRY_FILE = os.path.split(os.path.abspath(__file__))[0] + '/worker.py' -SubWorker = collections.namedtuple('SubWorker', ['process', 'port', 'channel', 'stub']) +def _pick_an_unused_port() -> int: + _, port, sock = get_socket() + sock.close() + return port + +async def _create_sub_worker() -> _SubWorker: + port = _pick_an_unused_port() -async def _create_sub_worker() -> SubWorker: - address, port, sock = get_socket() - sock.close() _LOGGER.info('Creating sub worker at port [%d]...', port) process = await asyncio.create_subprocess_exec( sys.executable, - WORKER_ENTRY_FILE, + _WORKER_ENTRY_FILE, '--driver_port', str(port) ) - _LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port, process.pid) - channel = aio.insecure_channel(f'{address}:{port}') + _LOGGER.info( + 'Created sub worker process for port [%d] at pid [%d]', port, process.pid) + channel = aio.insecure_channel(f'localhost:{port}') _LOGGER.info('Waiting for sub worker at port [%d]', port) await aio.channel_ready(channel) stub = worker_service_pb2_grpc.WorkerServiceStub(channel) - return SubWorker( + return _SubWorker( process=process, port=port, channel=channel, @@ -150,34 +160,89 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): self._loop = asyncio.get_event_loop() self._quit_event = asyncio.Event() - # async def _run_single_server(self, config, request_iterator, context): - # server, port = _create_server(config) - # await server.start() - - async def RunServer(self, request_iterator, context): - config = (await context.read()).setup - _LOGGER.info('Received ServerConfig: %s', config) - - if config.async_server_threads <= 0: - _LOGGER.info('async_server_threads can\'t be [%d]', config.async_server_threads) - _LOGGER.info('Using async_server_threads == [%d]', _NUM_CORES) - config.async_server_threads = _NUM_CORES - + async def _run_single_server(self, config, request_iterator, context): server, port = _create_server(config) await server.start() _LOGGER.info('Server started at port [%d]', port) start_time = time.time() - yield _get_server_status(start_time, start_time, port) + await context.write(_get_server_status(start_time, start_time, port)) async for request in request_iterator: end_time = time.time() status = _get_server_status(start_time, end_time, port) if request.mark.reset: start_time = end_time - yield status + await context.write(status) await server.stop(None) + async def RunServer(self, request_iterator, context): + config_request = await context.read() + config = config_request.setup + _LOGGER.info('Received ServerConfig: %s', config) + + if config.async_server_threads <= 0: + _LOGGER.info( + 'async_server_threads can\'t be [%d]', config.async_server_threads) + _LOGGER.info('Using async_server_threads == [%d]', _NUM_CORES) + config.async_server_threads = _NUM_CORES + + if config.port == 0: + config.port = _pick_an_unused_port() + _LOGGER.info('Port picked [%d]', config.port) + + if config.async_server_threads == 1: + await self._run_single_server(config, request_iterator, context) + else: + sub_workers = await asyncio.gather(*( + _create_sub_worker() + for _ in range(config.async_server_threads) + )) + + calls = [worker.stub.RunServer() for worker in sub_workers] + + config_request.setup.async_server_threads = 1 + + for call in calls: + await call.write(config_request) + # An empty status indicates the peer is ready + await call.read() + + start_time = time.time() + await context.write(_get_server_status( + start_time, + start_time, + config.port, + )) + + async for request in request_iterator: + end_time = time.time() + + for call in calls: + _LOGGER.debug('Fetching status...') + await call.write(request) + # Reports from sub workers doesn't matter + await call.read() + + status = _get_server_status( + start_time, + end_time, + config.port, + ) + if request.mark.reset: + start_time = end_time + await context.write(status) + + for call in calls: + await call.done_writing() + + for worker in sub_workers: + await worker.stub.QuitWorker(control_pb2.Void()) + await worker.channel.close() + _LOGGER.info('Waiting for sub worker [%s] to quit...', worker) + await worker.process.wait() + _LOGGER.info('Sub worker [%s] quit', worker) + async def _run_single_client(self, config, request_iterator, context): running_tasks = [] qps_data = histogram.Histogram(config.histogram_params.resolution, @@ -213,7 +278,8 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): _LOGGER.info('Received ClientConfig: %s', config) if config.async_client_threads <= 0: - _LOGGER.info('async_client_threads can\'t be [%d]', config.async_client_threads) + _LOGGER.info( + 'async_client_threads can\'t be [%d]', config.async_client_threads) _LOGGER.info('Using async_client_threads == [%d]', _NUM_CORES) config.async_client_threads = _NUM_CORES @@ -231,7 +297,7 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): for call in calls: await call.write(config_request) - # An empty status + # An empty status indicates the peer is ready await call.read() start_time = time.time() @@ -255,7 +321,8 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): if request.mark.reset: result.reset() start_time = time.time() - _LOGGER.debug('Reporting count=[%d]', status.stats.latencies.count) + _LOGGER.debug( + 'Reporting count=[%d]', status.stats.latencies.count) await context.write(status) for call in calls: