diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi index caf867b5696..981e420514d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi @@ -41,6 +41,11 @@ cdef void _store_c_metadata( for index, (key, value) in enumerate(metadata): encoded_key = _encode(key) encoded_value = value if encoded_key[-4:] == b'-bin' else _encode(value) + if type(encoded_value) != bytes: + raise TypeError('Binary metadata key="%s" expected bytes, got %s' % ( + key, + type(encoded_value) + )) c_metadata[0][index].key = _slice_from_bytes(encoded_key) c_metadata[0][index].value = _slice_from_bytes(encoded_value) diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 151523dd7e2..e28a7632a62 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -150,7 +150,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): else: return UnaryUnaryCall( request, _timeout_to_deadline(client_call_details.timeout), - metadata, client_call_details.credentials, self._channel, + client_call_details.metadata, + client_call_details.credentials, self._channel, client_call_details.method, request_serializer, response_deserializer) diff --git a/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel b/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel index fd47d2c33d5..aed975f16ce 100644 --- a/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel @@ -43,6 +43,12 @@ py_library( srcs_version = "PY3", ) +py_library( + name = "_common", + srcs = ["_common.py"], + srcs_version = "PY3", +) + [ py_test( name = test_file_name[:-3], @@ -55,6 +61,7 @@ py_library( main = test_file_name, python_version = "PY3", deps = [ + ":_common", ":_constants", ":_test_base", ":_test_server", diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py new file mode 100644 index 00000000000..4f645606613 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -0,0 +1,24 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def seen_metadata(expected, actual): + metadata_dict = dict(actual) + if type(expected[0]) != tuple: + return metadata_dict.get(expected[0]) == expected[1] + else: + for metadatum in expected: + if metadata_dict.get(metadatum[0]) != metadatum[1]: + return False + return True diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index 9970178d0cd..33bad9894c5 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -18,11 +18,18 @@ import unittest import grpc from grpc.experimental import aio -from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE +from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY +from tests_aio.unit import _constants +from tests_aio.unit import _common from tests_aio.unit._test_base import AioTestBase -from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' +_INITIAL_METADATA_TO_INJECT = ( + (_INITIAL_METADATA_KEY, 'extra info'), + (_TRAILING_METADATA_KEY, b'\x13\x37'), +) class TestUnaryUnaryClientInterceptor(AioTestBase): @@ -124,7 +131,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): client_call_details, request): new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, - timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, metadata=client_call_details.metadata, credentials=client_call_details.credentials) return await continuation(new_client_call_details, request) @@ -165,7 +172,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, - timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, metadata=client_call_details.metadata, credentials=client_call_details.credentials) @@ -342,8 +349,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - call = multicallable(messages_pb2.SimpleRequest(), - timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) + call = multicallable( + messages_pb2.SimpleRequest(), + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2) with self.assertRaises(aio.AioRpcError) as exception_context: await call @@ -375,8 +383,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - call = multicallable(messages_pb2.SimpleRequest(), - timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) + call = multicallable( + messages_pb2.SimpleRequest(), + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2) with self.assertRaises(aio.AioRpcError) as exception_context: await call @@ -532,6 +541,42 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.initial_metadata(), tuple()) self.assertEqual(await call.trailing_metadata(), None) + async def test_initial_metadata_modification(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + if client_call_details.metadata is not None: + new_metadata = client_call_details.metadata + _INITIAL_METADATA_TO_INJECT + else: + new_metadata = _INITIAL_METADATA_TO_INJECT + new_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=new_metadata, + credentials=client_call_details.credentials, + ) + return await continuation(new_details, request) + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.UnaryCall(messages_pb2.SimpleRequest()) + + # Expected to see the echoed initial metadata + self.assertTrue( + _common.seen_metadata(_INITIAL_METADATA_TO_INJECT[0], await + call.initial_metadata())) + + # Expected to see the echoed trailing metadata + self.assertTrue( + _common.seen_metadata(_INITIAL_METADATA_TO_INJECT[1], await + call.trailing_metadata())) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 17a9d1a0ecf..cc361d95cbf 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -23,6 +23,7 @@ import grpc from grpc.experimental import aio from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit import _common _TEST_CLIENT_TO_SERVER = '/test/TestClientToServer' _TEST_SERVER_TO_CLIENT = '/test/TestServerToClient' @@ -69,21 +70,13 @@ _INVALID_METADATA_TEST_CASES = ( ) -def _seen_metadata(expected, actual): - metadata_dict = dict(actual) - for metadatum in expected: - if metadata_dict.get(metadatum[0]) != metadatum[1]: - return False - return True - - class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): @staticmethod async def _test_client_to_server(request, context): assert _REQUEST == request - assert _seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, - context.invocation_metadata()) + assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata()) return _RESPONSE @staticmethod @@ -120,8 +113,8 @@ class _TestGenericHandlerItself(grpc.GenericRpcHandler): return _RESPONSE def service(self, handler_details): - assert _seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER, - handler_details.invocation_metadata) + assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER, + handler_details.invocation_metadata) return grpc.unary_unary_rpc_method_handler(self._method)