diff --git a/src/python/grpcio/grpc/_cython/BUILD.bazel b/src/python/grpcio/grpc/_cython/BUILD.bazel index 3a355527a00..7212ceec4be 100644 --- a/src/python/grpcio/grpc/_cython/BUILD.bazel +++ b/src/python/grpcio/grpc/_cython/BUILD.bazel @@ -13,6 +13,8 @@ pyx_library( "_cygrpc/aio/rpc_error.pxd.pxi", "_cygrpc/aio/rpc_error.pyx.pxi", "_cygrpc/aio/callbackcontext.pxd.pxi", + "_cygrpc/aio/cancel_status.pxd.pxi", + "_cygrpc/aio/cancel_status.pyx.pxi", "_cygrpc/aio/channel.pxd.pxi", "_cygrpc/aio/channel.pyx.pxi", "_cygrpc/aio/grpc_aio.pxd.pxi", diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi index 1166551fd5c..687c5999d0b 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi @@ -20,8 +20,9 @@ cdef class _AioCall: grpc_completion_queue * _cq grpc_experimental_completion_queue_functor _functor object _waiter_call + list _references @staticmethod - cdef void functor_run(grpc_experimental_completion_queue_functor* functor, int succeed) + cdef void functor_run(grpc_experimental_completion_queue_functor* functor, int success) with gil @staticmethod - cdef void watcher_call_functor_run(grpc_experimental_completion_queue_functor* functor, int succeed) + cdef void watcher_call_functor_run(grpc_experimental_completion_queue_functor* functor, int success) with gil diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index d758c1eedcb..0ce716a9d6e 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -34,6 +34,7 @@ cdef class _AioCall: self._watcher_call.functor.functor_run = _AioCall.watcher_call_functor_run self._watcher_call.waiter = self self._waiter_call = None + self._references = [] def __dealloc__(self): grpc_completion_queue_shutdown(self._cq) @@ -45,21 +46,20 @@ cdef class _AioCall: return f"<{class_name} {id_}>" @staticmethod - cdef void functor_run(grpc_experimental_completion_queue_functor* functor, int succeed): + cdef void functor_run(grpc_experimental_completion_queue_functor* functor, int success) with gil: pass @staticmethod - cdef void watcher_call_functor_run(grpc_experimental_completion_queue_functor* functor, int succeed): + cdef void watcher_call_functor_run(grpc_experimental_completion_queue_functor* functor, int success) with gil: call = <_AioCall>(functor).waiter - assert call._waiter_call + if not call._waiter_call.done(): + if success == 0: + call._waiter_call.set_exception(Exception("Some error occurred")) + else: + call._waiter_call.set_result(None) - if succeed == 0: - call._waiter_call.set_exception(Exception("Some error occurred")) - else: - call._waiter_call.set_result(None) - - async def unary_unary(self, method, request, timeout): + async def unary_unary(self, bytes method, bytes request, object timeout, AioCancelStatus cancel_status): cdef grpc_call * call cdef grpc_slice method_slice cdef grpc_op * ops @@ -73,6 +73,7 @@ cdef class _AioCall: cdef grpc_call_error call_status cdef gpr_timespec deadline = _timespec_from_time(timeout) + cdef char *c_details = NULL method_slice = grpc_slice_from_copied_buffer( method, @@ -133,8 +134,21 @@ cdef class _AioCall: self._waiter_call = None raise Exception("Error with grpc_call_start_batch {}".format(call_status)) - await self._waiter_call - + try: + await self._waiter_call + except asyncio.CancelledError: + if cancel_status: + details = str_to_bytes(cancel_status.details()) + self._references.append(details) + c_details = details + call_status = grpc_call_cancel_with_status( + call, cancel_status.code(), c_details, NULL) + else: + call_status = grpc_call_cancel( + call, NULL) + if call_status != GRPC_CALL_OK: + raise Exception("RPC call couldn't be cancelled. Error {}".format(call_status)) + raise finally: initial_metadata_operation.un_c() send_message_operation.un_c() @@ -149,7 +163,7 @@ cdef class _AioCall: if receive_status_on_client_operation.code() == StatusCode.ok: return receive_message_operation.message() - raise grpc.experimental.aio.AioRpcError( + raise AioRpcError( receive_initial_metadata_operation.initial_metadata(), receive_status_on_client_operation.code(), receive_status_on_client_operation.details(), diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pxd.pxi new file mode 100644 index 00000000000..47670e5deb1 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pxd.pxi @@ -0,0 +1,23 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Desired cancellation status for canceling an ongoing RPC calls.""" + + +cdef class AioCancelStatus: + cdef readonly: + object _code + str _details + + cpdef object code(self) + cpdef str details(self) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pyx.pxi new file mode 100644 index 00000000000..e2026458e3c --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/cancel_status.pyx.pxi @@ -0,0 +1,36 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Desired cancellation status for canceling an ongoing RPC call.""" + + +cdef class AioCancelStatus: + + def __cinit__(self): + self._code = None + self._details = None + + def __len__(self): + if self._code is None: + return 0 + return 1 + + def cancel(self, grpc_status_code code, str details=None): + self._code = code + self._details = details + + cpdef object code(self): + return self._code + + cpdef str details(self): + return self._details diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index cbcd4553864..526dade7f51 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -25,6 +25,6 @@ cdef class AioChannel: def close(self): grpc_channel_destroy(self.channel) - async def unary_unary(self, method, request, timeout): + async def unary_unary(self, method, request, timeout, cancel_status): call = _AioCall(self) - return await call.unary_unary(method, request, timeout) + return await call.unary_unary(method, request, timeout, cancel_status) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pyx.pxi index 95b9144eff9..ca8a584d7a7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_error.pyx.pxi @@ -14,7 +14,7 @@ """Exceptions for the aio version of the RPC calls.""" -cdef class _AioRpcError(Exception): +cdef class AioRpcError(Exception): def __cinit__(self, tuple initial_metadata, int code, str details, tuple trailing_metadata): self._initial_metadata = initial_metadata diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 5bab542f467..dd6ff8b29d8 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -35,9 +35,9 @@ cdef class CallbackWrapper: @staticmethod cdef void functor_run( grpc_experimental_completion_queue_functor* functor, - int succeed): + int success): cdef CallbackContext *context = functor - if succeed == 0: + if success == 0: (context.waiter).set_exception(RuntimeError()) else: (context.waiter).set_result(None) diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pxd b/src/python/grpcio/grpc/_cython/cygrpc.pxd index b0a2033cc40..9ab15528a6b 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pxd +++ b/src/python/grpcio/grpc/_cython/cygrpc.pxd @@ -46,5 +46,6 @@ include "_cygrpc/aio/iomgr/resolver.pxd.pxi" include "_cygrpc/aio/grpc_aio.pxd.pxi" include "_cygrpc/aio/callbackcontext.pxd.pxi" include "_cygrpc/aio/call.pxd.pxi" +include "_cygrpc/aio/cancel_status.pxd.pxi" include "_cygrpc/aio/channel.pxd.pxi" include "_cygrpc/aio/server.pxd.pxi" diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index 5f980bb46f0..7be32661c87 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -62,6 +62,7 @@ include "_cygrpc/aio/iomgr/timer.pyx.pxi" include "_cygrpc/aio/iomgr/resolver.pyx.pxi" include "_cygrpc/aio/grpc_aio.pyx.pxi" include "_cygrpc/aio/call.pyx.pxi" +include "_cygrpc/aio/cancel_status.pyx.pxi" include "_cygrpc/aio/channel.pyx.pxi" include "_cygrpc/aio/rpc_error.pyx.pxi" include "_cygrpc/aio/server.pyx.pxi" diff --git a/src/python/grpcio/grpc/experimental/BUILD.bazel b/src/python/grpcio/grpc/experimental/BUILD.bazel index c9f0484c886..5654d08a45b 100644 --- a/src/python/grpcio/grpc/experimental/BUILD.bazel +++ b/src/python/grpcio/grpc/experimental/BUILD.bazel @@ -4,6 +4,7 @@ py_library( name = "aio", srcs = [ "aio/__init__.py", + "aio/_call.py", "aio/_channel.py", "aio/_server.py", ], diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 5e919f500bd..696db001133 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -14,16 +14,16 @@ """gRPC's Asynchronous Python API.""" import abc -import types import six import grpc -from grpc._cython import cygrpc from grpc._cython.cygrpc import init_grpc_aio -from ._server import server +from ._call import AioRpcError +from ._call import Call from ._channel import Channel from ._channel import UnaryUnaryMultiCallable +from ._server import server def insecure_channel(target, options=None, compression=None): @@ -39,39 +39,11 @@ def insecure_channel(target, options=None, compression=None): Returns: A Channel. """ - from grpc.experimental.aio import _channel # pylint: disable=cyclic-import - return _channel.Channel(target, () - if options is None else options, None, compression) - - -class _AioRpcError: - """Private implementation of AioRpcError""" - - -class AioRpcError: - """An RpcError to be used by the asynchronous API. - - Parent classes: (cygrpc._AioRpcError, RpcError) - """ - # Dynamically registered as subclass of _AioRpcError and RpcError, because the former one is - # only available after the cython code has been compiled. - _class_built = _AioRpcError - - def __new__(cls, *args, **kwargs): - if cls._class_built is _AioRpcError: - cls._class_built = types.new_class( - "AioRpcError", (cygrpc._AioRpcError, grpc.RpcError)) - cls._class_built.__doc__ = cls.__doc__ - - return cls._class_built(*args, **kwargs) + return Channel(target, () + if options is None else options, None, compression) ################################### __all__ ################################# -__all__ = ( - 'init_grpc_aio', - 'Channel', - 'UnaryUnaryMultiCallable', - 'insecure_channel', - 'AioRpcError', -) +__all__ = ('AioRpcError', 'Call', 'init_grpc_aio', 'Channel', + 'UnaryUnaryMultiCallable', 'insecure_channel', 'server') diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py new file mode 100644 index 00000000000..70ac3628971 --- /dev/null +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -0,0 +1,262 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Invocation-side implementation of gRPC Asyncio Python.""" +import asyncio +import enum +from typing import Callable, Dict, Optional, ClassVar + +import grpc +from grpc import _common +from grpc._cython import cygrpc + +DeserializingFunction = Callable[[bytes], str] + + +class AioRpcError(grpc.RpcError): + """An RpcError to be used by the asynchronous API.""" + + # TODO(https://github.com/grpc/grpc/issues/20144) Metadata + # type returned by `initial_metadata` and `trailing_metadata` + # and also taken in the constructor needs to be revisit and make + # it more specific. + + _code: grpc.StatusCode + _details: Optional[str] + _initial_metadata: Optional[Dict] + _trailing_metadata: Optional[Dict] + + def __init__(self, + code: grpc.StatusCode, + details: Optional[str] = None, + initial_metadata: Optional[Dict] = None, + trailing_metadata: Optional[Dict] = None): + """Constructor. + + Args: + code: The status code with which the RPC has been finalized. + details: Optional details explaining the reason of the error. + initial_metadata: Optional initial metadata that could be sent by the + Server. + trailing_metadata: Optional metadata that could be sent by the Server. + """ + + super().__init__(self) + self._code = code + self._details = details + self._initial_metadata = initial_metadata + self._trailing_metadata = trailing_metadata + + def code(self) -> grpc.StatusCode: + """ + Returns: + The `grpc.StatusCode` status code. + """ + return self._code + + def details(self) -> Optional[str]: + """ + Returns: + The description of the error. + """ + return self._details + + def initial_metadata(self) -> Optional[Dict]: + """ + Returns: + The inital metadata received. + """ + return self._initial_metadata + + def trailing_metadata(self) -> Optional[Dict]: + """ + Returns: + The trailing metadata received. + """ + return self._trailing_metadata + + +@enum.unique +class _RpcState(enum.Enum): + """Identifies the state of the RPC.""" + ONGOING = 1 + CANCELLED = 2 + FINISHED = 3 + ABORT = 4 + + +class Call: + """Object for managing RPC calls, + returned when an instance of `UnaryUnaryMultiCallable` object is called. + """ + + _cancellation_details: ClassVar[str] = 'Locally cancelled by application!' + + _state: _RpcState + _exception: Optional[Exception] + _response: Optional[bytes] + _code: grpc.StatusCode + _details: Optional[str] + _initial_metadata: Optional[Dict] + _trailing_metadata: Optional[Dict] + _call: asyncio.Task + _call_cancel_status: cygrpc.AioCancelStatus + _response_deserializer: DeserializingFunction + + def __init__(self, call: asyncio.Task, + response_deserializer: DeserializingFunction, + call_cancel_status: cygrpc.AioCancelStatus) -> None: + """Constructor. + + Args: + call: Asyncio Task that holds the RPC execution. + response_deserializer: Deserializer used for parsing the reponse. + call_cancel_status: A cygrpc.AioCancelStatus used for giving a + specific error when the RPC is canceled. + """ + + self._state = _RpcState.ONGOING + self._exception = None + self._response = None + self._code = grpc.StatusCode.UNKNOWN + self._details = None + self._initial_metadata = None + self._trailing_metadata = None + self._call = call + self._call_cancel_status = call_cancel_status + self._response_deserializer = response_deserializer + + def __del__(self): + self.cancel() + + def cancel(self) -> bool: + """Cancels the ongoing RPC request. + + Returns: + True if the RPC can be canceled, False if was already cancelled or terminated. + """ + if self.cancelled() or self.done(): + return False + + code = grpc.StatusCode.CANCELLED + self._call_cancel_status.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], + details=Call._cancellation_details) + self._call.cancel() + self._details = Call._cancellation_details + self._code = code + self._state = _RpcState.CANCELLED + return True + + def cancelled(self) -> bool: + """Returns if the RPC was cancelled. + + Returns: + True if the requests was cancelled, False if not. + """ + return self._state is _RpcState.CANCELLED + + def running(self) -> bool: + """Returns if the RPC is running. + + Returns: + True if the requests is running, False if it already terminated. + """ + return not self.done() + + def done(self) -> bool: + """Returns if the RPC has finished. + + Returns: + True if the requests has finished, False is if still ongoing. + """ + return self._state is not _RpcState.ONGOING + + async def initial_metadata(self): + raise NotImplementedError() + + async def trailing_metadata(self): + raise NotImplementedError() + + async def code(self) -> grpc.StatusCode: + """Returns the `grpc.StatusCode` if the RPC is finished, + otherwise first waits until the RPC finishes. + + Returns: + The `grpc.StatusCode` status code. + """ + if not self.done(): + try: + await self + except (asyncio.CancelledError, AioRpcError): + pass + + return self._code + + async def details(self) -> str: + """Returns the details if the RPC is finished, otherwise first waits till the + RPC finishes. + + Returns: + The details. + """ + if not self.done(): + try: + await self + except (asyncio.CancelledError, AioRpcError): + pass + + return self._details + + def __await__(self): + """Wait till the ongoing RPC request finishes. + + Returns: + Response of the RPC call. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + # We can not relay on the `done()` method since some exceptions + # might be pending to be catched, like `asyncio.CancelledError`. + if self._response: + return self._response + elif self._exception: + raise self._exception + + try: + buffer_ = yield from self._call.__await__() + except cygrpc.AioRpcError as aio_rpc_error: + self._state = _RpcState.ABORT + self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[ + aio_rpc_error.code()] + self._details = aio_rpc_error.details() + self._initial_metadata = aio_rpc_error.initial_metadata() + self._trailing_metadata = aio_rpc_error.trailing_metadata() + + # Propagates the pure Python class + self._exception = AioRpcError(self._code, self._details, + self._initial_metadata, + self._trailing_metadata) + raise self._exception from aio_rpc_error + except asyncio.CancelledError as cancel_error: + # _state, _code, _details are managed in the `cancel` method + self._exception = cancel_error + raise + + self._response = _common.deserialize(buffer_, + self._response_deserializer) + self._code = grpc.StatusCode.OK + self._state = _RpcState.FINISHED + return self._response diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 9bef7cbeaa8..389b952b0b2 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -18,6 +18,8 @@ from typing import Callable, Optional from grpc import _common from grpc._cython import cygrpc +from ._call import Call + SerializingFunction = Callable[[str], bytes] DeserializingFunction = Callable[[bytes], str] @@ -39,14 +41,14 @@ class UnaryUnaryMultiCallable: return None return self._loop.time() + timeout - async def __call__(self, - request, - *, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def __call__(self, + request, + *, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None) -> Call: """Asynchronously invokes the underlying RPC. Args: @@ -63,7 +65,7 @@ class UnaryUnaryMultiCallable: grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: - The response value for the RPC. + A Call object instance which is an awaitable object. Raises: RpcError: Indicating that the RPC terminated with non-OK status. The @@ -87,9 +89,12 @@ class UnaryUnaryMultiCallable: serialized_request = _common.serialize(request, self._request_serializer) timeout = self._timeout_to_deadline(timeout) - response = await self._channel.unary_unary(self._method, - serialized_request, timeout) - return _common.deserialize(response, self._response_deserializer) + aio_cancel_status = cygrpc.AioCancelStatus() + aio_call = asyncio.ensure_future( + self._channel.unary_unary(self._method, serialized_request, timeout, + aio_cancel_status), + loop=self._loop) + return Call(aio_call, self._response_deserializer, aio_cancel_status) class Channel: diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 0fab86e49bc..8d51c9aaf8d 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -1,7 +1,8 @@ [ "_sanity._sanity_test.AioSanityTest", + "unit.call_test.TestAioRpcError", + "unit.call_test.TestCall", "unit.channel_test.TestChannel", - "unit.init_test.TestAioRpcError", "unit.init_test.TestInsecureChannel", "unit.server_test.TestServer" ] diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 5f3661f42cf..4b6ceebc816 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse - -from concurrent import futures from time import sleep -import grpc from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import test_pb2_grpc diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py new file mode 100644 index 00000000000..b1f470a3756 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -0,0 +1,196 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import unittest + +import grpc + +from grpc.experimental import aio +from src.proto.grpc.testing import messages_pb2 +from tests.unit.framework.common import test_constants +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase + + +class TestAioRpcError(unittest.TestCase): + _TEST_INITIAL_METADATA = ("initial metadata",) + _TEST_TRAILING_METADATA = ("trailing metadata",) + + def test_attributes(self): + aio_rpc_error = aio.AioRpcError( + grpc.StatusCode.CANCELLED, + "details", + initial_metadata=self._TEST_INITIAL_METADATA, + trailing_metadata=self._TEST_TRAILING_METADATA) + self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(aio_rpc_error.details(), "details") + self.assertEqual(aio_rpc_error.initial_metadata(), + self._TEST_INITIAL_METADATA) + self.assertEqual(aio_rpc_error.trailing_metadata(), + self._TEST_TRAILING_METADATA) + + +class TestCall(AioTestBase): + + def test_call_ok(self): + + async def coro(): + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest. + SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString + ) + call = hi(messages_pb2.SimpleRequest()) + + self.assertFalse(call.done()) + + response = await call + + self.assertTrue(call.done()) + self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + # Response is cached at call object level, reentrance + # returns again the same response + response_retry = await call + self.assertIs(response, response_retry) + + self.loop.run_until_complete(coro()) + + def test_call_rpc_error(self): + + async def coro(): + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target) as channel: + empty_call_with_sleep = channel.unary_unary( + "/grpc.testing.TestService/EmptyCall", + request_serializer=messages_pb2.SimpleRequest. + SerializeToString, + response_deserializer=messages_pb2.SimpleResponse. + FromString, + ) + timeout = test_constants.SHORT_TIMEOUT / 2 + # TODO(https://github.com/grpc/grpc/issues/20869 + # Update once the async server is ready, change the + # synchronization mechanism by removing the sleep() + # as both components (client & server) will be on the same + # process. + call = empty_call_with_sleep( + messages_pb2.SimpleRequest(), timeout=timeout) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + + self.assertTrue(call.done()) + self.assertEqual(await call.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + + # Exception is cached at call object level, reentrance + # returns again the same exception + with self.assertRaises( + grpc.RpcError) as exception_context_retry: + await call + + self.assertIs(exception_context.exception, + exception_context_retry.exception) + + self.loop.run_until_complete(coro()) + + def test_call_code_awaitable(self): + + async def coro(): + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest. + SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString + ) + call = hi(messages_pb2.SimpleRequest()) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.loop.run_until_complete(coro()) + + def test_call_details_awaitable(self): + + async def coro(): + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest. + SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString + ) + call = hi(messages_pb2.SimpleRequest()) + self.assertEqual(await call.details(), None) + + self.loop.run_until_complete(coro()) + + def test_cancel(self): + + async def coro(): + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest. + SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString + ) + call = hi(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + + # Force the loop to execute the RPC task, cython + # code is executed. + await asyncio.sleep(0) + + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertFalse(call.cancel()) + + with self.assertRaises( + asyncio.CancelledError) as exception_context: + await call + + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + 'Locally cancelled by application!') + + # Exception is cached at call object level, reentrance + # returns again the same exception + with self.assertRaises( + asyncio.CancelledError) as exception_context_retry: + await call + + self.assertIs(exception_context.exception, + exception_context_retry.exception) + + self.loop.run_until_complete(coro()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index e18b6da6d39..96817c61a6f 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import asyncio import logging import unittest @@ -30,7 +28,7 @@ class TestChannel(AioTestBase): def test_async_context(self): async def coro(): - server_target, unused_server = await start_test_server() + server_target, _ = await start_test_server() # pylint: disable=unused-variable async with aio.insecure_channel(server_target) as channel: hi = channel.unary_unary( @@ -46,7 +44,7 @@ class TestChannel(AioTestBase): def test_unary_unary(self): async def coro(): - server_target, unused_server = await start_test_server() + server_target, _ = await start_test_server() # pylint: disable=unused-variable channel = aio.insecure_channel(server_target) hi = channel.unary_unary( @@ -55,7 +53,7 @@ class TestChannel(AioTestBase): response_deserializer=messages_pb2.SimpleResponse.FromString) response = await hi(messages_pb2.SimpleRequest()) - self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertIs(type(response), messages_pb2.SimpleResponse) await channel.close() @@ -64,7 +62,7 @@ class TestChannel(AioTestBase): def test_unary_call_times_out(self): async def coro(): - server_target, unused_server = await start_test_server() + server_target, _ = await start_test_server() # pylint: disable=unused-variable async with aio.insecure_channel(server_target) as channel: empty_call_with_sleep = channel.unary_unary( @@ -75,15 +73,18 @@ class TestChannel(AioTestBase): FromString, ) timeout = test_constants.SHORT_TIMEOUT / 2 - # TODO: Update once the async server is ready, change the synchronization mechanism by removing the - # sleep() as both components (client & server) will be on the same process. + # TODO(https://github.com/grpc/grpc/issues/20869) + # Update once the async server is ready, change the + # synchronization mechanism by removing the sleep() + # as both components (client & server) will be on the same + # process. with self.assertRaises(grpc.RpcError) as exception_context: await empty_call_with_sleep( messages_pb2.SimpleRequest(), timeout=timeout) - status_code, details = grpc.StatusCode.DEADLINE_EXCEEDED.value + _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable self.assertEqual(exception_context.exception.code(), - status_code) + grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(exception_context.exception.details(), details.title()) self.assertIsNotNone( diff --git a/src/python/grpcio_tests/tests_aio/unit/init_test.py b/src/python/grpcio_tests/tests_aio/unit/init_test.py index 297f178ee44..9f5d8bb0d85 100644 --- a/src/python/grpcio_tests/tests_aio/unit/init_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/init_test.py @@ -11,62 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import asyncio import logging import unittest -import grpc from grpc.experimental import aio from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase -class TestAioRpcError(unittest.TestCase): - _TEST_INITIAL_METADATA = ("initial metadata",) - _TEST_TRAILING_METADATA = ("trailing metadata",) - - def test_attributes(self): - aio_rpc_error = aio.AioRpcError(self._TEST_INITIAL_METADATA, 0, - "details", self._TEST_TRAILING_METADATA) - self.assertEqual(aio_rpc_error.initial_metadata(), - self._TEST_INITIAL_METADATA) - self.assertEqual(aio_rpc_error.code(), 0) - self.assertEqual(aio_rpc_error.details(), "details") - self.assertEqual(aio_rpc_error.trailing_metadata(), - self._TEST_TRAILING_METADATA) - - def test_class_hierarchy(self): - aio_rpc_error = aio.AioRpcError(self._TEST_INITIAL_METADATA, 0, - "details", self._TEST_TRAILING_METADATA) - - self.assertIsInstance(aio_rpc_error, grpc.RpcError) - - def test_class_attributes(self): - aio_rpc_error = aio.AioRpcError(self._TEST_INITIAL_METADATA, 0, - "details", self._TEST_TRAILING_METADATA) - self.assertEqual(aio_rpc_error.__class__.__name__, "AioRpcError") - self.assertEqual(aio_rpc_error.__class__.__doc__, - aio.AioRpcError.__doc__) - - def test_class_singleton(self): - first_aio_rpc_error = aio.AioRpcError(self._TEST_INITIAL_METADATA, 0, - "details", - self._TEST_TRAILING_METADATA) - second_aio_rpc_error = aio.AioRpcError(self._TEST_INITIAL_METADATA, 0, - "details", - self._TEST_TRAILING_METADATA) - - self.assertIs(first_aio_rpc_error.__class__, - second_aio_rpc_error.__class__) - - class TestInsecureChannel(AioTestBase): def test_insecure_channel(self): async def coro(): - server_target, unused_server = await start_test_server() + server_target, _ = await start_test_server() # pylint: disable=unused-variable channel = aio.insecure_channel(server_target) self.assertIsInstance(channel, aio.Channel) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 15f4ff182d6..937cce9eebb 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import logging import unittest import grpc from grpc.experimental import aio -from src.proto.grpc.testing import messages_pb2 -from src.proto.grpc.testing import benchmark_service_pb2_grpc from tests_aio.unit._test_base import AioTestBase _TEST_METHOD_PATH = ''