From f912ddf7d435c410fe3bce49d977871e4f023a06 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 15 Jan 2020 12:11:54 -0800 Subject: [PATCH] Split the seen_metadata function & assign tuple() as default value --- src/python/grpcio/grpc/experimental/aio/_call.py | 2 +- .../grpcio/grpc/experimental/aio/_channel.py | 3 +++ .../grpcio/grpc/experimental/aio/_interceptor.py | 3 +-- .../grpcio/grpc/experimental/aio/_typing.py | 3 ++- .../grpcio_tests/tests_aio/unit/_common.py | 16 ++++++++-------- .../tests_aio/unit/interceptor_test.py | 15 ++++++--------- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index ebf2e935f3e..0d302f98bec 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -280,7 +280,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], - metadata: Optional[MetadataType], + metadata: MetadataType, credentials: Optional[grpc.CallCredentials], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index e8ad9598473..ec56bbc1072 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -101,6 +101,9 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): if compression: raise NotImplementedError("TODO: compression not implemented yet") + if metadata is None: + metadata = tuple() + if not self._interceptors: return UnaryUnaryCall( request, diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index e28a7632a62..85f51b19e57 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -106,8 +106,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): def __init__( # pylint: disable=R0913 self, interceptors: Sequence[UnaryUnaryClientInterceptor], request: RequestType, timeout: Optional[float], - metadata: Optional[MetadataType], - credentials: Optional[grpc.CallCredentials], + metadata: MetadataType, credentials: Optional[grpc.CallCredentials], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction) -> None: diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py index 6428fb72f98..c60eab85449 100644 --- a/src/python/grpcio/grpc/experimental/aio/_typing.py +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -20,6 +20,7 @@ RequestType = TypeVar('RequestType') ResponseType = TypeVar('ResponseType') SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] -MetadataType = Sequence[Tuple[Text, AnyStr]] +MetadatumType = Tuple[Text, AnyStr] +MetadataType = Sequence[MetadatumType] ChannelArgumentType = Sequence[Tuple[Text, Any]] EOFType = type(EOF) diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index 4f645606613..e8aa13da548 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from grpc.experimental.aio._typing import MetadataType, MetadatumType -def seen_metadata(expected, actual): + +def seen_metadata(expected: MetadataType, actual: MetadataType): + return bool(set(expected) - set(actual)) + + +def seen_metadatum(expected: MetadatumType, actual: MetadataType): metadata_dict = dict(actual) - if type(expected[0]) != tuple: - return metadata_dict.get(expected[0]) == expected[1] - else: - for metadatum in expected: - if metadata_dict.get(metadatum[0]) != metadatum[1]: - return False - return True + return metadata_dict.get(expected[0]) == expected[1] diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index 89f6678451d..6d1ae543b34 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -546,14 +546,11 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): async def intercept_unary_unary(self, continuation, client_call_details, request): - if client_call_details.metadata is not None: - new_metadata = client_call_details.metadata + _INITIAL_METADATA_TO_INJECT - else: - new_metadata = _INITIAL_METADATA_TO_INJECT new_details = aio.ClientCallDetails( method=client_call_details.method, timeout=client_call_details.timeout, - metadata=new_metadata, + metadata=client_call_details.metadata + + _INITIAL_METADATA_TO_INJECT, credentials=client_call_details.credentials, ) return await continuation(new_details, request) @@ -566,13 +563,13 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): # Expected to see the echoed initial metadata self.assertTrue( - _common.seen_metadata(_INITIAL_METADATA_TO_INJECT[0], await - call.initial_metadata())) + _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await + call.initial_metadata())) # Expected to see the echoed trailing metadata self.assertTrue( - _common.seen_metadata(_INITIAL_METADATA_TO_INJECT[1], await - call.trailing_metadata())) + _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await + call.trailing_metadata())) self.assertEqual(await call.code(), grpc.StatusCode.OK)