diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 35728433987..7ecc2dd8e1b 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -25,10 +25,10 @@ from grpc import _common from grpc._cython import cygrpc from . import _base_call +from ._metadata import Metadata from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, MetadatumType, RequestIterableType, RequestType, ResponseType, SerializingFunction) -from ._metadata import Metadata __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -84,8 +84,8 @@ class AioRpcError(grpc.RpcError): super().__init__(self) self._code = code self._details = details - self._initial_metadata = initial_metadata or Metadata() - self._trailing_metadata = trailing_metadata or Metadata() + self._initial_metadata = Metadata(*(initial_metadata or ())) + self._trailing_metadata = Metadata(*(trailing_metadata or ())) self._debug_error_string = debug_error_string def code(self) -> grpc.StatusCode: @@ -205,10 +205,13 @@ class Call: return self._cython_call.time_remaining() async def initial_metadata(self) -> MetadataType: - return await self._cython_call.initial_metadata() + raw_metadata_tuple = await self._cython_call.initial_metadata() + return Metadata(*(raw_metadata_tuple or ())) async def trailing_metadata(self) -> MetadataType: - return (await self._cython_call.status()).trailing_metadata() + raw_metadata_tuple = (await + self._cython_call.status()).trailing_metadata() + return Metadata(*(raw_metadata_tuple or ())) async def code(self) -> grpc.StatusCode: cygrpc_code = (await self._cython_call.status()).code() diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 3361af477b7..3ac12bf6139 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -29,10 +29,10 @@ from ._interceptor import ( InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, StreamStreamClientInterceptor) +from ._metadata import Metadata from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline -from ._metadata import Metadata _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) diff --git a/src/python/grpcio_tests/tests_aio/interop/methods.py b/src/python/grpcio_tests/tests_aio/interop/methods.py index 4ff10ff1572..706f4249be3 100644 --- a/src/python/grpcio_tests/tests_aio/interop/methods.py +++ b/src/python/grpcio_tests/tests_aio/interop/methods.py @@ -293,12 +293,13 @@ async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): ) async def _validate_metadata(call): - initial_metadata = dict(await call.initial_metadata()) + initial_metadata = await call.initial_metadata() if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: raise ValueError('expected initial metadata %s, got %s' % (initial_metadata_value, initial_metadata[_INITIAL_METADATA_KEY])) - trailing_metadata = dict(await call.trailing_metadata()) + + trailing_metadata = await call.trailing_metadata() if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: raise ValueError('expected trailing metadata %s, got %s' % (trailing_metadata_value, diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index 7f9e2e89701..a4a9236069c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -28,7 +28,7 @@ def seen_metadata(expected: MetadataType, actual: MetadataType): def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, actual: MetadataType) -> bool: obtained = actual[expected_key] - assert obtained == expected_value + return obtained == expected_value async def block_until_certain_state(channel: aio.Channel, diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 3ce2a2f7b52..94a36a4c070 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -102,11 +102,11 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): async def test_call_initial_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual((), await call.initial_metadata()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) async def test_call_trailing_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual((), await call.trailing_metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_initial_metadata_cancelable(self): coro_started = asyncio.Event() @@ -122,7 +122,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): # Test that initial metadata can still be asked thought # a cancellation happened with the previous task - self.assertEqual((), await call.initial_metadata()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) async def test_call_initial_metadata_multiple_waiters(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -134,8 +134,8 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): task2 = self.loop.create_task(coro()) await call - - self.assertEqual([(), ()], await asyncio.gather(*[task1, task2])) + expected = [aio.Metadata() for _ in range(2)] + self.assertEqual(await asyncio.gather(*[task1, task2]), expected) async def test_call_code_cancelable(self): coro_started = asyncio.Event() diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py index 2d58cff0e41..b9a04af00dc 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py @@ -92,8 +92,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -131,8 +131,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -230,8 +230,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py index 6137538ffca..fd542fd16e9 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py @@ -96,8 +96,8 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py index ae1ad54acd9..e64daec7df4 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py @@ -302,8 +302,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(type(response), messages_pb2.SimpleResponse) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.details(), '') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_ok_awaited(self): @@ -331,8 +331,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(type(response), messages_pb2.SimpleResponse) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.details(), '') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_rpc_error(self): @@ -364,8 +364,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(await call.details(), 'Deadline Exceeded') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_rpc_error_awaited(self): @@ -398,8 +398,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(await call.details(), 'Deadline Exceeded') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_cancel_before_rpc(self): @@ -541,8 +541,10 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.details(), _LOCAL_CANCEL_DETAILS_EXPECTATION) - self.assertEqual(await call.initial_metadata(), tuple()) - self.assertEqual(await call.trailing_metadata(), None) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual( + await call.trailing_metadata(), aio.Metadata(), + "When the raw response is None, empty metadata is returned") async def test_initial_metadata_modification(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py b/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py index 1e6e3598b8b..0bb3a3acc89 100644 --- a/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py @@ -255,7 +255,8 @@ class TestCompatibility(AioTestBase): self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary) call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) self.assertTrue( - _common.seen_metadata(metadata, await call.initial_metadata())) + _common.seen_metadata(aio.Metadata(*metadata), await + call.initial_metadata())) async def test_sync_unary_unary_abort(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 16f91430fce..59f5596c70b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -55,15 +55,15 @@ _INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( _INVALID_METADATA_TEST_CASES = ( ( TypeError, - aio.Metadata((42, 42),), + ((42, 42),), ), ( TypeError, - aio.Metadata(({}, {}),), + ((None, {}),), ), ( TypeError, - aio.Metadata(('normal', object()),), + (('normal', object()),), ), ) @@ -100,13 +100,13 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): async def _test_server_to_client(request, context): assert _REQUEST == request await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) return _RESPONSE @staticmethod async def _test_trailing_metadata(request, context): assert _REQUEST == request - context.set_trailing_metadata(_TRAILING_METADATA) + context.set_trailing_metadata(tuple(_TRAILING_METADATA)) return _RESPONSE @staticmethod @@ -115,21 +115,21 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) yield _RESPONSE - context.set_trailing_metadata(_TRAILING_METADATA) + context.set_trailing_metadata(tuple(_TRAILING_METADATA)) @staticmethod async def _test_stream_unary(request_iterator, context): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) async for request in request_iterator: assert _REQUEST == request - context.set_trailing_metadata(_TRAILING_METADATA) + context.set_trailing_metadata(tuple(_TRAILING_METADATA)) return _RESPONSE @staticmethod @@ -137,13 +137,13 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, context.invocation_metadata()) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)) async for request in request_iterator: assert _REQUEST == request yield _RESPONSE - context.set_trailing_metadata(_TRAILING_METADATA) + context.set_trailing_metadata(tuple(_TRAILING_METADATA)) def service(self, handler_call_details): return self._routing_table.get(handler_call_details.method) @@ -193,6 +193,7 @@ class TestMetadata(AioTestBase): async def test_from_server_to_client(self): multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT) call = multicallable(_REQUEST) + self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await call.initial_metadata()) self.assertEqual(_RESPONSE, await call) @@ -207,8 +208,8 @@ class TestMetadata(AioTestBase): async def test_from_client_to_server_with_list(self): multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) - call = multicallable( - _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) + call = multicallable(_REQUEST, + metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index f85e46c379a..d891ecdb771 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -198,7 +198,7 @@ class TestServerInterceptor(AioTestBase): request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - metadata = (('key', 'value'),) + metadata = aio.Metadata(('key', 'value'),) call = multicallable(messages_pb2.SimpleRequest(), metadata=metadata) await call @@ -208,7 +208,7 @@ class TestServerInterceptor(AioTestBase): ], record) record.clear() - metadata = (('key', 'value'), ('secret', '42')) + metadata = aio.Metadata(('key', 'value'), ('secret', '42')) call = multicallable(messages_pb2.SimpleRequest(), metadata=metadata) await call