From e04fcd29981186f83d8f9b3c38dd3d496482a2dd Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Wed, 22 Apr 2020 16:38:22 +0200 Subject: [PATCH 01/10] [issue-21953] Use the Metadata type In all places where a tuple was used for metadata (in the aio version), replace it by the new ``Metadata`` class. --- src/python/grpcio/grpc/_compression.py | 2 +- .../grpcio/grpc/experimental/aio/_call.py | 10 ++---- .../grpcio/grpc/experimental/aio/_channel.py | 35 ++++++++++++------- .../grpc/experimental/aio/_interceptor.py | 2 +- .../grpcio/grpc/experimental/aio/_typing.py | 9 ++--- .../grpcio_tests/tests_aio/interop/methods.py | 6 ++-- .../grpcio_tests/tests_aio/unit/_common.py | 13 +++---- .../tests_aio/unit/aio_rpc_error_test.py | 7 ++-- .../client_unary_unary_interceptor_test.py | 27 ++++++++------ .../tests_aio/unit/metadata_test.py | 27 ++++++-------- 10 files changed, 76 insertions(+), 62 deletions(-) diff --git a/src/python/grpcio/grpc/_compression.py b/src/python/grpcio/grpc/_compression.py index 45339c3afe2..4035844db97 100644 --- a/src/python/grpcio/grpc/_compression.py +++ b/src/python/grpcio/grpc/_compression.py @@ -39,7 +39,7 @@ def create_channel_option(compression): int(compression)),) if compression else () -def augment_metadata(metadata, compression): +def augment_metadata(metadata, compression) -> tuple: if not metadata and not compression: return None base_metadata = tuple(metadata) if metadata else () diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index a0693921461..35728433987 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -28,6 +28,7 @@ from . import _base_call from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, MetadatumType, RequestIterableType, RequestType, ResponseType, SerializingFunction) +from ._metadata import Metadata __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -58,11 +59,6 @@ class AioRpcError(grpc.RpcError): determined. Hence, its methods no longer needs to be coroutines. """ - # TODO(https://github.com/grpc/grpc/issues/20144) Metadata - # type returned by `initial_metadata` and `trailing_metadata` - # and also taken in the constructor needs to be revisit and make - # it more specific. - _code: grpc.StatusCode _details: Optional[str] _initial_metadata: Optional[MetadataType] @@ -88,8 +84,8 @@ class AioRpcError(grpc.RpcError): super().__init__(self) self._code = code self._details = details - self._initial_metadata = initial_metadata - self._trailing_metadata = trailing_metadata + self._initial_metadata = initial_metadata or Metadata() + self._trailing_metadata = trailing_metadata or Metadata() self._debug_error_string = debug_error_string def code(self) -> grpc.StatusCode: diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 7427872e0b3..3361af477b7 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -32,8 +32,8 @@ from ._interceptor import ( from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline +from ._metadata import Metadata -_IMMUTABLE_EMPTY_TUPLE = tuple() _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) if sys.version_info[1] < 7: @@ -88,6 +88,19 @@ class _BaseMultiCallable: self._response_deserializer = response_deserializer self._interceptors = interceptors + @staticmethod + def _init_metadata(metadata: Optional[Metadata] = None, + compression: Optional[grpc.Compression] = None + ) -> Metadata: + """Based on the provided values for or initialise the final + metadata, as it should be used for the current call. + """ + metadata = metadata or Metadata() + if compression: + metadata = Metadata( + *_compression.augment_metadata(metadata, compression)) + return metadata + class UnaryUnaryMultiCallable(_BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable): @@ -96,14 +109,13 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable, request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.UnaryUnaryCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) if not self._interceptors: call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), metadata, credentials, wait_for_ready, @@ -127,14 +139,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable, request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.UnaryStreamCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: @@ -158,14 +169,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable, def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.StreamUnaryCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: @@ -189,14 +199,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable, def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.StreamStreamCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index e276ae0c922..8a28a61c8ba 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -248,7 +248,7 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): class InterceptedCall: - """Base implementation for all intecepted call arities. + """Base implementation for all intercepted call arities. Interceptors might have some work to do before the RPC invocation with the capacity of changing the invocation parameters, and some work to do diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py index a02ec8ff803..7e2e8da8a06 100644 --- a/src/python/grpcio/grpc/experimental/aio/_typing.py +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -13,17 +13,18 @@ # limitations under the License. """Common types for gRPC Async API""" -from typing import (Any, AnyStr, AsyncIterable, Callable, Iterable, Sequence, - Tuple, TypeVar, Union) +from typing import (Any, AsyncIterable, Callable, Iterable, Sequence, Tuple, + TypeVar, Union) from grpc._cython.cygrpc import EOF +from ._metadata import Metadata, MetadataKey, MetadataValue RequestType = TypeVar('RequestType') ResponseType = TypeVar('ResponseType') SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] -MetadatumType = Tuple[str, AnyStr] -MetadataType = Sequence[MetadatumType] +MetadatumType = Tuple[MetadataKey, MetadataValue] +MetadataType = Metadata ChannelArgumentType = Sequence[Tuple[str, Any]] EOFType = type(EOF) DoneCallbackType = Callable[[Any], None] diff --git a/src/python/grpcio_tests/tests_aio/interop/methods.py b/src/python/grpcio_tests/tests_aio/interop/methods.py index 019b3bca894..4ff10ff1572 100644 --- a/src/python/grpcio_tests/tests_aio/interop/methods.py +++ b/src/python/grpcio_tests/tests_aio/interop/methods.py @@ -287,8 +287,10 @@ async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub): async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): initial_metadata_value = "test_initial_metadata_value" trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" - metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), - (_TRAILING_METADATA_KEY, trailing_metadata_value)) + metadata = aio.Metadata( + (_INITIAL_METADATA_KEY, initial_metadata_value), + (_TRAILING_METADATA_KEY, trailing_metadata_value), + ) async def _validate_metadata(call): initial_metadata = dict(await call.initial_metadata()) diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index dab9454c58d..7f9e2e89701 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -16,18 +16,19 @@ import asyncio import grpc from typing import AsyncIterable from grpc.experimental import aio -from grpc.experimental.aio._typing import MetadataType, MetadatumType +from grpc.experimental.aio._typing import MetadataType, MetadatumType, MetadataKey, MetadataValue from tests.unit.framework.common import test_constants def seen_metadata(expected: MetadataType, actual: MetadataType): - return not bool(set(expected) - set(actual)) + return not bool(set(tuple(expected)) - set(tuple(actual))) -def seen_metadatum(expected: MetadatumType, actual: MetadataType): - metadata_dict = dict(actual) - return metadata_dict.get(expected[0]) == expected[1] +def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, + actual: MetadataType) -> bool: + obtained = actual[expected_key] + assert obtained == expected_value async def block_until_certain_state(channel: aio.Channel, @@ -50,7 +51,7 @@ def inject_callbacks(call: aio.Call): second_callback_ran = asyncio.Event() def second_callback(call): - # Validate that all resopnses have been received + # Validate that all responses have been received # and the call is an end state. assert call.done() second_callback_ran.set() diff --git a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py index fbcd99b39e3..9bd652a43a6 100644 --- a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py @@ -18,11 +18,14 @@ import unittest import grpc +from grpc.experimental import aio from grpc.experimental.aio._call import AioRpcError from tests_aio.unit._test_base import AioTestBase -_TEST_INITIAL_METADATA = ('initial metadata',) -_TEST_TRAILING_METADATA = ('trailing metadata',) +_TEST_INITIAL_METADATA = aio.Metadata( + ('initial metadata key', 'initial metadata value')) +_TEST_TRAILING_METADATA = aio.Metadata( + ('trailing metadata key', 'trailing metadata value')) _TEST_DEBUG_ERROR_STRING = '{This is a debug string}' diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py index 8f5a356ca4a..ae1ad54acd9 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py @@ -25,7 +25,7 @@ from tests_aio.unit._test_base import AioTestBase from src.proto.grpc.testing import messages_pb2, test_pb2_grpc _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' -_INITIAL_METADATA_TO_INJECT = ( +_INITIAL_METADATA_TO_INJECT = aio.Metadata( (_INITIAL_METADATA_KEY, 'extra info'), (_TRAILING_METADATA_KEY, b'\x13\x37'), ) @@ -162,7 +162,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): async def test_retry(self): class RetryInterceptor(aio.UnaryUnaryClientInterceptor): - """Simulates a Retry Interceptor which ends up by making + """Simulates a Retry Interceptor which ends up by making two RPC calls.""" def __init__(self): @@ -550,11 +550,12 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): async def intercept_unary_unary(self, continuation, client_call_details, request): + new_metadata = aio.Metadata(*client_call_details.metadata, + *_INITIAL_METADATA_TO_INJECT) new_details = aio.ClientCallDetails( method=client_call_details.method, timeout=client_call_details.timeout, - metadata=client_call_details.metadata + - _INITIAL_METADATA_TO_INJECT, + metadata=new_metadata, credentials=client_call_details.credentials, wait_for_ready=client_call_details.wait_for_ready, ) @@ -568,14 +569,20 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): # Expected to see the echoed initial metadata self.assertTrue( - _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await - call.initial_metadata())) - + _common.seen_metadatum( + expected_key=_INITIAL_METADATA_KEY, + expected_value=_INITIAL_METADATA_TO_INJECT[ + _INITIAL_METADATA_KEY], + actual=await call.initial_metadata(), + )) # Expected to see the echoed trailing metadata self.assertTrue( - _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await - call.trailing_metadata())) - + _common.seen_metadatum( + expected_key=_TRAILING_METADATA_KEY, + expected_value=_INITIAL_METADATA_TO_INJECT[ + _TRAILING_METADATA_KEY], + actual=await call.trailing_metadata(), + )) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_add_done_callback_before_finishes(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 6551e4ca084..16f91430fce 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -37,38 +37,33 @@ _TEST_STREAM_STREAM = '/test/TestStreamStream' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' -_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = ( +_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata( ('client-to-server', 'question'), ('client-to-server-bin', b'\x07\x07\x07'), ) -_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = ( +_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata( ('server-to-client', 'answer'), ('server-to-client-bin', b'\x06\x06\x06'), ) -_TRAILING_METADATA = (('a-trailing-metadata', 'stack-trace'), - ('a-trailing-metadata-bin', b'\x05\x05\x05')) -_INITIAL_METADATA_FOR_GENERIC_HANDLER = (('a-must-have-key', 'secret'),) +_TRAILING_METADATA = aio.Metadata( + ('a-trailing-metadata', 'stack-trace'), + ('a-trailing-metadata-bin', b'\x05\x05\x05'), +) +_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( + ('a-must-have-key', 'secret'),) _INVALID_METADATA_TEST_CASES = ( ( TypeError, - ((42, 42),), - ), - ( - TypeError, - (({}, {}),), - ), - ( - TypeError, - (('normal', object()),), + aio.Metadata((42, 42),), ), ( TypeError, - object(), + aio.Metadata(({}, {}),), ), ( TypeError, - (object(),), + aio.Metadata(('normal', object()),), ), ) From e9dadf46bfba716c7c88ce4f69cb10b7b1844518 Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Tue, 26 May 2020 17:00:13 +0200 Subject: [PATCH 02/10] [issue-24953] Fix tests, format, & types Fixes https://github.com/grpc/grpc/issues/21953 --- .../grpcio/grpc/experimental/aio/_call.py | 13 +++++---- .../grpcio/grpc/experimental/aio/_channel.py | 2 +- .../grpcio_tests/tests_aio/interop/methods.py | 5 ++-- .../grpcio_tests/tests_aio/unit/_common.py | 2 +- .../grpcio_tests/tests_aio/unit/call_test.py | 10 +++---- .../client_stream_unary_interceptor_test.py | 12 ++++----- .../client_unary_stream_interceptor_test.py | 4 +-- .../client_unary_unary_interceptor_test.py | 22 ++++++++------- .../tests_aio/unit/compatibility_test.py | 3 ++- .../tests_aio/unit/metadata_test.py | 27 ++++++++++--------- .../tests_aio/unit/server_interceptor_test.py | 4 +-- 11 files changed, 56 insertions(+), 48 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 35728433987..7ecc2dd8e1b 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -25,10 +25,10 @@ from grpc import _common from grpc._cython import cygrpc from . import _base_call +from ._metadata import Metadata from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, MetadatumType, RequestIterableType, RequestType, ResponseType, SerializingFunction) -from ._metadata import Metadata __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -84,8 +84,8 @@ class AioRpcError(grpc.RpcError): super().__init__(self) self._code = code self._details = details - self._initial_metadata = initial_metadata or Metadata() - self._trailing_metadata = trailing_metadata or Metadata() + self._initial_metadata = Metadata(*(initial_metadata or ())) + self._trailing_metadata = Metadata(*(trailing_metadata or ())) self._debug_error_string = debug_error_string def code(self) -> grpc.StatusCode: @@ -205,10 +205,13 @@ class Call: return self._cython_call.time_remaining() async def initial_metadata(self) -> MetadataType: - return await self._cython_call.initial_metadata() + raw_metadata_tuple = await self._cython_call.initial_metadata() + return Metadata(*(raw_metadata_tuple or ())) async def trailing_metadata(self) -> MetadataType: - return (await self._cython_call.status()).trailing_metadata() + raw_metadata_tuple = (await + self._cython_call.status()).trailing_metadata() + return Metadata(*(raw_metadata_tuple or ())) async def code(self) -> grpc.StatusCode: cygrpc_code = (await self._cython_call.status()).code() diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 3361af477b7..3ac12bf6139 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -29,10 +29,10 @@ from ._interceptor import ( InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, StreamStreamClientInterceptor) +from ._metadata import Metadata from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline -from ._metadata import Metadata _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) diff --git a/src/python/grpcio_tests/tests_aio/interop/methods.py b/src/python/grpcio_tests/tests_aio/interop/methods.py index 4ff10ff1572..706f4249be3 100644 --- a/src/python/grpcio_tests/tests_aio/interop/methods.py +++ b/src/python/grpcio_tests/tests_aio/interop/methods.py @@ -293,12 +293,13 @@ async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): ) async def _validate_metadata(call): - initial_metadata = dict(await call.initial_metadata()) + initial_metadata = await call.initial_metadata() if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: raise ValueError('expected initial metadata %s, got %s' % (initial_metadata_value, initial_metadata[_INITIAL_METADATA_KEY])) - trailing_metadata = dict(await call.trailing_metadata()) + + trailing_metadata = await call.trailing_metadata() if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: raise ValueError('expected trailing metadata %s, got %s' % (trailing_metadata_value, diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index 7f9e2e89701..a4a9236069c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -28,7 +28,7 @@ def seen_metadata(expected: MetadataType, actual: MetadataType): def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, actual: MetadataType) -> bool: obtained = actual[expected_key] - assert obtained == expected_value + return obtained == expected_value async def block_until_certain_state(channel: aio.Channel, diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 3ce2a2f7b52..94a36a4c070 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -102,11 +102,11 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): async def test_call_initial_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual((), await call.initial_metadata()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) async def test_call_trailing_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual((), await call.trailing_metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_initial_metadata_cancelable(self): coro_started = asyncio.Event() @@ -122,7 +122,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): # Test that initial metadata can still be asked thought # a cancellation happened with the previous task - self.assertEqual((), await call.initial_metadata()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) async def test_call_initial_metadata_multiple_waiters(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -134,8 +134,8 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): task2 = self.loop.create_task(coro()) await call - - self.assertEqual([(), ()], await asyncio.gather(*[task1, task2])) + expected = [aio.Metadata() for _ in range(2)] + self.assertEqual(await asyncio.gather(*[task1, task2]), expected) async def test_call_code_cancelable(self): coro_started = asyncio.Event() diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py index 2d58cff0e41..b9a04af00dc 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py @@ -92,8 +92,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -131,8 +131,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -230,8 +230,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py index 6137538ffca..fd542fd16e9 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py @@ -96,8 +96,8 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py index ae1ad54acd9..e64daec7df4 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py @@ -302,8 +302,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(type(response), messages_pb2.SimpleResponse) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.details(), '') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_ok_awaited(self): @@ -331,8 +331,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(type(response), messages_pb2.SimpleResponse) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.details(), '') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_rpc_error(self): @@ -364,8 +364,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(await call.details(), 'Deadline Exceeded') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_rpc_error_awaited(self): @@ -398,8 +398,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(await call.details(), 'Deadline Exceeded') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_cancel_before_rpc(self): @@ -541,8 +541,10 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.details(), _LOCAL_CANCEL_DETAILS_EXPECTATION) - self.assertEqual(await call.initial_metadata(), tuple()) - self.assertEqual(await call.trailing_metadata(), None) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual( + await call.trailing_metadata(), aio.Metadata(), + "When the raw response is None, empty metadata is returned") async def test_initial_metadata_modification(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py b/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py index 1e6e3598b8b..0bb3a3acc89 100644 --- a/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py @@ -255,7 +255,8 @@ class TestCompatibility(AioTestBase): self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary) call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) self.assertTrue( - _common.seen_metadata(metadata, await call.initial_metadata())) + _common.seen_metadata(aio.Metadata(*metadata), await + call.initial_metadata())) async def test_sync_unary_unary_abort(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 16f91430fce..59f5596c70b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -55,15 +55,15 @@ _INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( _INVALID_METADATA_TEST_CASES = ( ( TypeError, - aio.Metadata((42, 42),), + ((42, 42),), ), ( TypeError, - aio.Metadata(({}, {}),), + ((None, {}),), ), ( TypeError, - aio.Metadata(('normal', object()),), + (('normal', object()),), ), ) @@ -100,13 +100,13 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): async def _test_server_to_client(request, context): assert _REQUEST == request await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) return _RESPONSE @staticmethod async def _test_trailing_metadata(request, context): assert _REQUEST == request - context.set_trailing_metadata(_TRAILING_METADATA) + context.set_trailing_metadata(tuple(_TRAILING_METADATA)) return _RESPONSE @staticmethod @@ -115,21 +115,21 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) yield _RESPONSE - context.set_trailing_metadata(_TRAILING_METADATA) + context.set_trailing_metadata(tuple(_TRAILING_METADATA)) @staticmethod async def _test_stream_unary(request_iterator, context): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) async for request in request_iterator: assert _REQUEST == request - context.set_trailing_metadata(_TRAILING_METADATA) + context.set_trailing_metadata(tuple(_TRAILING_METADATA)) return _RESPONSE @staticmethod @@ -137,13 +137,13 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) async for request in request_iterator: assert _REQUEST == request yield _RESPONSE - context.set_trailing_metadata(_TRAILING_METADATA) + context.set_trailing_metadata(tuple(_TRAILING_METADATA)) def service(self, handler_call_details): return self._routing_table.get(handler_call_details.method) @@ -193,6 +193,7 @@ class TestMetadata(AioTestBase): async def test_from_server_to_client(self): multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT) call = multicallable(_REQUEST) + self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await call.initial_metadata()) self.assertEqual(_RESPONSE, await call) @@ -207,8 +208,8 @@ class TestMetadata(AioTestBase): async def test_from_client_to_server_with_list(self): multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) - call = multicallable( - _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) + call = multicallable(_REQUEST, + metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index f85e46c379a..d891ecdb771 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -198,7 +198,7 @@ class TestServerInterceptor(AioTestBase): request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - metadata = (('key', 'value'),) + metadata = aio.Metadata(('key', 'value'),) call = multicallable(messages_pb2.SimpleRequest(), metadata=metadata) await call @@ -208,7 +208,7 @@ class TestServerInterceptor(AioTestBase): ], record) record.clear() - metadata = (('key', 'value'), ('secret', '42')) + metadata = aio.Metadata(('key', 'value'), ('secret', '42')) call = multicallable(messages_pb2.SimpleRequest(), metadata=metadata) await call From 8fcc77a3109080d41d9a62763a17cde18aaa90d1 Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Thu, 28 May 2020 10:41:38 +0200 Subject: [PATCH 03/10] [issue-21953] Improvements from review * Replace ``MetadataType`` by ``Metadata`` in all places * Fix annotations * Use the new ``Metadata.from_tuple`` to create Metadata objects --- src/python/grpcio/grpc/_compression.py | 2 +- .../grpcio/grpc/experimental/aio/_call.py | 42 +++++++++---------- .../grpcio/grpc/experimental/aio/_metadata.py | 10 ++++- .../tests_aio/unit/_metadata_test.py | 12 ++++++ .../grpcio_tests/tests_aio/unit/call_test.py | 8 ++-- .../tests_aio/unit/metadata_test.py | 4 ++ 6 files changed, 50 insertions(+), 28 deletions(-) diff --git a/src/python/grpcio/grpc/_compression.py b/src/python/grpcio/grpc/_compression.py index 4035844db97..45339c3afe2 100644 --- a/src/python/grpcio/grpc/_compression.py +++ b/src/python/grpcio/grpc/_compression.py @@ -39,7 +39,7 @@ def create_channel_option(compression): int(compression)),) if compression else () -def augment_metadata(metadata, compression) -> tuple: +def augment_metadata(metadata, compression): if not metadata and not compression: return None base_metadata = tuple(metadata) if metadata else () diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 7ecc2dd8e1b..c121bd6b76d 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -26,7 +26,7 @@ from grpc._cython import cygrpc from . import _base_call from ._metadata import Metadata -from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, +from ._typing import (DeserializingFunction, DoneCallbackType, MetadatumType, RequestIterableType, RequestType, ResponseType, SerializingFunction) @@ -61,15 +61,15 @@ class AioRpcError(grpc.RpcError): _code: grpc.StatusCode _details: Optional[str] - _initial_metadata: Optional[MetadataType] - _trailing_metadata: Optional[MetadataType] + _initial_metadata: Optional[Metadata] + _trailing_metadata: Optional[Metadata] _debug_error_string: Optional[str] def __init__(self, code: grpc.StatusCode, details: Optional[str] = None, - initial_metadata: Optional[MetadataType] = None, - trailing_metadata: Optional[MetadataType] = None, + initial_metadata: Optional[Metadata] = None, + trailing_metadata: Optional[Metadata] = None, debug_error_string: Optional[str] = None) -> None: """Constructor. @@ -84,8 +84,8 @@ class AioRpcError(grpc.RpcError): super().__init__(self) self._code = code self._details = details - self._initial_metadata = Metadata(*(initial_metadata or ())) - self._trailing_metadata = Metadata(*(trailing_metadata or ())) + self._initial_metadata = initial_metadata + self._trailing_metadata = trailing_metadata self._debug_error_string = debug_error_string def code(self) -> grpc.StatusCode: @@ -104,7 +104,7 @@ class AioRpcError(grpc.RpcError): """ return self._details - def initial_metadata(self) -> Optional[MetadataType]: + def initial_metadata(self) -> Metadata: """Accesses the initial metadata sent by the server. Returns: @@ -112,7 +112,7 @@ class AioRpcError(grpc.RpcError): """ return self._initial_metadata - def trailing_metadata(self) -> Optional[MetadataType]: + def trailing_metadata(self) -> Metadata: """Accesses the trailing metadata sent by the server. Returns: @@ -141,13 +141,13 @@ class AioRpcError(grpc.RpcError): return self._repr() -def _create_rpc_error(initial_metadata: Optional[MetadataType], +def _create_rpc_error(initial_metadata: Metadata, status: cygrpc.AioRpcStatus) -> AioRpcError: return AioRpcError( _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], status.details(), - initial_metadata, - status.trailing_metadata(), + Metadata.from_tuple(initial_metadata), + Metadata.from_tuple(status.trailing_metadata()), status.debug_error_string(), ) @@ -164,7 +164,7 @@ class Call: _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction - def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType, + def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: @@ -204,14 +204,14 @@ class Call: def time_remaining(self) -> Optional[float]: return self._cython_call.time_remaining() - async def initial_metadata(self) -> MetadataType: + async def initial_metadata(self) -> Metadata: raw_metadata_tuple = await self._cython_call.initial_metadata() - return Metadata(*(raw_metadata_tuple or ())) + return Metadata.from_tuple(raw_metadata_tuple) - async def trailing_metadata(self) -> MetadataType: + async def trailing_metadata(self) -> Metadata: raw_metadata_tuple = (await self._cython_call.status()).trailing_metadata() - return Metadata(*(raw_metadata_tuple or ())) + return Metadata.from_tuple(raw_metadata_tuple) async def code(self) -> grpc.StatusCode: cygrpc_code = (await self._cython_call.status()).code() @@ -474,7 +474,7 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -523,7 +523,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -563,7 +563,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, # pylint: disable=too-many-arguments def __init__(self, request_iterator: Optional[RequestIterableType], - deadline: Optional[float], metadata: MetadataType, + deadline: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -601,7 +601,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, # pylint: disable=too-many-arguments def __init__(self, request_iterator: Optional[RequestIterableType], - deadline: Optional[float], metadata: MetadataType, + deadline: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, diff --git a/src/python/grpcio/grpc/experimental/aio/_metadata.py b/src/python/grpcio/grpc/experimental/aio/_metadata.py index ff970106748..3230445d58a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_metadata.py +++ b/src/python/grpcio/grpc/experimental/aio/_metadata.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the metadata abstraction for gRPC Asyncio Python.""" -from typing import List, Tuple, Iterator, Any, Text, Union +from typing import List, Tuple, Iterator, Any, Union from collections import abc, OrderedDict -MetadataKey = Text +MetadataKey = str MetadataValue = Union[str, bytes] @@ -37,6 +37,12 @@ class Metadata(abc.Mapping): for md_key, md_value in args: self.add(md_key, md_value) + @classmethod + def from_tuple(cls, raw_metadata: tuple): + if raw_metadata: + return cls(*raw_metadata) + return cls() + def add(self, key: MetadataKey, value: MetadataValue) -> None: self._metadata.setdefault(key, []) self._metadata[key].append(value) diff --git a/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py index dda58c5ed53..c0594cb06ab 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py @@ -119,6 +119,18 @@ class TestTypeMetadata(unittest.TestCase): with self.assertRaises(KeyError): del metadata["other key"] + def test_metadata_from_tuple(self): + scenarios = ( + (None, Metadata()), + (Metadata(), Metadata()), + (self._DEFAULT_DATA, Metadata(*self._DEFAULT_DATA)), + (self._MULTI_ENTRY_DATA, Metadata(*self._MULTI_ENTRY_DATA)), + (Metadata(*self._DEFAULT_DATA), Metadata(*self._DEFAULT_DATA)), + ) + for source, expected in scenarios: + with self.subTest(raw_metadata=source, expected=expected): + self.assertEqual(expected, Metadata.from_tuple(source)) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 94a36a4c070..1961226fa6d 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -102,11 +102,11 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): async def test_call_initial_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(aio.Metadata(), await call.initial_metadata()) async def test_call_trailing_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(aio.Metadata(), await call.trailing_metadata()) async def test_call_initial_metadata_cancelable(self): coro_started = asyncio.Event() @@ -122,7 +122,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): # Test that initial metadata can still be asked thought # a cancellation happened with the previous task - self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(aio.Metadata(), await call.initial_metadata()) async def test_call_initial_metadata_multiple_waiters(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -135,7 +135,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): await call expected = [aio.Metadata() for _ in range(2)] - self.assertEqual(await asyncio.gather(*[task1, task2]), expected) + self.assertEqual(expected, await asyncio.gather(*[task1, task2])) async def test_call_code_cancelable(self): coro_started = asyncio.Event() diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 59f5596c70b..822bd134521 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -57,6 +57,10 @@ _INVALID_METADATA_TEST_CASES = ( TypeError, ((42, 42),), ), + ( + TypeError, + (({}, {}),), + ), ( TypeError, ((None, {}),), From 5a5a5784462c3e3a369f0a75f5ce5e90ce2182bc Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Tue, 2 Jun 2020 11:44:14 +0200 Subject: [PATCH 04/10] Fix new metadata tests Using the new aio.Metadata() type instead of tuple. --- .../unit/client_stream_stream_interceptor_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py index 9ab54b39a6b..ce6a7bc04d6 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py @@ -98,8 +98,8 @@ class TestStreamStreamClientInterceptor(AioTestBase): self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -140,8 +140,8 @@ class TestStreamStreamClientInterceptor(AioTestBase): await call.done_writing() self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -183,8 +183,8 @@ class TestStreamStreamClientInterceptor(AioTestBase): await call.done_writing() self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) From a6bf093af8dcb13a802137d42e965b1b59152587 Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Tue, 2 Jun 2020 11:49:57 +0200 Subject: [PATCH 05/10] Use metadata types in the service context Replace the signature to allow methods to use the metadata object. Internally, they'll still wrap the data in a tuple, but the interface makes it clear that the ``aio.Metadata()`` object is supported. Remove the ``tuple()`` conversions done in the tests. --- .../grpc/_cython/_cygrpc/aio/server.pyx.pxi | 16 ++++++++-------- .../grpcio_tests/tests_aio/unit/metadata_test.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index f37769c0038..b842ec6f2ba 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -143,7 +143,7 @@ cdef class _ServicerContext: self._loop) self._rpc_state.metadata_sent = True - async def send_initial_metadata(self, tuple metadata): + async def send_initial_metadata(self, object metadata): self._rpc_state.raise_for_termination() if self._rpc_state.metadata_sent: @@ -151,7 +151,7 @@ cdef class _ServicerContext: else: await _send_initial_metadata( self._rpc_state, - _augment_metadata(metadata, self._rpc_state.compression_algorithm), + _augment_metadata(tuple(metadata), self._rpc_state.compression_algorithm), _EMPTY_FLAG, self._loop ) @@ -192,8 +192,8 @@ cdef class _ServicerContext: async def abort_with_status(self, object status): await self.abort(status.code, status.details, status.trailing_metadata) - def set_trailing_metadata(self, tuple metadata): - self._rpc_state.trailing_metadata = metadata + def set_trailing_metadata(self, object metadata): + self._rpc_state.trailing_metadata = tuple(metadata) def invocation_metadata(self): return self._rpc_state.invocation_metadata() @@ -233,13 +233,13 @@ cdef class _SyncServicerContext: # Abort should raise an AbortError future.exception() - def send_initial_metadata(self, tuple metadata): + def send_initial_metadata(self, object metadata): future = asyncio.run_coroutine_threadsafe( self._context.send_initial_metadata(metadata), self._loop) future.result() - def set_trailing_metadata(self, tuple metadata): + def set_trailing_metadata(self, object metadata): self._context.set_trailing_metadata(metadata) def invocation_metadata(self): @@ -303,7 +303,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state, object response_serializer, object loop): """Finishes server method handler with a single response. - + This function executes the application handler, and handles response sending, as well as errors. It is shared between unary-unary and stream-unary handlers. @@ -378,7 +378,7 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state, """ cdef object async_response_generator cdef object response_message - + if inspect.iscoroutinefunction(stream_handler): # Case 1: Coroutine async handler - using reader-writer API # The handler uses reader / writer API, returns None. diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 822bd134521..6fee0c62630 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -104,13 +104,13 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): async def _test_server_to_client(request, context): assert _REQUEST == request await context.send_initial_metadata( - tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) return _RESPONSE @staticmethod async def _test_trailing_metadata(request, context): assert _REQUEST == request - context.set_trailing_metadata(tuple(_TRAILING_METADATA)) + context.set_trailing_metadata(_TRAILING_METADATA) return _RESPONSE @staticmethod @@ -119,21 +119,21 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) yield _RESPONSE - context.set_trailing_metadata(tuple(_TRAILING_METADATA)) + context.set_trailing_metadata(_TRAILING_METADATA) @staticmethod async def _test_stream_unary(request_iterator, context): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) async for request in request_iterator: assert _REQUEST == request - context.set_trailing_metadata(tuple(_TRAILING_METADATA)) + context.set_trailing_metadata(_TRAILING_METADATA) return _RESPONSE @staticmethod @@ -141,13 +141,13 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) async for request in request_iterator: assert _REQUEST == request yield _RESPONSE - context.set_trailing_metadata(tuple(_TRAILING_METADATA)) + context.set_trailing_metadata(_TRAILING_METADATA) def service(self, handler_call_details): return self._routing_table.get(handler_call_details.method) From 7b3430ef3ed76f6a5cc15f45e350735009bb1370 Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Tue, 2 Jun 2020 11:53:22 +0200 Subject: [PATCH 06/10] Restore test that passes metadata in a list The old interface of accepting the metadata as a list, should be kept due to a backwards incompatibility with a client. The new ``aio.Metadata()`` type supports iteration, so creating a list from it, is possible. --- src/python/grpcio_tests/tests_aio/unit/metadata_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 6fee0c62630..9e4de909216 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -213,7 +213,7 @@ class TestMetadata(AioTestBase): async def test_from_client_to_server_with_list(self): multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) call = multicallable(_REQUEST, - metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) + metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) From 6e83eb79f48d8abb09fb3cb1c853422e44b52b6b Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Tue, 2 Jun 2020 12:38:46 +0200 Subject: [PATCH 07/10] Apply formatting & fix typing --- src/python/grpcio/grpc/experimental/aio/_call.py | 6 +++--- src/python/grpcio_tests/tests_aio/unit/metadata_test.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index c121bd6b76d..bf5865a3af0 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -26,9 +26,9 @@ from grpc._cython import cygrpc from . import _base_call from ._metadata import Metadata -from ._typing import (DeserializingFunction, DoneCallbackType, - MetadatumType, RequestIterableType, RequestType, - ResponseType, SerializingFunction) +from ._typing import (DeserializingFunction, DoneCallbackType, MetadatumType, + RequestIterableType, RequestType, ResponseType, + SerializingFunction) __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 9e4de909216..0c8956537ce 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -212,8 +212,8 @@ class TestMetadata(AioTestBase): async def test_from_client_to_server_with_list(self): multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) - call = multicallable(_REQUEST, - metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) + call = multicallable( + _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) # pytype: disable=wrong-arg-types self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) From 36f79adaf9cda4179a4612a07ecf92a95b3616e8 Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Mon, 8 Jun 2020 16:52:31 +0200 Subject: [PATCH 08/10] Remove references to the old MetadataType --- .../grpc/experimental/aio/_base_call.py | 7 ++-- .../grpc/experimental/aio/_base_channel.py | 13 ++++---- .../grpc/experimental/aio/_base_server.py | 11 ++++--- .../grpcio/grpc/experimental/aio/_channel.py | 10 +++--- .../grpc/experimental/aio/_interceptor.py | 33 ++++++++++--------- .../grpcio_tests/tests_aio/unit/_common.py | 7 ++-- 6 files changed, 42 insertions(+), 39 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index 214e208c005..c07b4ca6b14 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -23,8 +23,9 @@ from typing import AsyncIterable, Awaitable, Generic, Optional, Union import grpc -from ._typing import (DoneCallbackType, EOFType, MetadataType, RequestType, +from ._typing import (DoneCallbackType, EOFType, RequestType, ResponseType) +from ._metadata import Metadata __all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -86,7 +87,7 @@ class Call(RpcContext, metaclass=ABCMeta): """The abstract base class of an RPC on the client-side.""" @abstractmethod - async def initial_metadata(self) -> MetadataType: + async def initial_metadata(self) -> Metadata: """Accesses the initial metadata sent by the server. Returns: @@ -94,7 +95,7 @@ class Call(RpcContext, metaclass=ABCMeta): """ @abstractmethod - async def trailing_metadata(self) -> MetadataType: + async def trailing_metadata(self) -> Metadata: """Accesses the trailing metadata sent by the server. Returns: diff --git a/src/python/grpcio/grpc/experimental/aio/_base_channel.py b/src/python/grpcio/grpc/experimental/aio/_base_channel.py index 33efa7789cb..4b4ea1355b4 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_channel.py @@ -19,10 +19,9 @@ from typing import Any, Optional import grpc from . import _base_call -from ._typing import (DeserializingFunction, MetadataType, RequestIterableType, +from ._typing import (DeserializingFunction, RequestIterableType, SerializingFunction) - -_IMMUTABLE_EMPTY_TUPLE = tuple() +from ._metadata import Metadata class UnaryUnaryMultiCallable(abc.ABC): @@ -33,7 +32,7 @@ class UnaryUnaryMultiCallable(abc.ABC): request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -71,7 +70,7 @@ class UnaryStreamMultiCallable(abc.ABC): request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -108,7 +107,7 @@ class StreamUnaryMultiCallable(abc.ABC): def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -146,7 +145,7 @@ class StreamStreamMultiCallable(abc.ABC): def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None diff --git a/src/python/grpcio/grpc/experimental/aio/_base_server.py b/src/python/grpcio/grpc/experimental/aio/_base_server.py index 72e4288c94f..842e9b15c9e 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_server.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_server.py @@ -18,7 +18,8 @@ from typing import Generic, Optional, Sequence import grpc -from ._typing import MetadataType, RequestType, ResponseType +from ._typing import RequestType, ResponseType +from ._metadata import Metadata class Server(abc.ABC): @@ -158,7 +159,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): @abc.abstractmethod async def send_initial_metadata(self, - initial_metadata: MetadataType) -> None: + initial_metadata: Metadata) -> None: """Sends the initial metadata value to the client. This method need not be called by implementations if they have no @@ -170,7 +171,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): @abc.abstractmethod async def abort(self, code: grpc.StatusCode, details: str, - trailing_metadata: MetadataType) -> None: + trailing_metadata: Metadata) -> None: """Raises an exception to terminate the RPC with a non-OK status. The code and details passed as arguments will supercede any existing @@ -191,7 +192,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): @abc.abstractmethod async def set_trailing_metadata(self, - trailing_metadata: MetadataType) -> None: + trailing_metadata: Metadata) -> None: """Sends the trailing metadata for the RPC. This method need not be called by implementations if they have no @@ -202,7 +203,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): """ @abc.abstractmethod - def invocation_metadata(self) -> Optional[MetadataType]: + def invocation_metadata(self) -> Optional[Metadata]: """Accesses the metadata from the sent by the client. Returns: diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 3ac12bf6139..1995db13bf5 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -30,7 +30,7 @@ from ._interceptor import ( UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, StreamStreamClientInterceptor) from ._metadata import Metadata -from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, +from ._typing import (ChannelArgumentType, DeserializingFunction, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline @@ -109,7 +109,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable, request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -139,7 +139,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable, request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -169,7 +169,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable, def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -199,7 +199,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable, def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 8a28a61c8ba..c8f185afb56 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -27,8 +27,9 @@ from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS from ._call import _API_STYLE_ERROR from ._utils import _timeout_to_deadline from ._typing import (RequestType, SerializingFunction, DeserializingFunction, - MetadataType, ResponseType, DoneCallbackType, + ResponseType, DoneCallbackType, RequestIterableType, ResponseIterableType) +from ._metadata import Metadata _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' @@ -82,7 +83,7 @@ class ClientCallDetails( method: str timeout: Optional[float] - metadata: Optional[MetadataType] + metadata: Optional[Metadata] credentials: Optional[grpc.CallCredentials] wait_for_ready: Optional[bool] @@ -370,7 +371,7 @@ class InterceptedCall: def time_remaining(self) -> Optional[float]: raise NotImplementedError() - async def initial_metadata(self) -> Optional[MetadataType]: + async def initial_metadata(self) -> Optional[Metadata]: try: call = await self._interceptors_task except AioRpcError as err: @@ -380,7 +381,7 @@ class InterceptedCall: return await call.initial_metadata() - async def trailing_metadata(self) -> Optional[MetadataType]: + async def trailing_metadata(self) -> Optional[Metadata]: try: call = await self._interceptors_task except AioRpcError as err: @@ -556,7 +557,7 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor], request: RequestType, timeout: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -573,7 +574,7 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, # pylint: disable=too-many-arguments async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], + metadata: Optional[Metadata], credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], request: RequestType, request_serializer: SerializingFunction, @@ -628,7 +629,7 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor], request: RequestType, timeout: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -647,7 +648,7 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, # pylint: disable=too-many-arguments async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], + metadata: Optional[Metadata], credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], request: RequestType, request_serializer: SerializingFunction, @@ -712,7 +713,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor], request_iterator: Optional[RequestIterableType], - timeout: Optional[float], metadata: MetadataType, + timeout: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -731,7 +732,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, async def _invoke( self, interceptors: Sequence[StreamUnaryClientInterceptor], method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], + metadata: Optional[Metadata], credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], request_iterator: RequestIterableType, @@ -783,7 +784,7 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor], request_iterator: Optional[RequestIterableType], - timeout: Optional[float], metadata: MetadataType, + timeout: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -804,7 +805,7 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, async def _invoke( self, interceptors: Sequence[StreamStreamClientInterceptor], method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], + metadata: Optional[Metadata], credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], request_iterator: RequestIterableType, @@ -876,10 +877,10 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): def time_remaining(self) -> Optional[float]: raise NotImplementedError() - async def initial_metadata(self) -> Optional[MetadataType]: + async def initial_metadata(self) -> Optional[Metadata]: return None - async def trailing_metadata(self) -> Optional[MetadataType]: + async def trailing_metadata(self) -> Optional[Metadata]: return None async def code(self) -> grpc.StatusCode: @@ -928,10 +929,10 @@ class _StreamCallResponseIterator: def time_remaining(self) -> Optional[float]: return self._call.time_remaining() - async def initial_metadata(self) -> Optional[MetadataType]: + async def initial_metadata(self) -> Optional[Metadata]: return await self._call.initial_metadata() - async def trailing_metadata(self) -> Optional[MetadataType]: + async def trailing_metadata(self) -> Optional[Metadata]: return await self._call.trailing_metadata() async def code(self) -> grpc.StatusCode: diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index a4a9236069c..7fdd120e31b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -16,17 +16,18 @@ import asyncio import grpc from typing import AsyncIterable from grpc.experimental import aio -from grpc.experimental.aio._typing import MetadataType, MetadatumType, MetadataKey, MetadataValue +from grpc.experimental.aio._typing import MetadatumType, MetadataKey, MetadataValue +from grpc.experimental.aio._metadata import Metadata from tests.unit.framework.common import test_constants -def seen_metadata(expected: MetadataType, actual: MetadataType): +def seen_metadata(expected: Metadata, actual: Metadata): return not bool(set(tuple(expected)) - set(tuple(actual))) def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, - actual: MetadataType) -> bool: + actual: Metadata) -> bool: obtained = actual[expected_key] return obtained == expected_value From 18e0f9f53313a316dd44a0bb94ef3dff8922f1ee Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Mon, 8 Jun 2020 17:02:32 +0200 Subject: [PATCH 09/10] Remove metadata as optional from AioRpcError --- src/python/grpcio/grpc/experimental/aio/_call.py | 8 ++++---- .../grpcio_tests/tests_aio/unit/aio_rpc_error_test.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index bf5865a3af0..ba229f35c39 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -67,9 +67,9 @@ class AioRpcError(grpc.RpcError): def __init__(self, code: grpc.StatusCode, + initial_metadata: Metadata, + trailing_metadata: Metadata, details: Optional[str] = None, - initial_metadata: Optional[Metadata] = None, - trailing_metadata: Optional[Metadata] = None, debug_error_string: Optional[str] = None) -> None: """Constructor. @@ -145,10 +145,10 @@ def _create_rpc_error(initial_metadata: Metadata, status: cygrpc.AioRpcStatus) -> AioRpcError: return AioRpcError( _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], - status.details(), Metadata.from_tuple(initial_metadata), Metadata.from_tuple(status.trailing_metadata()), - status.debug_error_string(), + details=status.details(), + debug_error_string=status.debug_error_string(), ) diff --git a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py index 9bd652a43a6..416c51a7080 100644 --- a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py @@ -33,9 +33,9 @@ class TestAioRpcError(unittest.TestCase): def test_attributes(self): aio_rpc_error = AioRpcError(grpc.StatusCode.CANCELLED, - 'details', initial_metadata=_TEST_INITIAL_METADATA, trailing_metadata=_TEST_TRAILING_METADATA, + details="details", debug_error_string=_TEST_DEBUG_ERROR_STRING) self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED) self.assertEqual(aio_rpc_error.details(), 'details') From 376c0f0767058bd42dc67b53c2444b20401abfd1 Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Thu, 11 Jun 2020 09:32:55 +0200 Subject: [PATCH 10/10] Add missing metadata TypeError case Code formatted. --- src/python/grpcio/grpc/experimental/aio/_base_call.py | 3 +-- src/python/grpcio/grpc/experimental/aio/_base_server.py | 6 ++---- src/python/grpcio/grpc/experimental/aio/_interceptor.py | 4 ++-- src/python/grpcio_tests/tests_aio/unit/metadata_test.py | 4 ++++ 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index c07b4ca6b14..4ccbb3be132 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -23,8 +23,7 @@ from typing import AsyncIterable, Awaitable, Generic, Optional, Union import grpc -from ._typing import (DoneCallbackType, EOFType, RequestType, - ResponseType) +from ._typing import (DoneCallbackType, EOFType, RequestType, ResponseType) from ._metadata import Metadata __all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' diff --git a/src/python/grpcio/grpc/experimental/aio/_base_server.py b/src/python/grpcio/grpc/experimental/aio/_base_server.py index 842e9b15c9e..86c15fc86b0 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_server.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_server.py @@ -158,8 +158,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): """ @abc.abstractmethod - async def send_initial_metadata(self, - initial_metadata: Metadata) -> None: + async def send_initial_metadata(self, initial_metadata: Metadata) -> None: """Sends the initial metadata value to the client. This method need not be called by implementations if they have no @@ -191,8 +190,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): """ @abc.abstractmethod - async def set_trailing_metadata(self, - trailing_metadata: Metadata) -> None: + async def set_trailing_metadata(self, trailing_metadata: Metadata) -> None: """Sends the trailing metadata for the RPC. This method need not be called by implementations if they have no diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index c8f185afb56..80e9625c553 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -27,8 +27,8 @@ from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS from ._call import _API_STYLE_ERROR from ._utils import _timeout_to_deadline from ._typing import (RequestType, SerializingFunction, DeserializingFunction, - ResponseType, DoneCallbackType, - RequestIterableType, ResponseIterableType) + ResponseType, DoneCallbackType, RequestIterableType, + ResponseIterableType) from ._metadata import Metadata _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 0c8956537ce..c1fa97b3e4c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -65,6 +65,10 @@ _INVALID_METADATA_TEST_CASES = ( TypeError, ((None, {}),), ), + ( + TypeError, + (({}, {}),), + ), ( TypeError, (('normal', object()),),