Adopt reviewer's suggestions

pull/20598/head
Lidi Zheng 5 years ago
parent 49c7f1ddf6
commit 2ced359d78
  1. 22
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/iomgr.pyx.pxi
  2. 12
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi
  3. 13
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  4. 95
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  5. 26
      src/python/grpcio/grpc/experimental/aio/_server.py
  6. 2
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -17,6 +17,7 @@ from cpython cimport Py_INCREF, Py_DECREF
from libc cimport string
import socket as native_socket
import ipaddress # CPython 3.3 and above
cdef grpc_socket_vtable asyncio_socket_vtable
cdef grpc_custom_resolver_vtable asyncio_resolver_vtable
@ -87,6 +88,7 @@ cdef grpc_error* asyncio_socket_getpeername(
cdef grpc_resolved_address c_addr
hostname = str_to_bytes(peer[0])
grpc_string_to_sockaddr(&c_addr, hostname, peer[1])
# TODO(https://github.com/grpc/grpc/issues/20684) Remove the memcpy
string.memcpy(<void*>addr, <void*>c_addr.addr, c_addr.len)
length[0] = c_addr.len
return grpc_error_none()
@ -105,6 +107,7 @@ cdef grpc_error* asyncio_socket_getsockname(
peer = socket.sockname()
hostname = str_to_bytes(peer[0])
grpc_string_to_sockaddr(&c_addr, hostname, peer[1])
# TODO(https://github.com/grpc/grpc/issues/20684) Remove the memcpy
string.memcpy(<void*>addr, <void*>c_addr.addr, c_addr.len)
length[0] = c_addr.len
return grpc_error_none()
@ -128,19 +131,20 @@ cdef grpc_error* asyncio_socket_bind(
size_t len, int flags) with gil:
host, port = sockaddr_to_tuple(addr, len)
try:
try:
socket = native_socket.socket(family=native_socket.AF_INET6)
_asyncio_apply_socket_options(socket)
socket.bind((host, port))
except native_socket.gaierror:
socket = native_socket.socket(family=native_socket.AF_INET)
_asyncio_apply_socket_options(socket)
socket.bind((host, port))
ip = ipaddress.ip_address(host)
if isinstance(ip, ipaddress.IPv6Address):
family = native_socket.AF_INET6
else:
family = native_socket.AF_INET
socket = native_socket.socket(family=family)
_asyncio_apply_socket_options(socket)
socket.bind((host, port))
except IOError as io_error:
return socket_error("bind", str(io_error))
else:
aio_socket = _AsyncioSocket.create_with_py_socket(grpc_socket, socket)
cpython.Py_INCREF(aio_socket)
cpython.Py_INCREF(aio_socket) # Py_DECREF in asyncio_socket_destroy
grpc_socket.impl = <void*>aio_socket
return grpc_error_none()

@ -112,9 +112,7 @@ cdef class _AsyncioSocket:
object host,
object port,
grpc_custom_connect_callback grpc_connect_cb):
if self._reader:
return
assert not self._reader
assert not self._task_connect
self._task_connect = asyncio.ensure_future(
@ -163,11 +161,11 @@ cdef class _AsyncioSocket:
)
self._grpc_client_socket.impl = <void*>client_socket
cpython.Py_INCREF(client_socket)
cpython.Py_INCREF(client_socket) # Py_DECREF in asyncio_socket_destroy
# Accept callback expects to be called with:
# * An grpc custom socket for server
# * An grpc custom socket for client (with new Socket instance)
# * An error object
# grpc_custom_socket: A grpc custom socket for server
# grpc_custom_socket: A grpc custom socket for client (with new Socket instance)
# grpc_error: An error object
self._grpc_accept_cb(self._grpc_socket, self._grpc_client_socket, grpc_error_none())
cdef listen(self):

@ -32,12 +32,13 @@ cdef enum AioServerStatus:
AIO_SERVER_STATUS_STOPPED
cdef class _AioServerState:
cdef Server server
cdef grpc_completion_queue *cq
cdef list generic_handlers
cdef AioServerStatus status
cdef class _CallbackCompletionQueue:
cdef grpc_completion_queue *_cq
cdef grpc_completion_queue* c_ptr(self)
cdef class AioServer:
cdef _AioServerState _state
cdef Server _server
cdef _CallbackCompletionQueue _cq
cdef list _generic_handlers
cdef AioServerStatus _status

@ -25,12 +25,12 @@ class _ServicerContextPlaceHolder(object): pass
# Apply this to the client-side
cdef class CallbackWrapper:
cdef CallbackContext context
cdef object _keep_reference
cdef object _reference
def __cinit__(self, object future):
self.context.functor.functor_run = self.functor_run
self.context.waiter = <cpython.PyObject*>(future)
self._keep_reference = future
self._reference = future
@staticmethod
cdef void functor_run(
@ -63,10 +63,10 @@ cdef class RPCState:
grpc_call_unref(self.call)
cdef _find_method_handler(RPCState rpc_state, list generic_handlers):
cdef _find_method_handler(str method, list generic_handlers):
# TODO(lidiz) connects Metadata to call details
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(
rpc_state.method().decode(),
method,
tuple()
)
@ -77,8 +77,9 @@ cdef _find_method_handler(RPCState rpc_state, list generic_handlers):
return None
async def callback_start_batch(RPCState rpc_state, tuple operations, object
loop):
async def callback_start_batch(RPCState rpc_state,
tuple operations,
object loop):
"""The callback version of start batch operations."""
cdef _BatchOperationTag batch_operation_tag = _BatchOperationTag(None, operations, None)
batch_operation_tag.prepare()
@ -100,10 +101,13 @@ loop):
await future
cpython.Py_DECREF(wrapper)
cdef grpc_event c_event
# Tag.event must be called, otherwise messages won't be parsed from C
batch_operation_tag.event(c_event)
async def _handle_unary_unary_rpc(object method_handler, RPCState rpc_state, object loop):
async def _handle_unary_unary_rpc(object method_handler,
RPCState rpc_state,
object loop):
# Receives request message
cdef tuple receive_ops = (
ReceiveMessageOperation(_EMPTY_FLAGS),
@ -138,11 +142,11 @@ async def _handle_unary_unary_rpc(object method_handler, RPCState rpc_state, obj
await callback_start_batch(rpc_state, send_ops, loop)
async def _handle_rpc(_AioServerState server_state, RPCState rpc_state, object loop):
async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
# Finds the method handler (application logic)
cdef object method_handler = _find_method_handler(
rpc_state,
server_state.generic_handlers
rpc_state.method().decode(),
generic_handlers
)
if method_handler is None:
# TODO(lidiz) return unimplemented error to client side
@ -158,7 +162,9 @@ async def _handle_rpc(_AioServerState server_state, RPCState rpc_state, object l
)
async def _server_call_request_call(_AioServerState server_state, object loop):
async def _server_call_request_call(Server server,
_CallbackCompletionQueue cq,
object loop):
cdef grpc_call_error error
cdef RPCState rpc_state = RPCState()
cdef object future = loop.create_future()
@ -167,9 +173,9 @@ async def _server_call_request_call(_AioServerState server_state, object loop):
# when calling "await". This is an over-optimization by Cython.
cpython.Py_INCREF(wrapper)
error = grpc_server_request_call(
server_state.server.c_server, &rpc_state.call, &rpc_state.details,
server.c_server, &rpc_state.call, &rpc_state.details,
&rpc_state.request_metadata,
server_state.cq, server_state.cq,
cq.c_ptr(), cq.c_ptr(),
wrapper.c_functor()
)
if error != GRPC_CALL_OK:
@ -180,45 +186,52 @@ async def _server_call_request_call(_AioServerState server_state, object loop):
return rpc_state
async def _server_main_loop(_AioServerState server_state):
async def _server_main_loop(Server server,
_CallbackCompletionQueue cq,
list generic_handlers):
cdef object loop = asyncio.get_event_loop()
cdef RPCState rpc_state
while True:
rpc_state = await _server_call_request_call(
server_state,
server,
cq,
loop)
loop.create_task(_handle_rpc(server_state, rpc_state, loop))
loop.create_task(_handle_rpc(generic_handlers, rpc_state, loop))
await asyncio.sleep(0)
async def _server_start(_AioServerState server_state):
server_state.server.start()
await _server_main_loop(server_state)
async def _server_start(Server server,
_CallbackCompletionQueue cq,
list generic_handlers):
server.start()
await _server_main_loop(server, cq, generic_handlers)
cdef class _CallbackCompletionQueue:
cdef class _AioServerState:
def __cinit__(self):
self.server = None
self.cq = NULL
self.generic_handlers = []
self._cq = grpc_completion_queue_create_for_callback(
NULL,
NULL
)
cdef grpc_completion_queue* c_ptr(self):
return self._cq
cdef class AioServer:
def __init__(self, thread_pool, generic_handlers, interceptors, options,
maximum_concurrent_rpcs, compression):
self._state = _AioServerState()
self._state.server = Server(options)
self._state.cq = grpc_completion_queue_create_for_callback(
NULL,
NULL
)
self._state.status = AIO_SERVER_STATUS_READY
self._server = Server(options)
self._cq = _CallbackCompletionQueue()
self._status = AIO_SERVER_STATUS_READY
self._generic_handlers = []
grpc_server_register_completion_queue(
self._state.server.c_server,
self._state.cq,
self._server.c_server,
self._cq.c_ptr(),
NULL
)
self.add_generic_rpc_handlers(generic_handlers)
@ -234,24 +247,28 @@ cdef class AioServer:
def add_generic_rpc_handlers(self, generic_rpc_handlers):
for h in generic_rpc_handlers:
self._state.generic_handlers.append(h)
self._generic_handlers.append(h)
def add_insecure_port(self, address):
return self._state.server.add_http2_port(address)
return self._server.add_http2_port(address)
def add_secure_port(self, address, server_credentials):
return self._state.server.add_http2_port(address,
return self._server.add_http2_port(address,
server_credentials._credentials)
async def start(self):
if self._state.status == AIO_SERVER_STATUS_RUNNING:
if self._status == AIO_SERVER_STATUS_RUNNING:
return
elif self._state.status != AIO_SERVER_STATUS_READY:
elif self._status != AIO_SERVER_STATUS_READY:
raise RuntimeError('Server not in ready state')
self._state.status = AIO_SERVER_STATUS_RUNNING
self._status = AIO_SERVER_STATUS_RUNNING
loop = asyncio.get_event_loop()
loop.create_task(_server_start(self._state))
loop.create_task(_server_start(
self._server,
self._cq,
self._generic_handlers,
))
await asyncio.sleep(0)
# TODO(https://github.com/grpc/grpc/issues/20668)

@ -16,6 +16,7 @@
from typing import Text, Optional
import asyncio
import grpc
from grpc import _common
from grpc._cython import cygrpc
@ -50,12 +51,12 @@ class Server:
Args:
address: The address for which to open a port. If the port is 0,
or not specified in the address, then gRPC runtime will choose a port.
or not specified in the address, then the gRPC runtime will choose a port.
Returns:
An integer port on which server will accept RPC requests.
An integer port on which the server will accept RPC requests.
"""
return self._server.add_insecure_port(address)
return self._server.add_insecure_port(_common.encode(address))
def add_secure_port(self, address: Text,
server_credentials: grpc.ServerCredentials) -> int:
@ -65,14 +66,15 @@ class Server:
Args:
address: The address for which to open a port.
if the port is 0, or not specified in the address, then gRPC
if the port is 0, or not specified in the address, then the gRPC
runtime will choose a port.
server_credentials: A ServerCredentials object.
Returns:
An integer port on which server will accept RPC requests.
An integer port on which the server will accept RPC requests.
"""
return self._server.add_secure_port(address, server_credentials)
return self._server.add_secure_port(
_common.encode(address), server_credentials)
async def start(self) -> None:
"""Starts this Server.
@ -84,7 +86,8 @@ class Server:
def stop(self, grace: Optional[float]) -> asyncio.Event:
"""Stops this Server.
This method immediately stop service of new RPCs in all cases.
"This method immediately stops the server from servicing new RPCs in
all cases.
If a grace period is specified, this method returns immediately
and all RPCs active at the end of the grace period are aborted.
@ -139,7 +142,7 @@ class Server:
await future
def server(thread_pool=None,
def server(migration_thread_pool=None,
handlers=None,
interceptors=None,
options=None,
@ -148,8 +151,8 @@ def server(thread_pool=None,
"""Creates a Server with which RPCs can be serviced.
Args:
thread_pool: A futures.ThreadPoolExecutor to be used by the Server
to execute RPC handlers.
migration_thread_pool: A futures.ThreadPoolExecutor to be used by the
Server to execute non-AsyncIO RPC handlers for migration purpose.
handlers: An optional list of GenericRpcHandlers used for executing RPCs.
More handlers may be added by calling add_generic_rpc_handlers any time
before the server is started.
@ -169,7 +172,8 @@ def server(thread_pool=None,
Returns:
A Server object.
"""
return Server(thread_pool, () if handlers is None else handlers, ()
return Server(migration_thread_pool, ()
if handlers is None else handlers, ()
if interceptors is None else interceptors, ()
if options is None else options, maximum_concurrent_rpcs,
compression)

@ -44,7 +44,7 @@ class TestServer(unittest.TestCase):
async def test_unary_unary_body():
server = aio.server()
port = server.add_insecure_port(('[::]:0').encode('ASCII'))
port = server.add_insecure_port('[::]:0')
server.add_generic_rpc_handlers((GenericHandler(),))
await server.start()

Loading…
Cancel
Save