Make server parallel-able

pull/21904/head
Lidi Zheng 5 years ago
parent 94525e5831
commit 7cb055b035
  1. 131
      src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py

@ -33,11 +33,16 @@ from tests.unit.framework.common import get_socket
_NUM_CORES = multiprocessing.cpu_count() _NUM_CORES = multiprocessing.cpu_count()
_NUM_CORE_PYTHON_CAN_USE = 1 _NUM_CORE_PYTHON_CAN_USE = 1
_WORKER_ENTRY_FILE = os.path.split(os.path.abspath(__file__))[0] + '/worker.py'
_SubWorker = collections.namedtuple(
'_SubWorker', ['process', 'port', 'channel', 'stub'])
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def _get_server_status(start_time: float, end_time: float, def _get_server_status(start_time: float,
end_time: float,
port: int) -> control_pb2.ServerStatus: port: int) -> control_pb2.ServerStatus:
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
@ -46,11 +51,13 @@ def _get_server_status(start_time: float, end_time: float,
time_system=elapsed_time) time_system=elapsed_time)
return control_pb2.ServerStatus(stats=stats, return control_pb2.ServerStatus(stats=stats,
port=port, port=port,
cores=_NUM_CORE_PYTHON_CAN_USE) cores=_NUM_CORES)
def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]:
server = aio.server() server = aio.server(options=(
('grpc.so_reuseport', 1),
))
if config.server_type == control_pb2.ASYNC_SERVER: if config.server_type == control_pb2.ASYNC_SERVER:
servicer = benchmark_servicer.BenchmarkServicer() servicer = benchmark_servicer.BenchmarkServicer()
benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
@ -84,7 +91,7 @@ def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]:
def _get_client_status(start_time: float, end_time: float, def _get_client_status(start_time: float, end_time: float,
qps_data: histogram.Histogram qps_data: histogram.Histogram
) -> control_pb2.ClientStatus: ) -> control_pb2.ClientStatus:
latencies = qps_data.get_data() latencies = qps_data.get_data()
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
@ -97,7 +104,7 @@ def _get_client_status(start_time: float, end_time: float,
def _create_client(server: str, config: control_pb2.ClientConfig, def _create_client(server: str, config: control_pb2.ClientConfig,
qps_data: histogram.Histogram qps_data: histogram.Histogram
) -> benchmark_client.BenchmarkClient: ) -> benchmark_client.BenchmarkClient:
if config.load_params.WhichOneof('load') != 'closed_loop': if config.load_params.WhichOneof('load') != 'closed_loop':
raise NotImplementedError( raise NotImplementedError(
f'Unsupported load parameter {config.load_params}') f'Unsupported load parameter {config.load_params}')
@ -117,25 +124,28 @@ def _create_client(server: str, config: control_pb2.ClientConfig,
return client_type(server, config, qps_data) return client_type(server, config, qps_data)
WORKER_ENTRY_FILE = os.path.split(os.path.abspath(__file__))[0] + '/worker.py' def _pick_an_unused_port() -> int:
SubWorker = collections.namedtuple('SubWorker', ['process', 'port', 'channel', 'stub']) _, port, sock = get_socket()
sock.close()
return port
async def _create_sub_worker() -> _SubWorker:
port = _pick_an_unused_port()
async def _create_sub_worker() -> SubWorker:
address, port, sock = get_socket()
sock.close()
_LOGGER.info('Creating sub worker at port [%d]...', port) _LOGGER.info('Creating sub worker at port [%d]...', port)
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
sys.executable, sys.executable,
WORKER_ENTRY_FILE, _WORKER_ENTRY_FILE,
'--driver_port', str(port) '--driver_port', str(port)
) )
_LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port, process.pid) _LOGGER.info(
channel = aio.insecure_channel(f'{address}:{port}') 'Created sub worker process for port [%d] at pid [%d]', port, process.pid)
channel = aio.insecure_channel(f'localhost:{port}')
_LOGGER.info('Waiting for sub worker at port [%d]', port) _LOGGER.info('Waiting for sub worker at port [%d]', port)
await aio.channel_ready(channel) await aio.channel_ready(channel)
stub = worker_service_pb2_grpc.WorkerServiceStub(channel) stub = worker_service_pb2_grpc.WorkerServiceStub(channel)
return SubWorker( return _SubWorker(
process=process, process=process,
port=port, port=port,
channel=channel, channel=channel,
@ -150,34 +160,89 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._quit_event = asyncio.Event() self._quit_event = asyncio.Event()
# async def _run_single_server(self, config, request_iterator, context): async def _run_single_server(self, config, request_iterator, context):
# server, port = _create_server(config)
# await server.start()
async def RunServer(self, request_iterator, context):
config = (await context.read()).setup
_LOGGER.info('Received ServerConfig: %s', config)
if config.async_server_threads <= 0:
_LOGGER.info('async_server_threads can\'t be [%d]', config.async_server_threads)
_LOGGER.info('Using async_server_threads == [%d]', _NUM_CORES)
config.async_server_threads = _NUM_CORES
server, port = _create_server(config) server, port = _create_server(config)
await server.start() await server.start()
_LOGGER.info('Server started at port [%d]', port) _LOGGER.info('Server started at port [%d]', port)
start_time = time.time() start_time = time.time()
yield _get_server_status(start_time, start_time, port) await context.write(_get_server_status(start_time, start_time, port))
async for request in request_iterator: async for request in request_iterator:
end_time = time.time() end_time = time.time()
status = _get_server_status(start_time, end_time, port) status = _get_server_status(start_time, end_time, port)
if request.mark.reset: if request.mark.reset:
start_time = end_time start_time = end_time
yield status await context.write(status)
await server.stop(None) await server.stop(None)
async def RunServer(self, request_iterator, context):
config_request = await context.read()
config = config_request.setup
_LOGGER.info('Received ServerConfig: %s', config)
if config.async_server_threads <= 0:
_LOGGER.info(
'async_server_threads can\'t be [%d]', config.async_server_threads)
_LOGGER.info('Using async_server_threads == [%d]', _NUM_CORES)
config.async_server_threads = _NUM_CORES
if config.port == 0:
config.port = _pick_an_unused_port()
_LOGGER.info('Port picked [%d]', config.port)
if config.async_server_threads == 1:
await self._run_single_server(config, request_iterator, context)
else:
sub_workers = await asyncio.gather(*(
_create_sub_worker()
for _ in range(config.async_server_threads)
))
calls = [worker.stub.RunServer() for worker in sub_workers]
config_request.setup.async_server_threads = 1
for call in calls:
await call.write(config_request)
# An empty status indicates the peer is ready
await call.read()
start_time = time.time()
await context.write(_get_server_status(
start_time,
start_time,
config.port,
))
async for request in request_iterator:
end_time = time.time()
for call in calls:
_LOGGER.debug('Fetching status...')
await call.write(request)
# Reports from sub workers doesn't matter
await call.read()
status = _get_server_status(
start_time,
end_time,
config.port,
)
if request.mark.reset:
start_time = end_time
await context.write(status)
for call in calls:
await call.done_writing()
for worker in sub_workers:
await worker.stub.QuitWorker(control_pb2.Void())
await worker.channel.close()
_LOGGER.info('Waiting for sub worker [%s] to quit...', worker)
await worker.process.wait()
_LOGGER.info('Sub worker [%s] quit', worker)
async def _run_single_client(self, config, request_iterator, context): async def _run_single_client(self, config, request_iterator, context):
running_tasks = [] running_tasks = []
qps_data = histogram.Histogram(config.histogram_params.resolution, qps_data = histogram.Histogram(config.histogram_params.resolution,
@ -213,7 +278,8 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
_LOGGER.info('Received ClientConfig: %s', config) _LOGGER.info('Received ClientConfig: %s', config)
if config.async_client_threads <= 0: if config.async_client_threads <= 0:
_LOGGER.info('async_client_threads can\'t be [%d]', config.async_client_threads) _LOGGER.info(
'async_client_threads can\'t be [%d]', config.async_client_threads)
_LOGGER.info('Using async_client_threads == [%d]', _NUM_CORES) _LOGGER.info('Using async_client_threads == [%d]', _NUM_CORES)
config.async_client_threads = _NUM_CORES config.async_client_threads = _NUM_CORES
@ -231,7 +297,7 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
for call in calls: for call in calls:
await call.write(config_request) await call.write(config_request)
# An empty status # An empty status indicates the peer is ready
await call.read() await call.read()
start_time = time.time() start_time = time.time()
@ -255,7 +321,8 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
if request.mark.reset: if request.mark.reset:
result.reset() result.reset()
start_time = time.time() start_time = time.time()
_LOGGER.debug('Reporting count=[%d]', status.stats.latencies.count) _LOGGER.debug(
'Reporting count=[%d]', status.stats.latencies.count)
await context.write(status) await context.write(status)
for call in calls: for call in calls:

Loading…
Cancel
Save