From e04fcd29981186f83d8f9b3c38dd3d496482a2dd Mon Sep 17 00:00:00 2001 From: Mariano Anaya Date: Wed, 22 Apr 2020 16:38:22 +0200 Subject: [PATCH] [issue-21953] Use the Metadata type In all places where a tuple was used for metadata (in the aio version), replace it by the new ``Metadata`` class. --- src/python/grpcio/grpc/_compression.py | 2 +- .../grpcio/grpc/experimental/aio/_call.py | 10 ++---- .../grpcio/grpc/experimental/aio/_channel.py | 35 ++++++++++++------- .../grpc/experimental/aio/_interceptor.py | 2 +- .../grpcio/grpc/experimental/aio/_typing.py | 9 ++--- .../grpcio_tests/tests_aio/interop/methods.py | 6 ++-- .../grpcio_tests/tests_aio/unit/_common.py | 13 +++---- .../tests_aio/unit/aio_rpc_error_test.py | 7 ++-- .../client_unary_unary_interceptor_test.py | 27 ++++++++------ .../tests_aio/unit/metadata_test.py | 27 ++++++-------- 10 files changed, 76 insertions(+), 62 deletions(-) diff --git a/src/python/grpcio/grpc/_compression.py b/src/python/grpcio/grpc/_compression.py index 45339c3afe2..4035844db97 100644 --- a/src/python/grpcio/grpc/_compression.py +++ b/src/python/grpcio/grpc/_compression.py @@ -39,7 +39,7 @@ def create_channel_option(compression): int(compression)),) if compression else () -def augment_metadata(metadata, compression): +def augment_metadata(metadata, compression) -> tuple: if not metadata and not compression: return None base_metadata = tuple(metadata) if metadata else () diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index a0693921461..35728433987 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -28,6 +28,7 @@ from . import _base_call from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, MetadatumType, RequestIterableType, RequestType, ResponseType, SerializingFunction) +from ._metadata import Metadata __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -58,11 +59,6 @@ class AioRpcError(grpc.RpcError): determined. Hence, its methods no longer needs to be coroutines. """ - # TODO(https://github.com/grpc/grpc/issues/20144) Metadata - # type returned by `initial_metadata` and `trailing_metadata` - # and also taken in the constructor needs to be revisit and make - # it more specific. - _code: grpc.StatusCode _details: Optional[str] _initial_metadata: Optional[MetadataType] @@ -88,8 +84,8 @@ class AioRpcError(grpc.RpcError): super().__init__(self) self._code = code self._details = details - self._initial_metadata = initial_metadata - self._trailing_metadata = trailing_metadata + self._initial_metadata = initial_metadata or Metadata() + self._trailing_metadata = trailing_metadata or Metadata() self._debug_error_string = debug_error_string def code(self) -> grpc.StatusCode: diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 7427872e0b3..3361af477b7 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -32,8 +32,8 @@ from ._interceptor import ( from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline +from ._metadata import Metadata -_IMMUTABLE_EMPTY_TUPLE = tuple() _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) if sys.version_info[1] < 7: @@ -88,6 +88,19 @@ class _BaseMultiCallable: self._response_deserializer = response_deserializer self._interceptors = interceptors + @staticmethod + def _init_metadata(metadata: Optional[Metadata] = None, + compression: Optional[grpc.Compression] = None + ) -> Metadata: + """Based on the provided values for or initialise the final + metadata, as it should be used for the current call. + """ + metadata = metadata or Metadata() + if compression: + metadata = Metadata( + *_compression.augment_metadata(metadata, compression)) + return metadata + class UnaryUnaryMultiCallable(_BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable): @@ -96,14 +109,13 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable, request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.UnaryUnaryCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) if not self._interceptors: call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), metadata, credentials, wait_for_ready, @@ -127,14 +139,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable, request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.UnaryStreamCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: @@ -158,14 +169,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable, def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.StreamUnaryCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: @@ -189,14 +199,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable, def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.StreamStreamCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index e276ae0c922..8a28a61c8ba 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -248,7 +248,7 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): class InterceptedCall: - """Base implementation for all intecepted call arities. + """Base implementation for all intercepted call arities. Interceptors might have some work to do before the RPC invocation with the capacity of changing the invocation parameters, and some work to do diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py index a02ec8ff803..7e2e8da8a06 100644 --- a/src/python/grpcio/grpc/experimental/aio/_typing.py +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -13,17 +13,18 @@ # limitations under the License. """Common types for gRPC Async API""" -from typing import (Any, AnyStr, AsyncIterable, Callable, Iterable, Sequence, - Tuple, TypeVar, Union) +from typing import (Any, AsyncIterable, Callable, Iterable, Sequence, Tuple, + TypeVar, Union) from grpc._cython.cygrpc import EOF +from ._metadata import Metadata, MetadataKey, MetadataValue RequestType = TypeVar('RequestType') ResponseType = TypeVar('ResponseType') SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] -MetadatumType = Tuple[str, AnyStr] -MetadataType = Sequence[MetadatumType] +MetadatumType = Tuple[MetadataKey, MetadataValue] +MetadataType = Metadata ChannelArgumentType = Sequence[Tuple[str, Any]] EOFType = type(EOF) DoneCallbackType = Callable[[Any], None] diff --git a/src/python/grpcio_tests/tests_aio/interop/methods.py b/src/python/grpcio_tests/tests_aio/interop/methods.py index 019b3bca894..4ff10ff1572 100644 --- a/src/python/grpcio_tests/tests_aio/interop/methods.py +++ b/src/python/grpcio_tests/tests_aio/interop/methods.py @@ -287,8 +287,10 @@ async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub): async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): initial_metadata_value = "test_initial_metadata_value" trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" - metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), - (_TRAILING_METADATA_KEY, trailing_metadata_value)) + metadata = aio.Metadata( + (_INITIAL_METADATA_KEY, initial_metadata_value), + (_TRAILING_METADATA_KEY, trailing_metadata_value), + ) async def _validate_metadata(call): initial_metadata = dict(await call.initial_metadata()) diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index dab9454c58d..7f9e2e89701 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -16,18 +16,19 @@ import asyncio import grpc from typing import AsyncIterable from grpc.experimental import aio -from grpc.experimental.aio._typing import MetadataType, MetadatumType +from grpc.experimental.aio._typing import MetadataType, MetadatumType, MetadataKey, MetadataValue from tests.unit.framework.common import test_constants def seen_metadata(expected: MetadataType, actual: MetadataType): - return not bool(set(expected) - set(actual)) + return not bool(set(tuple(expected)) - set(tuple(actual))) -def seen_metadatum(expected: MetadatumType, actual: MetadataType): - metadata_dict = dict(actual) - return metadata_dict.get(expected[0]) == expected[1] +def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, + actual: MetadataType) -> bool: + obtained = actual[expected_key] + assert obtained == expected_value async def block_until_certain_state(channel: aio.Channel, @@ -50,7 +51,7 @@ def inject_callbacks(call: aio.Call): second_callback_ran = asyncio.Event() def second_callback(call): - # Validate that all resopnses have been received + # Validate that all responses have been received # and the call is an end state. assert call.done() second_callback_ran.set() diff --git a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py index fbcd99b39e3..9bd652a43a6 100644 --- a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py @@ -18,11 +18,14 @@ import unittest import grpc +from grpc.experimental import aio from grpc.experimental.aio._call import AioRpcError from tests_aio.unit._test_base import AioTestBase -_TEST_INITIAL_METADATA = ('initial metadata',) -_TEST_TRAILING_METADATA = ('trailing metadata',) +_TEST_INITIAL_METADATA = aio.Metadata( + ('initial metadata key', 'initial metadata value')) +_TEST_TRAILING_METADATA = aio.Metadata( + ('trailing metadata key', 'trailing metadata value')) _TEST_DEBUG_ERROR_STRING = '{This is a debug string}' diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py index 8f5a356ca4a..ae1ad54acd9 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py @@ -25,7 +25,7 @@ from tests_aio.unit._test_base import AioTestBase from src.proto.grpc.testing import messages_pb2, test_pb2_grpc _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' -_INITIAL_METADATA_TO_INJECT = ( +_INITIAL_METADATA_TO_INJECT = aio.Metadata( (_INITIAL_METADATA_KEY, 'extra info'), (_TRAILING_METADATA_KEY, b'\x13\x37'), ) @@ -162,7 +162,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): async def test_retry(self): class RetryInterceptor(aio.UnaryUnaryClientInterceptor): - """Simulates a Retry Interceptor which ends up by making + """Simulates a Retry Interceptor which ends up by making two RPC calls.""" def __init__(self): @@ -550,11 +550,12 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): async def intercept_unary_unary(self, continuation, client_call_details, request): + new_metadata = aio.Metadata(*client_call_details.metadata, + *_INITIAL_METADATA_TO_INJECT) new_details = aio.ClientCallDetails( method=client_call_details.method, timeout=client_call_details.timeout, - metadata=client_call_details.metadata + - _INITIAL_METADATA_TO_INJECT, + metadata=new_metadata, credentials=client_call_details.credentials, wait_for_ready=client_call_details.wait_for_ready, ) @@ -568,14 +569,20 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): # Expected to see the echoed initial metadata self.assertTrue( - _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await - call.initial_metadata())) - + _common.seen_metadatum( + expected_key=_INITIAL_METADATA_KEY, + expected_value=_INITIAL_METADATA_TO_INJECT[ + _INITIAL_METADATA_KEY], + actual=await call.initial_metadata(), + )) # Expected to see the echoed trailing metadata self.assertTrue( - _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await - call.trailing_metadata())) - + _common.seen_metadatum( + expected_key=_TRAILING_METADATA_KEY, + expected_value=_INITIAL_METADATA_TO_INJECT[ + _TRAILING_METADATA_KEY], + actual=await call.trailing_metadata(), + )) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_add_done_callback_before_finishes(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 6551e4ca084..16f91430fce 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -37,38 +37,33 @@ _TEST_STREAM_STREAM = '/test/TestStreamStream' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' -_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = ( +_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata( ('client-to-server', 'question'), ('client-to-server-bin', b'\x07\x07\x07'), ) -_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = ( +_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata( ('server-to-client', 'answer'), ('server-to-client-bin', b'\x06\x06\x06'), ) -_TRAILING_METADATA = (('a-trailing-metadata', 'stack-trace'), - ('a-trailing-metadata-bin', b'\x05\x05\x05')) -_INITIAL_METADATA_FOR_GENERIC_HANDLER = (('a-must-have-key', 'secret'),) +_TRAILING_METADATA = aio.Metadata( + ('a-trailing-metadata', 'stack-trace'), + ('a-trailing-metadata-bin', b'\x05\x05\x05'), +) +_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( + ('a-must-have-key', 'secret'),) _INVALID_METADATA_TEST_CASES = ( ( TypeError, - ((42, 42),), - ), - ( - TypeError, - (({}, {}),), - ), - ( - TypeError, - (('normal', object()),), + aio.Metadata((42, 42),), ), ( TypeError, - object(), + aio.Metadata(({}, {}),), ), ( TypeError, - (object(),), + aio.Metadata(('normal', object()),), ), )