Refactorize Cython and Python call communications

Now the status and the initial metadata, as awaitable methods, are
provided by the Cython layer. Any time the Python layer, like the Call
object, needs to know the status of the initial metadata uses the new
methods published by the AioCall
pull/21696/head
Pau Freixes 5 years ago
parent eba60d8dbe
commit 53c41de3e0
  1. 14
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 276
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  4. 9
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 8
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  6. 167
      src/python/grpcio/grpc/experimental/aio/_call.py
  7. 115
      src/python/grpcio/grpc/experimental/aio/_channel.py
  8. 14
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  9. 69
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -17,6 +17,9 @@ cdef class _AioCall(GrpcCallWrapper):
cdef: cdef:
AioChannel _channel AioChannel _channel
list _references list _references
object _deadline
list _done_callbacks
# Caches the picked event loop, so we can avoid the 30ns overhead each # Caches the picked event loop, so we can avoid the 30ns overhead each
# time we need access to the event loop. # time we need access to the event loop.
object _loop object _loop
@ -28,6 +31,15 @@ cdef class _AioCall(GrpcCallWrapper):
# because Core is holding a pointer for the callback handler. # because Core is holding a pointer for the callback handler.
bint _is_locally_cancelled bint _is_locally_cancelled
object _deadline # Following attributes are used for storing the status of the call and
# the initial metadata. Waiters are used for pausing the execution of
# tasks that are asking for one of the field when they are not yet
# available.
object _status
object _initial_metadata
list _waiters_status
list _waiters_initial_metadata
cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except * cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except *
cdef void _set_status(self, AioRpcStatus status) except *
cdef void _set_initial_metadata(self, tuple initial_metadata) except *

@ -18,34 +18,68 @@ import grpc
_EMPTY_FLAGS = 0 _EMPTY_FLAGS = 0
_EMPTY_MASK = 0 _EMPTY_MASK = 0
_EMPTY_METADATA = None _IMMUTABLE_EMPTY_METADATA = tuple()
_UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.' _UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'>')
_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'\tdebug_error_string = "{}"\n'
'>')
cdef class _AioCall(GrpcCallWrapper): cdef class _AioCall(GrpcCallWrapper):
def __cinit__(self, def __cinit__(self, AioChannel channel, object deadline,
AioChannel channel, bytes method, CallCredentials call_credentials):
object deadline,
bytes method,
CallCredentials call_credentials):
self.call = NULL self.call = NULL
self._channel = channel self._channel = channel
self._loop = channel.loop
self._references = [] self._references = []
self._loop = asyncio.get_event_loop() self._status = None
self._create_grpc_call(deadline, method, call_credentials) self._initial_metadata = None
self._waiters_status = []
self._waiters_initial_metadata = []
self._done_callbacks = []
self._is_locally_cancelled = False self._is_locally_cancelled = False
self._deadline = deadline self._deadline = deadline
self._create_grpc_call(deadline, method, call_credentials)
def __dealloc__(self): def __dealloc__(self):
if self.call: if self.call:
grpc_call_unref(self.call) grpc_call_unref(self.call)
def __repr__(self): def _repr(self) -> str:
class_name = self.__class__.__name__ """Assembles the RPC representation string."""
id_ = id(self) # This needs to be loaded at run time once everything
return f"<{class_name} {id_}>" # has been loaded.
from grpc import _common
if not self.done():
return '<{} object>'.format(self.__class__.__name__)
if self._status.code() is StatusCode.ok:
return _OK_CALL_REPRESENTATION.format(
self.__class__.__name__,
_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[self._status.code()],
self._status.details())
else:
return _NON_OK_CALL_REPRESENTATION.format(
self.__class__.__name__,
self._status.details(),
_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[self._status.code()],
self._status.debug_error_string())
def __repr__(self) -> str:
return self._repr()
def __str__(self) -> str:
return self._repr()
cdef void _create_grpc_call(self, cdef void _create_grpc_call(self,
object deadline, object deadline,
@ -85,13 +119,55 @@ cdef class _AioCall(GrpcCallWrapper):
grpc_slice_unref(method_slice) grpc_slice_unref(method_slice)
cdef void _set_status(self, AioRpcStatus status) except *:
cdef list waiters
if self._initial_metadata is None:
self._set_initial_metadata(_IMMUTABLE_EMPTY_METADATA)
self._status = status
waiters = self._waiters_status
# No more waiters should be expected since status
# has been set.
self._waiters_status = None
for waiter in waiters:
if not waiter.done():
waiter.set_result(None)
for callback in self._done_callbacks:
callback()
cdef void _set_initial_metadata(self, tuple initial_metadata) except *:
cdef list waiters
self._initial_metadata = initial_metadata
waiters = self._waiters_initial_metadata
# No more waiters should be expected since initial metadata
# has been set.
self._waiters_initial_metadata = None
for waiter in waiters:
if not waiter.done():
waiter.set_result(None)
def add_done_callback(self, callback):
if self.done():
callback()
else:
self._done_callbacks.append(callback)
def time_remaining(self): def time_remaining(self):
if self._deadline is None: if self._deadline is None:
return None return None
else: else:
return max(0, self._deadline - time.time()) return max(0, self._deadline - time.time())
def cancel(self, AioRpcStatus status): def cancel(self, str details):
"""Cancels the RPC in Core with given RPC status. """Cancels the RPC in Core with given RPC status.
Above abstractions must invoke this method to set Core objects into Above abstractions must invoke this method to set Core objects into
@ -99,44 +175,108 @@ cdef class _AioCall(GrpcCallWrapper):
""" """
self._is_locally_cancelled = True self._is_locally_cancelled = True
cdef object details cdef object details_bytes
cdef char *c_details cdef char *c_details
cdef grpc_call_error error cdef grpc_call_error error
# Try to fetch application layer cancellation details in the future.
# * If cancellation details present, cancel with status; self._set_status(AioRpcStatus(
# * If details not present, cancel with unknown reason. StatusCode.cancelled,
if status is not None: details,
details = str_to_bytes(status.details()) None,
self._references.append(details) None,
c_details = <char *>details ))
details_bytes = str_to_bytes(details)
self._references.append(details_bytes)
c_details = <char *>details_bytes
# By implementation, grpc_call_cancel_with_status always return OK # By implementation, grpc_call_cancel_with_status always return OK
error = grpc_call_cancel_with_status( error = grpc_call_cancel_with_status(
self.call, self.call,
status.c_code(), StatusCode.cancelled,
c_details, c_details,
NULL, NULL,
) )
assert error == GRPC_CALL_OK assert error == GRPC_CALL_OK
else:
# By implementation, grpc_call_cancel always return OK def done(self):
error = grpc_call_cancel(self.call, NULL) """Returns if the RPC call has finished.
assert error == GRPC_CALL_OK
Checks if the status has been provided, either
because the RPC finished or because was cancelled..
Returns:
True if the RPC can be considered finished.
"""
return self._status is not None
def cancelled(self):
"""Returns if the RPC was cancelled.
Returns:
True if the RPC was cancelled.
"""
if not self.done():
return False
return self._status.code() == StatusCode.cancelled
async def status(self):
"""Returns the status of the RPC call.
It returns the finshed status of the RPC. If the RPC
has not finished yet this function will wait until the RPC
gets finished.
Returns:
Finished status of the RPC as an AioRpcStatus object.
"""
if self._status is not None:
return self._status
future = self._loop.create_future()
self._waiters_status.append(future)
await future
return self._status
async def initial_metadata(self):
"""Returns the initial metadata of the RPC call.
If the initial metadata has not been received yet this function will
wait until the RPC gets finished.
Returns:
The tuple object with the initial metadata.
"""
if self._initial_metadata is not None:
return self._initial_metadata
future = self._loop.create_future()
self._waiters_initial_metadata.append(future)
await future
return self._initial_metadata
def is_locally_cancelled(self):
"""Returns if the RPC was cancelled locally.
Returns:
True when was cancelled locally, False when was cancelled remotelly or
is still ongoing.
"""
if self._is_locally_cancelled:
return True
return False
async def unary_unary(self, async def unary_unary(self,
bytes request, bytes request,
tuple outbound_initial_metadata, tuple outbound_initial_metadata):
object initial_metadata_observer,
object status_observer):
"""Performs a unary unary RPC. """Performs a unary unary RPC.
Args: Args:
method: name of the calling method in bytes.
request: the serialized requests in bytes. request: the serialized requests in bytes.
deadline: optional deadline of the RPC in float. outbound_initial_metadata: optional outbound metadata.
cancellation_future: the future that meant to transport the
cancellation reason from the application layer.
initial_metadata_observer: a callback for received initial metadata.
status_observer: a callback for received final status.
""" """
cdef tuple ops cdef tuple ops
@ -159,25 +299,24 @@ cdef class _AioCall(GrpcCallWrapper):
ops, ops,
self._loop) self._loop)
# Reports received initial metadata. self._set_initial_metadata(receive_initial_metadata_op.initial_metadata())
initial_metadata_observer(receive_initial_metadata_op.initial_metadata())
status = AioRpcStatus( cdef grpc_status_code code
receive_status_on_client_op.code(), code = receive_status_on_client_op.code()
self._set_status(AioRpcStatus(
code,
receive_status_on_client_op.details(), receive_status_on_client_op.details(),
receive_status_on_client_op.trailing_metadata(), receive_status_on_client_op.trailing_metadata(),
receive_status_on_client_op.error_string(), receive_status_on_client_op.error_string(),
) ))
# Reports the final status of the RPC to Python layer. The observer
# pattern is used here to unify unary and streaming code path.
status_observer(status)
if status.code() == StatusCode.ok: if code == StatusCode.ok:
return receive_message_op.message() return receive_message_op.message()
else: else:
return None return None
async def _handle_status_once_received(self, object status_observer): async def _handle_status_once_received(self):
"""Handles the status sent by peer once received.""" """Handles the status sent by peer once received."""
cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,) cdef tuple ops = (op,)
@ -187,13 +326,12 @@ cdef class _AioCall(GrpcCallWrapper):
if self._is_locally_cancelled: if self._is_locally_cancelled:
return return
cdef AioRpcStatus status = AioRpcStatus( self._set_status(AioRpcStatus(
op.code(), op.code(),
op.details(), op.details(),
op.trailing_metadata(), op.trailing_metadata(),
op.error_string(), op.error_string(),
) ))
status_observer(status)
async def receive_serialized_message(self): async def receive_serialized_message(self):
"""Receives one single raw message in bytes.""" """Receives one single raw message in bytes."""
@ -227,13 +365,11 @@ cdef class _AioCall(GrpcCallWrapper):
async def initiate_unary_stream(self, async def initiate_unary_stream(self,
bytes request, bytes request,
tuple outbound_initial_metadata, tuple outbound_initial_metadata):
object initial_metadata_observer,
object status_observer):
"""Implementation of the start of a unary-stream call.""" """Implementation of the start of a unary-stream call."""
# Peer may prematurely end this RPC at any point. We need a corutine # Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status. # that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received(status_observer)) self._loop.create_task(self._handle_status_once_received())
cdef tuple outbound_ops cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation( cdef Operation initial_metadata_op = SendInitialMetadataOperation(
@ -257,16 +393,14 @@ cdef class _AioCall(GrpcCallWrapper):
self._loop) self._loop)
# Receives initial metadata. # Receives initial metadata.
initial_metadata_observer( self._set_initial_metadata(
await _receive_initial_metadata(self, await _receive_initial_metadata(self,
self._loop), self._loop),
) )
async def stream_unary(self, async def stream_unary(self,
tuple outbound_initial_metadata, tuple outbound_initial_metadata,
object metadata_sent_observer, object metadata_sent_observer):
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete unary-stream call. """Actual implementation of the complete unary-stream call.
Needs to pay extra attention to the raise mechanism. If we want to Needs to pay extra attention to the raise mechanism. If we want to
@ -281,9 +415,8 @@ cdef class _AioCall(GrpcCallWrapper):
metadata_sent_observer() metadata_sent_observer()
# Receives initial metadata. # Receives initial metadata.
initial_metadata_observer( self._set_initial_metadata(
await _receive_initial_metadata(self, await _receive_initial_metadata(self, self._loop)
self._loop),
) )
cdef tuple inbound_ops cdef tuple inbound_ops
@ -296,26 +429,24 @@ cdef class _AioCall(GrpcCallWrapper):
inbound_ops, inbound_ops,
self._loop) self._loop)
status = AioRpcStatus( cdef grpc_status_code code
receive_status_on_client_op.code(), code = receive_status_on_client_op.code()
self._set_status(AioRpcStatus(
code,
receive_status_on_client_op.details(), receive_status_on_client_op.details(),
receive_status_on_client_op.trailing_metadata(), receive_status_on_client_op.trailing_metadata(),
receive_status_on_client_op.error_string(), receive_status_on_client_op.error_string(),
) ))
# Reports the final status of the RPC to Python layer. The observer
# pattern is used here to unify unary and streaming code path.
status_observer(status)
if status.code() == StatusCode.ok: if code == StatusCode.ok:
return receive_message_op.message() return receive_message_op.message()
else: else:
return None return None
async def initiate_stream_stream(self, async def initiate_stream_stream(self,
tuple outbound_initial_metadata, tuple outbound_initial_metadata,
object metadata_sent_observer, object metadata_sent_observer):
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete stream-stream call. """Actual implementation of the complete stream-stream call.
Needs to pay extra attention to the raise mechanism. If we want to Needs to pay extra attention to the raise mechanism. If we want to
@ -324,7 +455,7 @@ cdef class _AioCall(GrpcCallWrapper):
""" """
# Peer may prematurely end this RPC at any point. We need a corutine # Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status. # that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received(status_observer)) self._loop.create_task(self._handle_status_once_received())
# Sends out initial_metadata ASAP. # Sends out initial_metadata ASAP.
await _send_initial_metadata(self, await _send_initial_metadata(self,
@ -334,7 +465,6 @@ cdef class _AioCall(GrpcCallWrapper):
metadata_sent_observer() metadata_sent_observer()
# Receives initial metadata. # Receives initial metadata.
initial_metadata_observer( self._set_initial_metadata(
await _receive_initial_metadata(self, await _receive_initial_metadata(self, self._loop)
self._loop),
) )

@ -21,6 +21,6 @@ cdef class AioChannel:
cdef: cdef:
grpc_channel * channel grpc_channel * channel
CallbackCompletionQueue cq CallbackCompletionQueue cq
object loop
bytes _target bytes _target
object _loop
AioChannelStatus _status AioChannelStatus _status

@ -25,13 +25,13 @@ cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailur
cdef class AioChannel: cdef class AioChannel:
def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials): def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials, object loop):
if options is None: if options is None:
options = () options = ()
cdef _ChannelArgs channel_args = _ChannelArgs(options) cdef _ChannelArgs channel_args = _ChannelArgs(options)
self._target = target self._target = target
self.cq = CallbackCompletionQueue() self.cq = CallbackCompletionQueue()
self._loop = asyncio.get_event_loop() self.loop = loop
self._status = AIO_CHANNEL_STATUS_READY self._status = AIO_CHANNEL_STATUS_READY
if credentials is None: if credentials is None:
@ -71,7 +71,7 @@ cdef class AioChannel:
raise RuntimeError('Channel is closed.') raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline) cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef object future = self._loop.create_future() cdef object future = self.loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper( cdef CallbackWrapper wrapper = CallbackWrapper(
future, future,
_WATCH_CONNECTIVITY_FAILURE_HANDLER) _WATCH_CONNECTIVITY_FAILURE_HANDLER)
@ -112,5 +112,4 @@ cdef class AioChannel:
else: else:
cython_call_credentials = None cython_call_credentials = None
cdef _AioCall call = _AioCall(self, deadline, method, cython_call_credentials) return _AioCall(self, deadline, method, cython_call_credentials)
return call

@ -40,7 +40,7 @@ cdef class RPCState:
self.abort_exception = None self.abort_exception = None
self.metadata_sent = False self.metadata_sent = False
self.status_sent = False self.status_sent = False
self.trailing_metadata = _EMPTY_METADATA self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA
cdef bytes method(self): cdef bytes method(self):
return _slice_bytes(self.details.method) return _slice_bytes(self.details.method)
@ -129,7 +129,7 @@ cdef class _ServicerContext:
async def abort(self, async def abort(self,
object code, object code,
str details='', str details='',
tuple trailing_metadata=_EMPTY_METADATA): tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA):
if self._rpc_state.abort_exception is not None: if self._rpc_state.abort_exception is not None:
raise RuntimeError('Abort already called!') raise RuntimeError('Abort already called!')
else: else:
@ -138,7 +138,7 @@ cdef class _ServicerContext:
# could lead to undefined behavior. # could lead to undefined behavior.
self._rpc_state.abort_exception = AbortError('Locally aborted.') self._rpc_state.abort_exception = AbortError('Locally aborted.')
if trailing_metadata == _EMPTY_METADATA and self._rpc_state.trailing_metadata: if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata:
trailing_metadata = self._rpc_state.trailing_metadata trailing_metadata = self._rpc_state.trailing_metadata
self._rpc_state.status_sent = True self._rpc_state.status_sent = True
@ -471,7 +471,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
rpc_state, rpc_state,
StatusCode.unimplemented, StatusCode.unimplemented,
'Method not found!', 'Method not found!',
_EMPTY_METADATA, _IMMUTABLE_EMPTY_METADATA,
rpc_state.metadata_sent, rpc_state.metadata_sent,
loop loop
) )

@ -14,7 +14,8 @@
"""Invocation-side implementation of gRPC Asyncio Python.""" """Invocation-side implementation of gRPC Asyncio Python."""
import asyncio import asyncio
from typing import AsyncIterable, Awaitable, List, Dict, Optional from functools import partial
from typing import AsyncIterable, List, Dict, Optional
import grpc import grpc
from grpc import _common from grpc import _common
@ -42,8 +43,6 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tdebug_error_string = "{}"\n' '\tdebug_error_string = "{}"\n'
'>') '>')
_EMPTY_METADATA = tuple()
class AioRpcError(grpc.RpcError): class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API. """An implementation of RpcError to be used by the asynchronous API.
@ -153,116 +152,69 @@ class Call(_base_call.Call):
""" """
_loop: asyncio.AbstractEventLoop _loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode _code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType]
_locally_cancelled: bool
_cython_call: cygrpc._AioCall _cython_call: cygrpc._AioCall
_done_callbacks: List[DoneCallbackType] _done_callbacks: List[DoneCallbackType]
def __init__(self, cython_call: cygrpc._AioCall) -> None: def __init__(self, cython_call: cygrpc._AioCall,
self._loop = asyncio.get_event_loop() loop: asyncio.AbstractEventLoop) -> None:
self._code = None self._loop = loop
self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future()
self._locally_cancelled = False
self._cython_call = cython_call self._cython_call = cython_call
self._done_callbacks = [] self._done_callbacks = []
def __del__(self) -> None: def __del__(self) -> None:
if not self._status.done(): if not self._cython_call.done():
self._cancel( self._cancel(_GC_CANCELLATION_DETAILS)
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
def cancelled(self) -> bool: def cancelled(self) -> bool:
return self._code == grpc.StatusCode.CANCELLED return self._cython_call.cancelled()
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: def _cancel(self, details: str) -> bool:
"""Forwards the application cancellation reasoning.""" """Forwards the application cancellation reasoning."""
if not self._status.done(): if not self._cython_call.done():
self._set_status(status) self._cython_call.cancel(details)
self._cython_call.cancel(status)
return True return True
else: else:
return False return False
def cancel(self) -> bool: def cancel(self) -> bool:
return self._cancel( return self._cancel(_LOCAL_CANCELLATION_DETAILS)
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
def done(self) -> bool: def done(self) -> bool:
return self._status.done() return self._cython_call.done()
def add_done_callback(self, callback: DoneCallbackType) -> None: def add_done_callback(self, callback: DoneCallbackType) -> None:
if self.done(): cb = partial(callback, self)
callback(self) self._cython_call.add_done_callback(cb)
else:
self._done_callbacks.append(callback)
def time_remaining(self) -> Optional[float]: def time_remaining(self) -> Optional[float]:
return self._cython_call.time_remaining() return self._cython_call.time_remaining()
async def initial_metadata(self) -> MetadataType: async def initial_metadata(self) -> MetadataType:
return await self._initial_metadata return await self._cython_call.initial_metadata()
async def trailing_metadata(self) -> MetadataType: async def trailing_metadata(self) -> MetadataType:
return (await self._status).trailing_metadata() return (await self._cython_call.status()).trailing_metadata()
async def code(self) -> grpc.StatusCode: async def code(self) -> grpc.StatusCode:
await self._status cygrpc_code = (await self._cython_call.status()).code()
return self._code return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
async def details(self) -> str: async def details(self) -> str:
return (await self._status).details() return (await self._cython_call.status()).details()
async def debug_error_string(self) -> str: async def debug_error_string(self) -> str:
return (await self._status).debug_error_string() return (await self._cython_call.status()).debug_error_string()
def _set_initial_metadata(self, metadata: MetadataType) -> None:
self._initial_metadata.set_result(metadata)
def _set_status(self, status: cygrpc.AioRpcStatus) -> None:
"""Private method to set final status of the RPC.
This method should only be invoked once.
"""
# In case of local cancellation, flip the flag.
if status.details() is _LOCAL_CANCELLATION_DETAILS:
self._locally_cancelled = True
# In case of the RPC finished without receiving metadata.
if not self._initial_metadata.done():
self._initial_metadata.set_result(_EMPTY_METADATA)
# Sets final status
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()]
for callback in self._done_callbacks:
callback(self)
async def _raise_for_status(self) -> None: async def _raise_for_status(self) -> None:
if self._locally_cancelled: if self._cython_call.is_locally_cancelled():
raise asyncio.CancelledError() raise asyncio.CancelledError()
await self._status code = await self.code()
if self._code != grpc.StatusCode.OK: if code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(), raise _create_rpc_error(await self.initial_metadata(), await
self._status.result()) self._cython_call.status())
def _repr(self) -> str: def _repr(self) -> str:
"""Assembles the RPC representation string.""" return repr(self._cython_call)
if not self._status.done():
return '<{} object>'.format(self.__class__.__name__)
if self._code is grpc.StatusCode.OK:
return _OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().details())
else:
return _NON_OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().details(),
self._status.result().debug_error_string())
def __repr__(self) -> str: def __repr__(self) -> str:
return self._repr() return self._repr()
@ -288,13 +240,14 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction,
super().__init__(channel.call(method, deadline, credentials)) loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._request = request self._request = request
self._metadata = metadata self._metadata = metadata
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke()) self._call = loop.create_task(self._invoke())
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
@ -312,11 +265,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try: try:
serialized_response = await self._cython_call.unary_unary( serialized_response = await self._cython_call.unary_unary(
serialized_request, serialized_request, self._metadata)
self._metadata,
self._set_initial_metadata,
self._set_status,
)
except asyncio.CancelledError: except asyncio.CancelledError:
if not self.cancelled(): if not self.cancelled():
self.cancel() self.cancel()
@ -360,13 +309,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction,
super().__init__(channel.call(method, deadline, credentials)) loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._request = request self._request = request
self._metadata = metadata self._metadata = metadata
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._send_unary_request_task = self._loop.create_task( self._send_unary_request_task = loop.create_task(
self._send_unary_request()) self._send_unary_request())
self._message_aiter = None self._message_aiter = None
@ -382,8 +332,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._request_serializer) self._request_serializer)
try: try:
await self._cython_call.initiate_unary_stream( await self._cython_call.initiate_unary_stream(
serialized_request, self._metadata, self._set_initial_metadata, serialized_request, self._metadata)
self._set_status)
except asyncio.CancelledError: except asyncio.CancelledError:
if not self.cancelled(): if not self.cancelled():
self.cancel() self.cancel()
@ -419,7 +368,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._response_deserializer) self._response_deserializer)
async def read(self) -> ResponseType: async def read(self) -> ResponseType:
if self._status.done(): if self._cython_call.done():
await self._raise_for_status() await self._raise_for_status()
return cygrpc.EOF return cygrpc.EOF
@ -452,16 +401,17 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction,
super().__init__(channel.call(method, deadline, credentials)) loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._metadata = metadata self._metadata = metadata
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=self._loop) self._metadata_sent = asyncio.Event(loop=loop)
self._done_writing = False self._done_writing = False
self._call_finisher = self._loop.create_task(self._conduct_rpc()) self._call_finisher = loop.create_task(self._conduct_rpc())
# If user passes in an async iterator, create a consumer Task. # If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None: if request_async_iterator is not None:
@ -485,11 +435,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
async def _conduct_rpc(self) -> ResponseType: async def _conduct_rpc(self) -> ResponseType:
try: try:
serialized_response = await self._cython_call.stream_unary( serialized_response = await self._cython_call.stream_unary(
self._metadata, self._metadata, self._metadata_sent_observer)
self._metadata_sent_observer,
self._set_initial_metadata,
self._set_status,
)
except asyncio.CancelledError: except asyncio.CancelledError:
if not self.cancelled(): if not self.cancelled():
self.cancel() self.cancel()
@ -517,7 +463,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
return response return response
async def write(self, request: RequestType) -> None: async def write(self, request: RequestType) -> None:
if self._status.done(): if self._cython_call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing: if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
@ -536,7 +482,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
async def done_writing(self) -> None: async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent.""" """Implementation of done_writing is idempotent."""
if self._status.done(): if self._cython_call.done():
# If the RPC is finished, do nothing. # If the RPC is finished, do nothing.
return return
if not self._done_writing: if not self._done_writing:
@ -572,20 +518,21 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction,
super().__init__(channel.call(method, deadline, credentials)) loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._metadata = metadata self._metadata = metadata
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=self._loop) self._metadata_sent = asyncio.Event(loop=loop)
self._done_writing = False self._done_writing = False
self._initializer = self._loop.create_task(self._prepare_rpc()) self._initializer = self._loop.create_task(self._prepare_rpc())
# If user passes in an async iterator, create a consumer coroutine. # If user passes in an async iterator, create a consumer coroutine.
if request_async_iterator is not None: if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task( self._async_request_poller = loop.create_task(
self._consume_request_iterator(request_async_iterator)) self._consume_request_iterator(request_async_iterator))
else: else:
self._async_request_poller = None self._async_request_poller = None
@ -611,11 +558,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
""" """
try: try:
await self._cython_call.initiate_stream_stream( await self._cython_call.initiate_stream_stream(
self._metadata, self._metadata, self._metadata_sent_observer)
self._metadata_sent_observer,
self._set_initial_metadata,
self._set_status,
)
except asyncio.CancelledError: except asyncio.CancelledError:
if not self.cancelled(): if not self.cancelled():
self.cancel() self.cancel()
@ -629,7 +572,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
await self.done_writing() await self.done_writing()
async def write(self, request: RequestType) -> None: async def write(self, request: RequestType) -> None:
if self._status.done(): if self._cython_call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing: if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
@ -648,7 +591,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
async def done_writing(self) -> None: async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent.""" """Implementation of done_writing is idempotent."""
if self._status.done(): if self._cython_call.done():
# If the RPC is finished, do nothing. # If the RPC is finished, do nothing.
return return
if not self._done_writing: if not self._done_writing:
@ -692,7 +635,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
self._response_deserializer) self._response_deserializer)
async def read(self) -> ResponseType: async def read(self) -> ResponseType:
if self._status.done(): if self._cython_call.done():
await self._raise_for_status() await self._raise_for_status()
return cygrpc.EOF return cygrpc.EOF

@ -28,6 +28,8 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction) SerializingFunction)
from ._utils import _timeout_to_deadline from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple()
class _BaseMultiCallable: class _BaseMultiCallable:
"""Base class of all multi callable objects. """Base class of all multi callable objects.
@ -47,12 +49,14 @@ class _BaseMultiCallable:
_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] _interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_loop: asyncio.AbstractEventLoop _loop: asyncio.AbstractEventLoop
def __init__(self, channel: cygrpc.AioChannel, method: bytes, def __init__(self, channel: cygrpc.AioChannel,
method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]],
loop: asyncio.AbstractEventLoop,
) -> None: ) -> None:
self._loop = asyncio.get_event_loop() self._loop = loop
self._channel = channel self._channel = channel
self._method = method self._method = method
self._request_serializer = request_serializer self._request_serializer = request_serializer
@ -102,31 +106,20 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raise NotImplementedError("TODO: compression not implemented yet") raise NotImplementedError("TODO: compression not implemented yet")
if metadata is None: if metadata is None:
metadata = tuple() metadata = _IMMUTABLE_EMPTY_TUPLE
if not self._interceptors: if not self._interceptors:
return UnaryUnaryCall( return UnaryUnaryCall(request, _timeout_to_deadline(timeout),
request, metadata, credentials, self._channel,
_timeout_to_deadline(timeout), self._method, self._request_serializer,
metadata, self._response_deserializer, self._loop)
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
else: else:
return InterceptedUnaryUnaryCall( return InterceptedUnaryUnaryCall(self._interceptors, request,
self._interceptors, timeout, metadata, credentials,
request, self._channel, self._method,
timeout,
metadata,
credentials,
self._channel,
self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._response_deserializer,
) self._loop)
class UnaryStreamMultiCallable(_BaseMultiCallable): class UnaryStreamMultiCallable(_BaseMultiCallable):
@ -168,18 +161,12 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None: if metadata is None:
metadata = tuple() metadata = _IMMUTABLE_EMPTY_TUPLE
return UnaryStreamCall( return UnaryStreamCall(request, deadline, metadata, credentials,
request, self._channel, self._method,
deadline,
metadata,
credentials,
self._channel,
self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._response_deserializer, self._loop)
)
class StreamUnaryMultiCallable(_BaseMultiCallable): class StreamUnaryMultiCallable(_BaseMultiCallable):
@ -225,18 +212,12 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None: if metadata is None:
metadata = tuple() metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamUnaryCall( return StreamUnaryCall(request_async_iterator, deadline, metadata,
request_async_iterator, credentials, self._channel, self._method,
deadline,
metadata,
credentials,
self._channel,
self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._response_deserializer, self._loop)
)
class StreamStreamMultiCallable(_BaseMultiCallable): class StreamStreamMultiCallable(_BaseMultiCallable):
@ -282,18 +263,12 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None: if metadata is None:
metadata = tuple() metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamStreamCall( return StreamStreamCall(request_async_iterator, deadline, metadata,
request_async_iterator, credentials, self._channel, self._method,
deadline,
metadata,
credentials,
self._channel,
self._method,
self._request_serializer, self._request_serializer,
self._response_deserializer, self._response_deserializer, self._loop)
)
class Channel: class Channel:
@ -301,6 +276,7 @@ class Channel:
A cygrpc.AioChannel-backed implementation. A cygrpc.AioChannel-backed implementation.
""" """
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel _channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
@ -341,8 +317,9 @@ class Channel:
"UnaryUnaryClientInterceptors, the following are invalid: {}"\ "UnaryUnaryClientInterceptors, the following are invalid: {}"\
.format(invalid_interceptors)) .format(invalid_interceptors))
self._loop = asyncio.get_event_loop()
self._channel = cygrpc.AioChannel(_common.encode(target), options, self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials) credentials, self._loop)
def get_state(self, def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity: try_to_connect: bool = False) -> grpc.ChannelConnectivity:
@ -408,10 +385,12 @@ class Channel:
Returns: Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method. A UnaryUnaryMultiCallable value for the named unary-unary method.
""" """
return UnaryUnaryMultiCallable(self._channel, _common.encode(method), return UnaryUnaryMultiCallable(self._channel,
_common.encode(method),
request_serializer, request_serializer,
response_deserializer, response_deserializer,
self._unary_unary_interceptors) self._unary_unary_interceptors,
self._loop)
def unary_stream( def unary_stream(
self, self,
@ -419,9 +398,11 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable: ) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method), return UnaryStreamMultiCallable(self._channel,
_common.encode(method),
request_serializer, request_serializer,
response_deserializer, None) response_deserializer,
None, self._loop)
def stream_unary( def stream_unary(
self, self,
@ -429,9 +410,11 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable: ) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, _common.encode(method), return StreamUnaryMultiCallable(self._channel,
_common.encode(method),
request_serializer, request_serializer,
response_deserializer, None) response_deserializer,
None, self._loop)
def stream_stream( def stream_stream(
self, self,
@ -439,9 +422,11 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable: ) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, _common.encode(method), return StreamStreamMultiCallable(self._channel,
_common.encode(method),
request_serializer, request_serializer,
response_deserializer, None) response_deserializer,
None, self._loop)
async def _close(self): async def _close(self):
# TODO: Send cancellation status # TODO: Send cancellation status

@ -110,12 +110,14 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._channel = channel self._channel = channel
self._loop = asyncio.get_event_loop() self._loop = loop
self._interceptors_task = asyncio.ensure_future( self._interceptors_task = asyncio.ensure_future(self._invoke(
self._invoke(interceptors, method, timeout, metadata, credentials, interceptors, method, timeout, metadata, credentials, request,
request, request_serializer, response_deserializer)) request_serializer, response_deserializer),
loop=loop)
def __del__(self): def __del__(self):
self.cancel() self.cancel()
@ -154,7 +156,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
client_call_details.metadata, client_call_details.metadata,
client_call_details.credentials, self._channel, client_call_details.credentials, self._channel,
client_call_details.method, request_serializer, client_call_details.method, request_serializer,
response_deserializer) response_deserializer, self._loop)
client_call_details = ClientCallDetails(method, timeout, metadata, client_call_details = ClientCallDetails(method, timeout, metadata,
credentials) credentials)

@ -48,6 +48,16 @@ class _MulticallableTestMixin():
class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
async def test_call_to_string(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None)
response = await call
self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None)
async def test_call_ok(self): async def test_call_ok(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
@ -105,6 +115,65 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata()) self.assertEqual((), await call.trailing_metadata())
async def test_call_initial_metadata_cancelable(self):
coro_started = asyncio.Event()
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro():
coro_started.set()
await call.initial_metadata()
task = self.loop.create_task(coro())
await coro_started.wait()
task.cancel()
# Test that initial metadata can still be asked thought
# a cancellation happened with the previous task
self.assertEqual((), await call.initial_metadata())
async def test_call_initial_metadata_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro():
return await call.initial_metadata()
task1 = self.loop.create_task(coro())
task2 = self.loop.create_task(coro())
await call
self.assertEqual([(), ()], await asyncio.gather(*[task1, task2]))
async def test_call_code_cancelable(self):
coro_started = asyncio.Event()
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro():
coro_started.set()
await call.code()
task = self.loop.create_task(coro())
await coro_started.wait()
task.cancel()
# Test that code can still be asked thought
# a cancellation happened with the previous task
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_call_code_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro():
return await call.code()
task1 = self.loop.create_task(coro())
task2 = self.loop.create_task(coro())
await call
self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await
asyncio.gather(task1, task2))
async def test_cancel_unary_unary(self): async def test_cancel_unary_unary(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) call = self._stub.UnaryCall(messages_pb2.SimpleRequest())

Loading…
Cancel
Save