Refactorize Cython and Python call communications

Now the status and the initial metadata, as awaitable methods, are
provided by the Cython layer. Any time the Python layer, like the Call
object, needs to know the status of the initial metadata uses the new
methods published by the AioCall
pull/21696/head
Pau Freixes 5 years ago
parent eba60d8dbe
commit 53c41de3e0
  1. 14
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 290
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  4. 9
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 8
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  6. 167
      src/python/grpcio/grpc/experimental/aio/_call.py
  7. 119
      src/python/grpcio/grpc/experimental/aio/_channel.py
  8. 14
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  9. 69
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -17,6 +17,9 @@ cdef class _AioCall(GrpcCallWrapper):
cdef:
AioChannel _channel
list _references
object _deadline
list _done_callbacks
# Caches the picked event loop, so we can avoid the 30ns overhead each
# time we need access to the event loop.
object _loop
@ -28,6 +31,15 @@ cdef class _AioCall(GrpcCallWrapper):
# because Core is holding a pointer for the callback handler.
bint _is_locally_cancelled
object _deadline
# Following attributes are used for storing the status of the call and
# the initial metadata. Waiters are used for pausing the execution of
# tasks that are asking for one of the field when they are not yet
# available.
object _status
object _initial_metadata
list _waiters_status
list _waiters_initial_metadata
cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except *
cdef void _set_status(self, AioRpcStatus status) except *
cdef void _set_initial_metadata(self, tuple initial_metadata) except *

@ -18,34 +18,68 @@ import grpc
_EMPTY_FLAGS = 0
_EMPTY_MASK = 0
_EMPTY_METADATA = None
_IMMUTABLE_EMPTY_METADATA = tuple()
_UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'>')
_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
'\tdetails = "{}"\n'
'\tdebug_error_string = "{}"\n'
'>')
cdef class _AioCall(GrpcCallWrapper):
def __cinit__(self,
AioChannel channel,
object deadline,
bytes method,
CallCredentials call_credentials):
def __cinit__(self, AioChannel channel, object deadline,
bytes method, CallCredentials call_credentials):
self.call = NULL
self._channel = channel
self._loop = channel.loop
self._references = []
self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method, call_credentials)
self._status = None
self._initial_metadata = None
self._waiters_status = []
self._waiters_initial_metadata = []
self._done_callbacks = []
self._is_locally_cancelled = False
self._deadline = deadline
self._create_grpc_call(deadline, method, call_credentials)
def __dealloc__(self):
if self.call:
grpc_call_unref(self.call)
def __repr__(self):
class_name = self.__class__.__name__
id_ = id(self)
return f"<{class_name} {id_}>"
def _repr(self) -> str:
"""Assembles the RPC representation string."""
# This needs to be loaded at run time once everything
# has been loaded.
from grpc import _common
if not self.done():
return '<{} object>'.format(self.__class__.__name__)
if self._status.code() is StatusCode.ok:
return _OK_CALL_REPRESENTATION.format(
self.__class__.__name__,
_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[self._status.code()],
self._status.details())
else:
return _NON_OK_CALL_REPRESENTATION.format(
self.__class__.__name__,
self._status.details(),
_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[self._status.code()],
self._status.debug_error_string())
def __repr__(self) -> str:
return self._repr()
def __str__(self) -> str:
return self._repr()
cdef void _create_grpc_call(self,
object deadline,
@ -85,13 +119,55 @@ cdef class _AioCall(GrpcCallWrapper):
grpc_slice_unref(method_slice)
cdef void _set_status(self, AioRpcStatus status) except *:
cdef list waiters
if self._initial_metadata is None:
self._set_initial_metadata(_IMMUTABLE_EMPTY_METADATA)
self._status = status
waiters = self._waiters_status
# No more waiters should be expected since status
# has been set.
self._waiters_status = None
for waiter in waiters:
if not waiter.done():
waiter.set_result(None)
for callback in self._done_callbacks:
callback()
cdef void _set_initial_metadata(self, tuple initial_metadata) except *:
cdef list waiters
self._initial_metadata = initial_metadata
waiters = self._waiters_initial_metadata
# No more waiters should be expected since initial metadata
# has been set.
self._waiters_initial_metadata = None
for waiter in waiters:
if not waiter.done():
waiter.set_result(None)
def add_done_callback(self, callback):
if self.done():
callback()
else:
self._done_callbacks.append(callback)
def time_remaining(self):
if self._deadline is None:
return None
else:
return max(0, self._deadline - time.time())
def cancel(self, AioRpcStatus status):
def cancel(self, str details):
"""Cancels the RPC in Core with given RPC status.
Above abstractions must invoke this method to set Core objects into
@ -99,44 +175,108 @@ cdef class _AioCall(GrpcCallWrapper):
"""
self._is_locally_cancelled = True
cdef object details
cdef object details_bytes
cdef char *c_details
cdef grpc_call_error error
# Try to fetch application layer cancellation details in the future.
# * If cancellation details present, cancel with status;
# * If details not present, cancel with unknown reason.
if status is not None:
details = str_to_bytes(status.details())
self._references.append(details)
c_details = <char *>details
# By implementation, grpc_call_cancel_with_status always return OK
error = grpc_call_cancel_with_status(
self.call,
status.c_code(),
c_details,
NULL,
)
assert error == GRPC_CALL_OK
else:
# By implementation, grpc_call_cancel always return OK
error = grpc_call_cancel(self.call, NULL)
assert error == GRPC_CALL_OK
self._set_status(AioRpcStatus(
StatusCode.cancelled,
details,
None,
None,
))
details_bytes = str_to_bytes(details)
self._references.append(details_bytes)
c_details = <char *>details_bytes
# By implementation, grpc_call_cancel_with_status always return OK
error = grpc_call_cancel_with_status(
self.call,
StatusCode.cancelled,
c_details,
NULL,
)
assert error == GRPC_CALL_OK
def done(self):
"""Returns if the RPC call has finished.
Checks if the status has been provided, either
because the RPC finished or because was cancelled..
Returns:
True if the RPC can be considered finished.
"""
return self._status is not None
def cancelled(self):
"""Returns if the RPC was cancelled.
Returns:
True if the RPC was cancelled.
"""
if not self.done():
return False
return self._status.code() == StatusCode.cancelled
async def status(self):
"""Returns the status of the RPC call.
It returns the finshed status of the RPC. If the RPC
has not finished yet this function will wait until the RPC
gets finished.
Returns:
Finished status of the RPC as an AioRpcStatus object.
"""
if self._status is not None:
return self._status
future = self._loop.create_future()
self._waiters_status.append(future)
await future
return self._status
async def initial_metadata(self):
"""Returns the initial metadata of the RPC call.
If the initial metadata has not been received yet this function will
wait until the RPC gets finished.
Returns:
The tuple object with the initial metadata.
"""
if self._initial_metadata is not None:
return self._initial_metadata
future = self._loop.create_future()
self._waiters_initial_metadata.append(future)
await future
return self._initial_metadata
def is_locally_cancelled(self):
"""Returns if the RPC was cancelled locally.
Returns:
True when was cancelled locally, False when was cancelled remotelly or
is still ongoing.
"""
if self._is_locally_cancelled:
return True
return False
async def unary_unary(self,
bytes request,
tuple outbound_initial_metadata,
object initial_metadata_observer,
object status_observer):
tuple outbound_initial_metadata):
"""Performs a unary unary RPC.
Args:
method: name of the calling method in bytes.
request: the serialized requests in bytes.
deadline: optional deadline of the RPC in float.
cancellation_future: the future that meant to transport the
cancellation reason from the application layer.
initial_metadata_observer: a callback for received initial metadata.
status_observer: a callback for received final status.
outbound_initial_metadata: optional outbound metadata.
"""
cdef tuple ops
@ -159,25 +299,24 @@ cdef class _AioCall(GrpcCallWrapper):
ops,
self._loop)
# Reports received initial metadata.
initial_metadata_observer(receive_initial_metadata_op.initial_metadata())
self._set_initial_metadata(receive_initial_metadata_op.initial_metadata())
cdef grpc_status_code code
code = receive_status_on_client_op.code()
status = AioRpcStatus(
receive_status_on_client_op.code(),
self._set_status(AioRpcStatus(
code,
receive_status_on_client_op.details(),
receive_status_on_client_op.trailing_metadata(),
receive_status_on_client_op.error_string(),
)
# Reports the final status of the RPC to Python layer. The observer
# pattern is used here to unify unary and streaming code path.
status_observer(status)
))
if status.code() == StatusCode.ok:
if code == StatusCode.ok:
return receive_message_op.message()
else:
return None
async def _handle_status_once_received(self, object status_observer):
async def _handle_status_once_received(self):
"""Handles the status sent by peer once received."""
cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,)
@ -187,13 +326,12 @@ cdef class _AioCall(GrpcCallWrapper):
if self._is_locally_cancelled:
return
cdef AioRpcStatus status = AioRpcStatus(
self._set_status(AioRpcStatus(
op.code(),
op.details(),
op.trailing_metadata(),
op.error_string(),
)
status_observer(status)
))
async def receive_serialized_message(self):
"""Receives one single raw message in bytes."""
@ -227,13 +365,11 @@ cdef class _AioCall(GrpcCallWrapper):
async def initiate_unary_stream(self,
bytes request,
tuple outbound_initial_metadata,
object initial_metadata_observer,
object status_observer):
tuple outbound_initial_metadata):
"""Implementation of the start of a unary-stream call."""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received(status_observer))
self._loop.create_task(self._handle_status_once_received())
cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation(
@ -257,16 +393,14 @@ cdef class _AioCall(GrpcCallWrapper):
self._loop)
# Receives initial metadata.
initial_metadata_observer(
self._set_initial_metadata(
await _receive_initial_metadata(self,
self._loop),
)
async def stream_unary(self,
tuple outbound_initial_metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
object metadata_sent_observer):
"""Actual implementation of the complete unary-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
@ -281,9 +415,8 @@ cdef class _AioCall(GrpcCallWrapper):
metadata_sent_observer()
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self,
self._loop),
self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop)
)
cdef tuple inbound_ops
@ -296,26 +429,24 @@ cdef class _AioCall(GrpcCallWrapper):
inbound_ops,
self._loop)
status = AioRpcStatus(
receive_status_on_client_op.code(),
cdef grpc_status_code code
code = receive_status_on_client_op.code()
self._set_status(AioRpcStatus(
code,
receive_status_on_client_op.details(),
receive_status_on_client_op.trailing_metadata(),
receive_status_on_client_op.error_string(),
)
# Reports the final status of the RPC to Python layer. The observer
# pattern is used here to unify unary and streaming code path.
status_observer(status)
))
if status.code() == StatusCode.ok:
if code == StatusCode.ok:
return receive_message_op.message()
else:
return None
async def initiate_stream_stream(self,
tuple outbound_initial_metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
object metadata_sent_observer):
"""Actual implementation of the complete stream-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
@ -324,7 +455,7 @@ cdef class _AioCall(GrpcCallWrapper):
"""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received(status_observer))
self._loop.create_task(self._handle_status_once_received())
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
@ -334,7 +465,6 @@ cdef class _AioCall(GrpcCallWrapper):
metadata_sent_observer()
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self,
self._loop),
self._set_initial_metadata(
await _receive_initial_metadata(self, self._loop)
)

@ -21,6 +21,6 @@ cdef class AioChannel:
cdef:
grpc_channel * channel
CallbackCompletionQueue cq
object loop
bytes _target
object _loop
AioChannelStatus _status

@ -25,13 +25,13 @@ cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailur
cdef class AioChannel:
def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials):
def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials, object loop):
if options is None:
options = ()
cdef _ChannelArgs channel_args = _ChannelArgs(options)
self._target = target
self.cq = CallbackCompletionQueue()
self._loop = asyncio.get_event_loop()
self.loop = loop
self._status = AIO_CHANNEL_STATUS_READY
if credentials is None:
@ -71,7 +71,7 @@ cdef class AioChannel:
raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef object future = self._loop.create_future()
cdef object future = self.loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
_WATCH_CONNECTIVITY_FAILURE_HANDLER)
@ -112,5 +112,4 @@ cdef class AioChannel:
else:
cython_call_credentials = None
cdef _AioCall call = _AioCall(self, deadline, method, cython_call_credentials)
return call
return _AioCall(self, deadline, method, cython_call_credentials)

@ -40,7 +40,7 @@ cdef class RPCState:
self.abort_exception = None
self.metadata_sent = False
self.status_sent = False
self.trailing_metadata = _EMPTY_METADATA
self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA
cdef bytes method(self):
return _slice_bytes(self.details.method)
@ -129,7 +129,7 @@ cdef class _ServicerContext:
async def abort(self,
object code,
str details='',
tuple trailing_metadata=_EMPTY_METADATA):
tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA):
if self._rpc_state.abort_exception is not None:
raise RuntimeError('Abort already called!')
else:
@ -138,7 +138,7 @@ cdef class _ServicerContext:
# could lead to undefined behavior.
self._rpc_state.abort_exception = AbortError('Locally aborted.')
if trailing_metadata == _EMPTY_METADATA and self._rpc_state.trailing_metadata:
if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata:
trailing_metadata = self._rpc_state.trailing_metadata
self._rpc_state.status_sent = True
@ -471,7 +471,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
rpc_state,
StatusCode.unimplemented,
'Method not found!',
_EMPTY_METADATA,
_IMMUTABLE_EMPTY_METADATA,
rpc_state.metadata_sent,
loop
)

@ -14,7 +14,8 @@
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import AsyncIterable, Awaitable, List, Dict, Optional
from functools import partial
from typing import AsyncIterable, List, Dict, Optional
import grpc
from grpc import _common
@ -42,8 +43,6 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tdebug_error_string = "{}"\n'
'>')
_EMPTY_METADATA = tuple()
class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API.
@ -153,116 +152,69 @@ class Call(_base_call.Call):
"""
_loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType]
_locally_cancelled: bool
_cython_call: cygrpc._AioCall
_done_callbacks: List[DoneCallbackType]
def __init__(self, cython_call: cygrpc._AioCall) -> None:
self._loop = asyncio.get_event_loop()
self._code = None
self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future()
self._locally_cancelled = False
def __init__(self, cython_call: cygrpc._AioCall,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._cython_call = cython_call
self._done_callbacks = []
def __del__(self) -> None:
if not self._status.done():
self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
if not self._cython_call.done():
self._cancel(_GC_CANCELLATION_DETAILS)
def cancelled(self) -> bool:
return self._code == grpc.StatusCode.CANCELLED
return self._cython_call.cancelled()
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
def _cancel(self, details: str) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._status.done():
self._set_status(status)
self._cython_call.cancel(status)
if not self._cython_call.done():
self._cython_call.cancel(details)
return True
else:
return False
def cancel(self) -> bool:
return self._cancel(
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_LOCAL_CANCELLATION_DETAILS, None, None))
return self._cancel(_LOCAL_CANCELLATION_DETAILS)
def done(self) -> bool:
return self._status.done()
return self._cython_call.done()
def add_done_callback(self, callback: DoneCallbackType) -> None:
if self.done():
callback(self)
else:
self._done_callbacks.append(callback)
cb = partial(callback, self)
self._cython_call.add_done_callback(cb)
def time_remaining(self) -> Optional[float]:
return self._cython_call.time_remaining()
async def initial_metadata(self) -> MetadataType:
return await self._initial_metadata
return await self._cython_call.initial_metadata()
async def trailing_metadata(self) -> MetadataType:
return (await self._status).trailing_metadata()
return (await self._cython_call.status()).trailing_metadata()
async def code(self) -> grpc.StatusCode:
await self._status
return self._code
cygrpc_code = (await self._cython_call.status()).code()
return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
async def details(self) -> str:
return (await self._status).details()
return (await self._cython_call.status()).details()
async def debug_error_string(self) -> str:
return (await self._status).debug_error_string()
def _set_initial_metadata(self, metadata: MetadataType) -> None:
self._initial_metadata.set_result(metadata)
def _set_status(self, status: cygrpc.AioRpcStatus) -> None:
"""Private method to set final status of the RPC.
This method should only be invoked once.
"""
# In case of local cancellation, flip the flag.
if status.details() is _LOCAL_CANCELLATION_DETAILS:
self._locally_cancelled = True
# In case of the RPC finished without receiving metadata.
if not self._initial_metadata.done():
self._initial_metadata.set_result(_EMPTY_METADATA)
# Sets final status
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()]
for callback in self._done_callbacks:
callback(self)
return (await self._cython_call.status()).debug_error_string()
async def _raise_for_status(self) -> None:
if self._locally_cancelled:
if self._cython_call.is_locally_cancelled():
raise asyncio.CancelledError()
await self._status
if self._code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(),
self._status.result())
code = await self.code()
if code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(), await
self._cython_call.status())
def _repr(self) -> str:
"""Assembles the RPC representation string."""
if not self._status.done():
return '<{} object>'.format(self.__class__.__name__)
if self._code is grpc.StatusCode.OK:
return _OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().details())
else:
return _NON_OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().details(),
self._status.result().debug_error_string())
return repr(self._cython_call)
def __repr__(self) -> str:
return self._repr()
@ -288,13 +240,14 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke())
self._call = loop.create_task(self._invoke())
def cancel(self) -> bool:
if super().cancel():
@ -312,11 +265,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try:
serialized_response = await self._cython_call.unary_unary(
serialized_request,
self._metadata,
self._set_initial_metadata,
self._set_status,
)
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
@ -360,13 +309,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._send_unary_request_task = self._loop.create_task(
self._send_unary_request_task = loop.create_task(
self._send_unary_request())
self._message_aiter = None
@ -382,8 +332,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._request_serializer)
try:
await self._cython_call.initiate_unary_stream(
serialized_request, self._metadata, self._set_initial_metadata,
self._set_status)
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
@ -419,7 +368,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._response_deserializer)
async def read(self) -> ResponseType:
if self._status.done():
if self._cython_call.done():
await self._raise_for_status()
return cygrpc.EOF
@ -452,16 +401,17 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=self._loop)
self._metadata_sent = asyncio.Event(loop=loop)
self._done_writing = False
self._call_finisher = self._loop.create_task(self._conduct_rpc())
self._call_finisher = loop.create_task(self._conduct_rpc())
# If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None:
@ -485,11 +435,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
async def _conduct_rpc(self) -> ResponseType:
try:
serialized_response = await self._cython_call.stream_unary(
self._metadata,
self._metadata_sent_observer,
self._set_initial_metadata,
self._set_status,
)
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
@ -517,7 +463,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
return response
async def write(self, request: RequestType) -> None:
if self._status.done():
if self._cython_call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
@ -536,7 +482,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._status.done():
if self._cython_call.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
@ -572,20 +518,21 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=self._loop)
self._metadata_sent = asyncio.Event(loop=loop)
self._done_writing = False
self._initializer = self._loop.create_task(self._prepare_rpc())
# If user passes in an async iterator, create a consumer coroutine.
if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task(
self._async_request_poller = loop.create_task(
self._consume_request_iterator(request_async_iterator))
else:
self._async_request_poller = None
@ -611,11 +558,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
"""
try:
await self._cython_call.initiate_stream_stream(
self._metadata,
self._metadata_sent_observer,
self._set_initial_metadata,
self._set_status,
)
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
@ -629,7 +572,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
await self.done_writing()
async def write(self, request: RequestType) -> None:
if self._status.done():
if self._cython_call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
@ -648,7 +591,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._status.done():
if self._cython_call.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
@ -692,7 +635,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
self._response_deserializer)
async def read(self) -> ResponseType:
if self._status.done():
if self._cython_call.done():
await self._raise_for_status()
return cygrpc.EOF

@ -28,6 +28,8 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction)
from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple()
class _BaseMultiCallable:
"""Base class of all multi callable objects.
@ -47,12 +49,14 @@ class _BaseMultiCallable:
_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_loop: asyncio.AbstractEventLoop
def __init__(self, channel: cygrpc.AioChannel, method: bytes,
def __init__(self, channel: cygrpc.AioChannel,
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]],
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = asyncio.get_event_loop()
self._loop = loop
self._channel = channel
self._method = method
self._request_serializer = request_serializer
@ -102,31 +106,20 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raise NotImplementedError("TODO: compression not implemented yet")
if metadata is None:
metadata = tuple()
metadata = _IMMUTABLE_EMPTY_TUPLE
if not self._interceptors:
return UnaryUnaryCall(
request,
_timeout_to_deadline(timeout),
metadata,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
return UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
return InterceptedUnaryUnaryCall(
self._interceptors,
request,
timeout,
metadata,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
return InterceptedUnaryUnaryCall(self._interceptors, request,
timeout, metadata, credentials,
self._channel, self._method,
self._request_serializer,
self._response_deserializer,
self._loop)
class UnaryStreamMultiCallable(_BaseMultiCallable):
@ -168,18 +161,12 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
metadata = _IMMUTABLE_EMPTY_TUPLE
return UnaryStreamCall(
request,
deadline,
metadata,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
return UnaryStreamCall(request, deadline, metadata, credentials,
self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
class StreamUnaryMultiCallable(_BaseMultiCallable):
@ -225,18 +212,12 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamUnaryCall(
request_async_iterator,
deadline,
metadata,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
return StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
class StreamStreamMultiCallable(_BaseMultiCallable):
@ -282,18 +263,12 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamStreamCall(
request_async_iterator,
deadline,
metadata,
credentials,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
return StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
class Channel:
@ -301,6 +276,7 @@ class Channel:
A cygrpc.AioChannel-backed implementation.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
@ -341,8 +317,9 @@ class Channel:
"UnaryUnaryClientInterceptors, the following are invalid: {}"\
.format(invalid_interceptors))
self._loop = asyncio.get_event_loop()
self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials)
credentials, self._loop)
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
@ -408,10 +385,12 @@ class Channel:
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
"""
return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
return UnaryUnaryMultiCallable(self._channel,
_common.encode(method),
request_serializer,
response_deserializer,
self._unary_unary_interceptors)
self._unary_unary_interceptors,
self._loop)
def unary_stream(
self,
@ -419,9 +398,11 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
return UnaryStreamMultiCallable(self._channel,
_common.encode(method),
request_serializer,
response_deserializer, None)
response_deserializer,
None, self._loop)
def stream_unary(
self,
@ -429,9 +410,11 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
return StreamUnaryMultiCallable(self._channel,
_common.encode(method),
request_serializer,
response_deserializer, None)
response_deserializer,
None, self._loop)
def stream_stream(
self,
@ -439,9 +422,11 @@ class Channel:
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, _common.encode(method),
return StreamStreamMultiCallable(self._channel,
_common.encode(method),
request_serializer,
response_deserializer, None)
response_deserializer,
None, self._loop)
async def _close(self):
# TODO: Send cancellation status

@ -110,12 +110,14 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._channel = channel
self._loop = asyncio.get_event_loop()
self._interceptors_task = asyncio.ensure_future(
self._invoke(interceptors, method, timeout, metadata, credentials,
request, request_serializer, response_deserializer))
self._loop = loop
self._interceptors_task = asyncio.ensure_future(self._invoke(
interceptors, method, timeout, metadata, credentials, request,
request_serializer, response_deserializer),
loop=loop)
def __del__(self):
self.cancel()
@ -154,7 +156,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
client_call_details.metadata,
client_call_details.credentials, self._channel,
client_call_details.method, request_serializer,
response_deserializer)
response_deserializer, self._loop)
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials)

@ -48,6 +48,16 @@ class _MulticallableTestMixin():
class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
async def test_call_to_string(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None)
response = await call
self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None)
async def test_call_ok(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
@ -105,6 +115,65 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata())
async def test_call_initial_metadata_cancelable(self):
coro_started = asyncio.Event()
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro():
coro_started.set()
await call.initial_metadata()
task = self.loop.create_task(coro())
await coro_started.wait()
task.cancel()
# Test that initial metadata can still be asked thought
# a cancellation happened with the previous task
self.assertEqual((), await call.initial_metadata())
async def test_call_initial_metadata_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro():
return await call.initial_metadata()
task1 = self.loop.create_task(coro())
task2 = self.loop.create_task(coro())
await call
self.assertEqual([(), ()], await asyncio.gather(*[task1, task2]))
async def test_call_code_cancelable(self):
coro_started = asyncio.Event()
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro():
coro_started.set()
await call.code()
task = self.loop.create_task(coro())
await coro_started.wait()
task.cancel()
# Test that code can still be asked thought
# a cancellation happened with the previous task
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_call_code_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro():
return await call.code()
task1 = self.loop.create_task(coro())
task2 = self.loop.create_task(coro())
await call
self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await
asyncio.gather(task1, task2))
async def test_cancel_unary_unary(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())

Loading…
Cancel
Save