diff --git a/src/python/grpcio/grpc/_compression.py b/src/python/grpcio/grpc/_compression.py index 4035844db97..45339c3afe2 100644 --- a/src/python/grpcio/grpc/_compression.py +++ b/src/python/grpcio/grpc/_compression.py @@ -39,7 +39,7 @@ def create_channel_option(compression): int(compression)),) if compression else () -def augment_metadata(metadata, compression) -> tuple: +def augment_metadata(metadata, compression): if not metadata and not compression: return None base_metadata = tuple(metadata) if metadata else () diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 7ecc2dd8e1b..c121bd6b76d 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -26,7 +26,7 @@ from grpc._cython import cygrpc from . import _base_call from ._metadata import Metadata -from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, +from ._typing import (DeserializingFunction, DoneCallbackType, MetadatumType, RequestIterableType, RequestType, ResponseType, SerializingFunction) @@ -61,15 +61,15 @@ class AioRpcError(grpc.RpcError): _code: grpc.StatusCode _details: Optional[str] - _initial_metadata: Optional[MetadataType] - _trailing_metadata: Optional[MetadataType] + _initial_metadata: Optional[Metadata] + _trailing_metadata: Optional[Metadata] _debug_error_string: Optional[str] def __init__(self, code: grpc.StatusCode, details: Optional[str] = None, - initial_metadata: Optional[MetadataType] = None, - trailing_metadata: Optional[MetadataType] = None, + initial_metadata: Optional[Metadata] = None, + trailing_metadata: Optional[Metadata] = None, debug_error_string: Optional[str] = None) -> None: """Constructor. @@ -84,8 +84,8 @@ class AioRpcError(grpc.RpcError): super().__init__(self) self._code = code self._details = details - self._initial_metadata = Metadata(*(initial_metadata or ())) - self._trailing_metadata = Metadata(*(trailing_metadata or ())) + self._initial_metadata = initial_metadata + self._trailing_metadata = trailing_metadata self._debug_error_string = debug_error_string def code(self) -> grpc.StatusCode: @@ -104,7 +104,7 @@ class AioRpcError(grpc.RpcError): """ return self._details - def initial_metadata(self) -> Optional[MetadataType]: + def initial_metadata(self) -> Metadata: """Accesses the initial metadata sent by the server. Returns: @@ -112,7 +112,7 @@ class AioRpcError(grpc.RpcError): """ return self._initial_metadata - def trailing_metadata(self) -> Optional[MetadataType]: + def trailing_metadata(self) -> Metadata: """Accesses the trailing metadata sent by the server. Returns: @@ -141,13 +141,13 @@ class AioRpcError(grpc.RpcError): return self._repr() -def _create_rpc_error(initial_metadata: Optional[MetadataType], +def _create_rpc_error(initial_metadata: Metadata, status: cygrpc.AioRpcStatus) -> AioRpcError: return AioRpcError( _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], status.details(), - initial_metadata, - status.trailing_metadata(), + Metadata.from_tuple(initial_metadata), + Metadata.from_tuple(status.trailing_metadata()), status.debug_error_string(), ) @@ -164,7 +164,7 @@ class Call: _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction - def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType, + def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: @@ -204,14 +204,14 @@ class Call: def time_remaining(self) -> Optional[float]: return self._cython_call.time_remaining() - async def initial_metadata(self) -> MetadataType: + async def initial_metadata(self) -> Metadata: raw_metadata_tuple = await self._cython_call.initial_metadata() - return Metadata(*(raw_metadata_tuple or ())) + return Metadata.from_tuple(raw_metadata_tuple) - async def trailing_metadata(self) -> MetadataType: + async def trailing_metadata(self) -> Metadata: raw_metadata_tuple = (await self._cython_call.status()).trailing_metadata() - return Metadata(*(raw_metadata_tuple or ())) + return Metadata.from_tuple(raw_metadata_tuple) async def code(self) -> grpc.StatusCode: cygrpc_code = (await self._cython_call.status()).code() @@ -474,7 +474,7 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -523,7 +523,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -563,7 +563,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, # pylint: disable=too-many-arguments def __init__(self, request_iterator: Optional[RequestIterableType], - deadline: Optional[float], metadata: MetadataType, + deadline: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -601,7 +601,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, # pylint: disable=too-many-arguments def __init__(self, request_iterator: Optional[RequestIterableType], - deadline: Optional[float], metadata: MetadataType, + deadline: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, diff --git a/src/python/grpcio/grpc/experimental/aio/_metadata.py b/src/python/grpcio/grpc/experimental/aio/_metadata.py index ff970106748..3230445d58a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_metadata.py +++ b/src/python/grpcio/grpc/experimental/aio/_metadata.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the metadata abstraction for gRPC Asyncio Python.""" -from typing import List, Tuple, Iterator, Any, Text, Union +from typing import List, Tuple, Iterator, Any, Union from collections import abc, OrderedDict -MetadataKey = Text +MetadataKey = str MetadataValue = Union[str, bytes] @@ -37,6 +37,12 @@ class Metadata(abc.Mapping): for md_key, md_value in args: self.add(md_key, md_value) + @classmethod + def from_tuple(cls, raw_metadata: tuple): + if raw_metadata: + return cls(*raw_metadata) + return cls() + def add(self, key: MetadataKey, value: MetadataValue) -> None: self._metadata.setdefault(key, []) self._metadata[key].append(value) diff --git a/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py index dda58c5ed53..c0594cb06ab 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py @@ -119,6 +119,18 @@ class TestTypeMetadata(unittest.TestCase): with self.assertRaises(KeyError): del metadata["other key"] + def test_metadata_from_tuple(self): + scenarios = ( + (None, Metadata()), + (Metadata(), Metadata()), + (self._DEFAULT_DATA, Metadata(*self._DEFAULT_DATA)), + (self._MULTI_ENTRY_DATA, Metadata(*self._MULTI_ENTRY_DATA)), + (Metadata(*self._DEFAULT_DATA), Metadata(*self._DEFAULT_DATA)), + ) + for source, expected in scenarios: + with self.subTest(raw_metadata=source, expected=expected): + self.assertEqual(expected, Metadata.from_tuple(source)) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 94a36a4c070..1961226fa6d 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -102,11 +102,11 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): async def test_call_initial_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(aio.Metadata(), await call.initial_metadata()) async def test_call_trailing_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(aio.Metadata(), await call.trailing_metadata()) async def test_call_initial_metadata_cancelable(self): coro_started = asyncio.Event() @@ -122,7 +122,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): # Test that initial metadata can still be asked thought # a cancellation happened with the previous task - self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(aio.Metadata(), await call.initial_metadata()) async def test_call_initial_metadata_multiple_waiters(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -135,7 +135,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): await call expected = [aio.Metadata() for _ in range(2)] - self.assertEqual(await asyncio.gather(*[task1, task2]), expected) + self.assertEqual(expected, await asyncio.gather(*[task1, task2])) async def test_call_code_cancelable(self): coro_started = asyncio.Event() diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 59f5596c70b..822bd134521 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -57,6 +57,10 @@ _INVALID_METADATA_TEST_CASES = ( TypeError, ((42, 42),), ), + ( + TypeError, + (({}, {}),), + ), ( TypeError, ((None, {}),),