diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index f37769c0038..b842ec6f2ba 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -143,7 +143,7 @@ cdef class _ServicerContext: self._loop) self._rpc_state.metadata_sent = True - async def send_initial_metadata(self, tuple metadata): + async def send_initial_metadata(self, object metadata): self._rpc_state.raise_for_termination() if self._rpc_state.metadata_sent: @@ -151,7 +151,7 @@ cdef class _ServicerContext: else: await _send_initial_metadata( self._rpc_state, - _augment_metadata(metadata, self._rpc_state.compression_algorithm), + _augment_metadata(tuple(metadata), self._rpc_state.compression_algorithm), _EMPTY_FLAG, self._loop ) @@ -192,8 +192,8 @@ cdef class _ServicerContext: async def abort_with_status(self, object status): await self.abort(status.code, status.details, status.trailing_metadata) - def set_trailing_metadata(self, tuple metadata): - self._rpc_state.trailing_metadata = metadata + def set_trailing_metadata(self, object metadata): + self._rpc_state.trailing_metadata = tuple(metadata) def invocation_metadata(self): return self._rpc_state.invocation_metadata() @@ -233,13 +233,13 @@ cdef class _SyncServicerContext: # Abort should raise an AbortError future.exception() - def send_initial_metadata(self, tuple metadata): + def send_initial_metadata(self, object metadata): future = asyncio.run_coroutine_threadsafe( self._context.send_initial_metadata(metadata), self._loop) future.result() - def set_trailing_metadata(self, tuple metadata): + def set_trailing_metadata(self, object metadata): self._context.set_trailing_metadata(metadata) def invocation_metadata(self): @@ -303,7 +303,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state, object response_serializer, object loop): """Finishes server method handler with a single response. - + This function executes the application handler, and handles response sending, as well as errors. It is shared between unary-unary and stream-unary handlers. @@ -378,7 +378,7 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state, """ cdef object async_response_generator cdef object response_message - + if inspect.iscoroutinefunction(stream_handler): # Case 1: Coroutine async handler - using reader-writer API # The handler uses reader / writer API, returns None. diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index 214e208c005..4ccbb3be132 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -23,8 +23,8 @@ from typing import AsyncIterable, Awaitable, Generic, Optional, Union import grpc -from ._typing import (DoneCallbackType, EOFType, MetadataType, RequestType, - ResponseType) +from ._typing import (DoneCallbackType, EOFType, RequestType, ResponseType) +from ._metadata import Metadata __all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -86,7 +86,7 @@ class Call(RpcContext, metaclass=ABCMeta): """The abstract base class of an RPC on the client-side.""" @abstractmethod - async def initial_metadata(self) -> MetadataType: + async def initial_metadata(self) -> Metadata: """Accesses the initial metadata sent by the server. Returns: @@ -94,7 +94,7 @@ class Call(RpcContext, metaclass=ABCMeta): """ @abstractmethod - async def trailing_metadata(self) -> MetadataType: + async def trailing_metadata(self) -> Metadata: """Accesses the trailing metadata sent by the server. Returns: diff --git a/src/python/grpcio/grpc/experimental/aio/_base_channel.py b/src/python/grpcio/grpc/experimental/aio/_base_channel.py index 33efa7789cb..4b4ea1355b4 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_channel.py @@ -19,10 +19,9 @@ from typing import Any, Optional import grpc from . import _base_call -from ._typing import (DeserializingFunction, MetadataType, RequestIterableType, +from ._typing import (DeserializingFunction, RequestIterableType, SerializingFunction) - -_IMMUTABLE_EMPTY_TUPLE = tuple() +from ._metadata import Metadata class UnaryUnaryMultiCallable(abc.ABC): @@ -33,7 +32,7 @@ class UnaryUnaryMultiCallable(abc.ABC): request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -71,7 +70,7 @@ class UnaryStreamMultiCallable(abc.ABC): request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -108,7 +107,7 @@ class StreamUnaryMultiCallable(abc.ABC): def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -146,7 +145,7 @@ class StreamStreamMultiCallable(abc.ABC): def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None diff --git a/src/python/grpcio/grpc/experimental/aio/_base_server.py b/src/python/grpcio/grpc/experimental/aio/_base_server.py index 72e4288c94f..86c15fc86b0 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_server.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_server.py @@ -18,7 +18,8 @@ from typing import Generic, Optional, Sequence import grpc -from ._typing import MetadataType, RequestType, ResponseType +from ._typing import RequestType, ResponseType +from ._metadata import Metadata class Server(abc.ABC): @@ -157,8 +158,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): """ @abc.abstractmethod - async def send_initial_metadata(self, - initial_metadata: MetadataType) -> None: + async def send_initial_metadata(self, initial_metadata: Metadata) -> None: """Sends the initial metadata value to the client. This method need not be called by implementations if they have no @@ -170,7 +170,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): @abc.abstractmethod async def abort(self, code: grpc.StatusCode, details: str, - trailing_metadata: MetadataType) -> None: + trailing_metadata: Metadata) -> None: """Raises an exception to terminate the RPC with a non-OK status. The code and details passed as arguments will supercede any existing @@ -190,8 +190,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): """ @abc.abstractmethod - async def set_trailing_metadata(self, - trailing_metadata: MetadataType) -> None: + async def set_trailing_metadata(self, trailing_metadata: Metadata) -> None: """Sends the trailing metadata for the RPC. This method need not be called by implementations if they have no @@ -202,7 +201,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): """ @abc.abstractmethod - def invocation_metadata(self) -> Optional[MetadataType]: + def invocation_metadata(self) -> Optional[Metadata]: """Accesses the metadata from the sent by the client. Returns: diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index a0693921461..ba229f35c39 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -25,9 +25,10 @@ from grpc import _common from grpc._cython import cygrpc from . import _base_call -from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, - MetadatumType, RequestIterableType, RequestType, - ResponseType, SerializingFunction) +from ._metadata import Metadata +from ._typing import (DeserializingFunction, DoneCallbackType, MetadatumType, + RequestIterableType, RequestType, ResponseType, + SerializingFunction) __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -58,22 +59,17 @@ class AioRpcError(grpc.RpcError): determined. Hence, its methods no longer needs to be coroutines. """ - # TODO(https://github.com/grpc/grpc/issues/20144) Metadata - # type returned by `initial_metadata` and `trailing_metadata` - # and also taken in the constructor needs to be revisit and make - # it more specific. - _code: grpc.StatusCode _details: Optional[str] - _initial_metadata: Optional[MetadataType] - _trailing_metadata: Optional[MetadataType] + _initial_metadata: Optional[Metadata] + _trailing_metadata: Optional[Metadata] _debug_error_string: Optional[str] def __init__(self, code: grpc.StatusCode, + initial_metadata: Metadata, + trailing_metadata: Metadata, details: Optional[str] = None, - initial_metadata: Optional[MetadataType] = None, - trailing_metadata: Optional[MetadataType] = None, debug_error_string: Optional[str] = None) -> None: """Constructor. @@ -108,7 +104,7 @@ class AioRpcError(grpc.RpcError): """ return self._details - def initial_metadata(self) -> Optional[MetadataType]: + def initial_metadata(self) -> Metadata: """Accesses the initial metadata sent by the server. Returns: @@ -116,7 +112,7 @@ class AioRpcError(grpc.RpcError): """ return self._initial_metadata - def trailing_metadata(self) -> Optional[MetadataType]: + def trailing_metadata(self) -> Metadata: """Accesses the trailing metadata sent by the server. Returns: @@ -145,14 +141,14 @@ class AioRpcError(grpc.RpcError): return self._repr() -def _create_rpc_error(initial_metadata: Optional[MetadataType], +def _create_rpc_error(initial_metadata: Metadata, status: cygrpc.AioRpcStatus) -> AioRpcError: return AioRpcError( _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], - status.details(), - initial_metadata, - status.trailing_metadata(), - status.debug_error_string(), + Metadata.from_tuple(initial_metadata), + Metadata.from_tuple(status.trailing_metadata()), + details=status.details(), + debug_error_string=status.debug_error_string(), ) @@ -168,7 +164,7 @@ class Call: _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction - def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType, + def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: @@ -208,11 +204,14 @@ class Call: def time_remaining(self) -> Optional[float]: return self._cython_call.time_remaining() - async def initial_metadata(self) -> MetadataType: - return await self._cython_call.initial_metadata() + async def initial_metadata(self) -> Metadata: + raw_metadata_tuple = await self._cython_call.initial_metadata() + return Metadata.from_tuple(raw_metadata_tuple) - async def trailing_metadata(self) -> MetadataType: - return (await self._cython_call.status()).trailing_metadata() + async def trailing_metadata(self) -> Metadata: + raw_metadata_tuple = (await + self._cython_call.status()).trailing_metadata() + return Metadata.from_tuple(raw_metadata_tuple) async def code(self) -> grpc.StatusCode: cygrpc_code = (await self._cython_call.status()).code() @@ -475,7 +474,7 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -524,7 +523,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -564,7 +563,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, # pylint: disable=too-many-arguments def __init__(self, request_iterator: Optional[RequestIterableType], - deadline: Optional[float], metadata: MetadataType, + deadline: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -602,7 +601,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, # pylint: disable=too-many-arguments def __init__(self, request_iterator: Optional[RequestIterableType], - deadline: Optional[float], metadata: MetadataType, + deadline: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 7427872e0b3..1995db13bf5 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -29,11 +29,11 @@ from ._interceptor import ( InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, StreamStreamClientInterceptor) -from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, +from ._metadata import Metadata +from ._typing import (ChannelArgumentType, DeserializingFunction, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline -_IMMUTABLE_EMPTY_TUPLE = tuple() _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) if sys.version_info[1] < 7: @@ -88,6 +88,19 @@ class _BaseMultiCallable: self._response_deserializer = response_deserializer self._interceptors = interceptors + @staticmethod + def _init_metadata(metadata: Optional[Metadata] = None, + compression: Optional[grpc.Compression] = None + ) -> Metadata: + """Based on the provided values for or initialise the final + metadata, as it should be used for the current call. + """ + metadata = metadata or Metadata() + if compression: + metadata = Metadata( + *_compression.augment_metadata(metadata, compression)) + return metadata + class UnaryUnaryMultiCallable(_BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable): @@ -96,14 +109,13 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable, request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.UnaryUnaryCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) if not self._interceptors: call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), metadata, credentials, wait_for_ready, @@ -127,14 +139,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable, request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.UnaryStreamCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: @@ -158,14 +169,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable, def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.StreamUnaryCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: @@ -189,14 +199,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable, def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, + metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.StreamStreamCall: - if compression: - metadata = _compression.augment_metadata(metadata, compression) + metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index e276ae0c922..80e9625c553 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -27,8 +27,9 @@ from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS from ._call import _API_STYLE_ERROR from ._utils import _timeout_to_deadline from ._typing import (RequestType, SerializingFunction, DeserializingFunction, - MetadataType, ResponseType, DoneCallbackType, - RequestIterableType, ResponseIterableType) + ResponseType, DoneCallbackType, RequestIterableType, + ResponseIterableType) +from ._metadata import Metadata _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' @@ -82,7 +83,7 @@ class ClientCallDetails( method: str timeout: Optional[float] - metadata: Optional[MetadataType] + metadata: Optional[Metadata] credentials: Optional[grpc.CallCredentials] wait_for_ready: Optional[bool] @@ -248,7 +249,7 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): class InterceptedCall: - """Base implementation for all intecepted call arities. + """Base implementation for all intercepted call arities. Interceptors might have some work to do before the RPC invocation with the capacity of changing the invocation parameters, and some work to do @@ -370,7 +371,7 @@ class InterceptedCall: def time_remaining(self) -> Optional[float]: raise NotImplementedError() - async def initial_metadata(self) -> Optional[MetadataType]: + async def initial_metadata(self) -> Optional[Metadata]: try: call = await self._interceptors_task except AioRpcError as err: @@ -380,7 +381,7 @@ class InterceptedCall: return await call.initial_metadata() - async def trailing_metadata(self) -> Optional[MetadataType]: + async def trailing_metadata(self) -> Optional[Metadata]: try: call = await self._interceptors_task except AioRpcError as err: @@ -556,7 +557,7 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor], request: RequestType, timeout: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -573,7 +574,7 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, # pylint: disable=too-many-arguments async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], + metadata: Optional[Metadata], credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], request: RequestType, request_serializer: SerializingFunction, @@ -628,7 +629,7 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor], request: RequestType, timeout: Optional[float], - metadata: MetadataType, + metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -647,7 +648,7 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, # pylint: disable=too-many-arguments async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], + metadata: Optional[Metadata], credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], request: RequestType, request_serializer: SerializingFunction, @@ -712,7 +713,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor], request_iterator: Optional[RequestIterableType], - timeout: Optional[float], metadata: MetadataType, + timeout: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -731,7 +732,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, async def _invoke( self, interceptors: Sequence[StreamUnaryClientInterceptor], method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], + metadata: Optional[Metadata], credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], request_iterator: RequestIterableType, @@ -783,7 +784,7 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor], request_iterator: Optional[RequestIterableType], - timeout: Optional[float], metadata: MetadataType, + timeout: Optional[float], metadata: Metadata, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -804,7 +805,7 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, async def _invoke( self, interceptors: Sequence[StreamStreamClientInterceptor], method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], + metadata: Optional[Metadata], credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], request_iterator: RequestIterableType, @@ -876,10 +877,10 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): def time_remaining(self) -> Optional[float]: raise NotImplementedError() - async def initial_metadata(self) -> Optional[MetadataType]: + async def initial_metadata(self) -> Optional[Metadata]: return None - async def trailing_metadata(self) -> Optional[MetadataType]: + async def trailing_metadata(self) -> Optional[Metadata]: return None async def code(self) -> grpc.StatusCode: @@ -928,10 +929,10 @@ class _StreamCallResponseIterator: def time_remaining(self) -> Optional[float]: return self._call.time_remaining() - async def initial_metadata(self) -> Optional[MetadataType]: + async def initial_metadata(self) -> Optional[Metadata]: return await self._call.initial_metadata() - async def trailing_metadata(self) -> Optional[MetadataType]: + async def trailing_metadata(self) -> Optional[Metadata]: return await self._call.trailing_metadata() async def code(self) -> grpc.StatusCode: diff --git a/src/python/grpcio/grpc/experimental/aio/_metadata.py b/src/python/grpcio/grpc/experimental/aio/_metadata.py index ff970106748..3230445d58a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_metadata.py +++ b/src/python/grpcio/grpc/experimental/aio/_metadata.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the metadata abstraction for gRPC Asyncio Python.""" -from typing import List, Tuple, Iterator, Any, Text, Union +from typing import List, Tuple, Iterator, Any, Union from collections import abc, OrderedDict -MetadataKey = Text +MetadataKey = str MetadataValue = Union[str, bytes] @@ -37,6 +37,12 @@ class Metadata(abc.Mapping): for md_key, md_value in args: self.add(md_key, md_value) + @classmethod + def from_tuple(cls, raw_metadata: tuple): + if raw_metadata: + return cls(*raw_metadata) + return cls() + def add(self, key: MetadataKey, value: MetadataValue) -> None: self._metadata.setdefault(key, []) self._metadata[key].append(value) diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py index a02ec8ff803..7e2e8da8a06 100644 --- a/src/python/grpcio/grpc/experimental/aio/_typing.py +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -13,17 +13,18 @@ # limitations under the License. """Common types for gRPC Async API""" -from typing import (Any, AnyStr, AsyncIterable, Callable, Iterable, Sequence, - Tuple, TypeVar, Union) +from typing import (Any, AsyncIterable, Callable, Iterable, Sequence, Tuple, + TypeVar, Union) from grpc._cython.cygrpc import EOF +from ._metadata import Metadata, MetadataKey, MetadataValue RequestType = TypeVar('RequestType') ResponseType = TypeVar('ResponseType') SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] -MetadatumType = Tuple[str, AnyStr] -MetadataType = Sequence[MetadatumType] +MetadatumType = Tuple[MetadataKey, MetadataValue] +MetadataType = Metadata ChannelArgumentType = Sequence[Tuple[str, Any]] EOFType = type(EOF) DoneCallbackType = Callable[[Any], None] diff --git a/src/python/grpcio_tests/tests_aio/interop/methods.py b/src/python/grpcio_tests/tests_aio/interop/methods.py index 019b3bca894..706f4249be3 100644 --- a/src/python/grpcio_tests/tests_aio/interop/methods.py +++ b/src/python/grpcio_tests/tests_aio/interop/methods.py @@ -287,16 +287,19 @@ async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub): async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): initial_metadata_value = "test_initial_metadata_value" trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" - metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), - (_TRAILING_METADATA_KEY, trailing_metadata_value)) + metadata = aio.Metadata( + (_INITIAL_METADATA_KEY, initial_metadata_value), + (_TRAILING_METADATA_KEY, trailing_metadata_value), + ) async def _validate_metadata(call): - initial_metadata = dict(await call.initial_metadata()) + initial_metadata = await call.initial_metadata() if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: raise ValueError('expected initial metadata %s, got %s' % (initial_metadata_value, initial_metadata[_INITIAL_METADATA_KEY])) - trailing_metadata = dict(await call.trailing_metadata()) + + trailing_metadata = await call.trailing_metadata() if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: raise ValueError('expected trailing metadata %s, got %s' % (trailing_metadata_value, diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index dab9454c58d..7fdd120e31b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -16,18 +16,20 @@ import asyncio import grpc from typing import AsyncIterable from grpc.experimental import aio -from grpc.experimental.aio._typing import MetadataType, MetadatumType +from grpc.experimental.aio._typing import MetadatumType, MetadataKey, MetadataValue +from grpc.experimental.aio._metadata import Metadata from tests.unit.framework.common import test_constants -def seen_metadata(expected: MetadataType, actual: MetadataType): - return not bool(set(expected) - set(actual)) +def seen_metadata(expected: Metadata, actual: Metadata): + return not bool(set(tuple(expected)) - set(tuple(actual))) -def seen_metadatum(expected: MetadatumType, actual: MetadataType): - metadata_dict = dict(actual) - return metadata_dict.get(expected[0]) == expected[1] +def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, + actual: Metadata) -> bool: + obtained = actual[expected_key] + return obtained == expected_value async def block_until_certain_state(channel: aio.Channel, @@ -50,7 +52,7 @@ def inject_callbacks(call: aio.Call): second_callback_ran = asyncio.Event() def second_callback(call): - # Validate that all resopnses have been received + # Validate that all responses have been received # and the call is an end state. assert call.done() second_callback_ran.set() diff --git a/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py index dda58c5ed53..c0594cb06ab 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py @@ -119,6 +119,18 @@ class TestTypeMetadata(unittest.TestCase): with self.assertRaises(KeyError): del metadata["other key"] + def test_metadata_from_tuple(self): + scenarios = ( + (None, Metadata()), + (Metadata(), Metadata()), + (self._DEFAULT_DATA, Metadata(*self._DEFAULT_DATA)), + (self._MULTI_ENTRY_DATA, Metadata(*self._MULTI_ENTRY_DATA)), + (Metadata(*self._DEFAULT_DATA), Metadata(*self._DEFAULT_DATA)), + ) + for source, expected in scenarios: + with self.subTest(raw_metadata=source, expected=expected): + self.assertEqual(expected, Metadata.from_tuple(source)) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py index fbcd99b39e3..416c51a7080 100644 --- a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py @@ -18,11 +18,14 @@ import unittest import grpc +from grpc.experimental import aio from grpc.experimental.aio._call import AioRpcError from tests_aio.unit._test_base import AioTestBase -_TEST_INITIAL_METADATA = ('initial metadata',) -_TEST_TRAILING_METADATA = ('trailing metadata',) +_TEST_INITIAL_METADATA = aio.Metadata( + ('initial metadata key', 'initial metadata value')) +_TEST_TRAILING_METADATA = aio.Metadata( + ('trailing metadata key', 'trailing metadata value')) _TEST_DEBUG_ERROR_STRING = '{This is a debug string}' @@ -30,9 +33,9 @@ class TestAioRpcError(unittest.TestCase): def test_attributes(self): aio_rpc_error = AioRpcError(grpc.StatusCode.CANCELLED, - 'details', initial_metadata=_TEST_INITIAL_METADATA, trailing_metadata=_TEST_TRAILING_METADATA, + details="details", debug_error_string=_TEST_DEBUG_ERROR_STRING) self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED) self.assertEqual(aio_rpc_error.details(), 'details') diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 3ce2a2f7b52..1961226fa6d 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -102,11 +102,11 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): async def test_call_initial_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual((), await call.initial_metadata()) + self.assertEqual(aio.Metadata(), await call.initial_metadata()) async def test_call_trailing_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual((), await call.trailing_metadata()) + self.assertEqual(aio.Metadata(), await call.trailing_metadata()) async def test_call_initial_metadata_cancelable(self): coro_started = asyncio.Event() @@ -122,7 +122,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): # Test that initial metadata can still be asked thought # a cancellation happened with the previous task - self.assertEqual((), await call.initial_metadata()) + self.assertEqual(aio.Metadata(), await call.initial_metadata()) async def test_call_initial_metadata_multiple_waiters(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -134,8 +134,8 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): task2 = self.loop.create_task(coro()) await call - - self.assertEqual([(), ()], await asyncio.gather(*[task1, task2])) + expected = [aio.Metadata() for _ in range(2)] + self.assertEqual(expected, await asyncio.gather(*[task1, task2])) async def test_call_code_cancelable(self): coro_started = asyncio.Event() diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py index 9ab54b39a6b..ce6a7bc04d6 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py @@ -98,8 +98,8 @@ class TestStreamStreamClientInterceptor(AioTestBase): self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -140,8 +140,8 @@ class TestStreamStreamClientInterceptor(AioTestBase): await call.done_writing() self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -183,8 +183,8 @@ class TestStreamStreamClientInterceptor(AioTestBase): await call.done_writing() self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py index 2d58cff0e41..b9a04af00dc 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py @@ -92,8 +92,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -131,8 +131,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) @@ -230,8 +230,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase): self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py index 6137538ffca..fd542fd16e9 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py @@ -96,8 +96,8 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py index 8f5a356ca4a..e64daec7df4 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py @@ -25,7 +25,7 @@ from tests_aio.unit._test_base import AioTestBase from src.proto.grpc.testing import messages_pb2, test_pb2_grpc _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' -_INITIAL_METADATA_TO_INJECT = ( +_INITIAL_METADATA_TO_INJECT = aio.Metadata( (_INITIAL_METADATA_KEY, 'extra info'), (_TRAILING_METADATA_KEY, b'\x13\x37'), ) @@ -162,7 +162,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): async def test_retry(self): class RetryInterceptor(aio.UnaryUnaryClientInterceptor): - """Simulates a Retry Interceptor which ends up by making + """Simulates a Retry Interceptor which ends up by making two RPC calls.""" def __init__(self): @@ -302,8 +302,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(type(response), messages_pb2.SimpleResponse) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.details(), '') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_ok_awaited(self): @@ -331,8 +331,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(type(response), messages_pb2.SimpleResponse) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.details(), '') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_rpc_error(self): @@ -364,8 +364,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(await call.details(), 'Deadline Exceeded') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_rpc_error_awaited(self): @@ -398,8 +398,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(await call.details(), 'Deadline Exceeded') - self.assertEqual(await call.initial_metadata(), ()) - self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_cancel_before_rpc(self): @@ -541,8 +541,10 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.details(), _LOCAL_CANCEL_DETAILS_EXPECTATION) - self.assertEqual(await call.initial_metadata(), tuple()) - self.assertEqual(await call.trailing_metadata(), None) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual( + await call.trailing_metadata(), aio.Metadata(), + "When the raw response is None, empty metadata is returned") async def test_initial_metadata_modification(self): @@ -550,11 +552,12 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): async def intercept_unary_unary(self, continuation, client_call_details, request): + new_metadata = aio.Metadata(*client_call_details.metadata, + *_INITIAL_METADATA_TO_INJECT) new_details = aio.ClientCallDetails( method=client_call_details.method, timeout=client_call_details.timeout, - metadata=client_call_details.metadata + - _INITIAL_METADATA_TO_INJECT, + metadata=new_metadata, credentials=client_call_details.credentials, wait_for_ready=client_call_details.wait_for_ready, ) @@ -568,14 +571,20 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): # Expected to see the echoed initial metadata self.assertTrue( - _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await - call.initial_metadata())) - + _common.seen_metadatum( + expected_key=_INITIAL_METADATA_KEY, + expected_value=_INITIAL_METADATA_TO_INJECT[ + _INITIAL_METADATA_KEY], + actual=await call.initial_metadata(), + )) # Expected to see the echoed trailing metadata self.assertTrue( - _common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await - call.trailing_metadata())) - + _common.seen_metadatum( + expected_key=_TRAILING_METADATA_KEY, + expected_value=_INITIAL_METADATA_TO_INJECT[ + _TRAILING_METADATA_KEY], + actual=await call.trailing_metadata(), + )) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_add_done_callback_before_finishes(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py b/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py index 1e6e3598b8b..0bb3a3acc89 100644 --- a/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py @@ -255,7 +255,8 @@ class TestCompatibility(AioTestBase): self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary) call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) self.assertTrue( - _common.seen_metadata(metadata, await call.initial_metadata())) + _common.seen_metadata(aio.Metadata(*metadata), await + call.initial_metadata())) async def test_sync_unary_unary_abort(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 6551e4ca084..c1fa97b3e4c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -37,17 +37,20 @@ _TEST_STREAM_STREAM = '/test/TestStreamStream' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' -_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = ( +_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata( ('client-to-server', 'question'), ('client-to-server-bin', b'\x07\x07\x07'), ) -_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = ( +_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata( ('server-to-client', 'answer'), ('server-to-client-bin', b'\x06\x06\x06'), ) -_TRAILING_METADATA = (('a-trailing-metadata', 'stack-trace'), - ('a-trailing-metadata-bin', b'\x05\x05\x05')) -_INITIAL_METADATA_FOR_GENERIC_HANDLER = (('a-must-have-key', 'secret'),) +_TRAILING_METADATA = aio.Metadata( + ('a-trailing-metadata', 'stack-trace'), + ('a-trailing-metadata-bin', b'\x05\x05\x05'), +) +_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( + ('a-must-have-key', 'secret'),) _INVALID_METADATA_TEST_CASES = ( ( @@ -60,15 +63,15 @@ _INVALID_METADATA_TEST_CASES = ( ), ( TypeError, - (('normal', object()),), + ((None, {}),), ), ( TypeError, - object(), + (({}, {}),), ), ( TypeError, - (object(),), + (('normal', object()),), ), ) @@ -198,6 +201,7 @@ class TestMetadata(AioTestBase): async def test_from_server_to_client(self): multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT) call = multicallable(_REQUEST) + self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await call.initial_metadata()) self.assertEqual(_RESPONSE, await call) @@ -213,7 +217,7 @@ class TestMetadata(AioTestBase): async def test_from_client_to_server_with_list(self): multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) call = multicallable( - _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) + _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) # pytype: disable=wrong-arg-types self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index f85e46c379a..d891ecdb771 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -198,7 +198,7 @@ class TestServerInterceptor(AioTestBase): request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - metadata = (('key', 'value'),) + metadata = aio.Metadata(('key', 'value'),) call = multicallable(messages_pb2.SimpleRequest(), metadata=metadata) await call @@ -208,7 +208,7 @@ class TestServerInterceptor(AioTestBase): ], record) record.clear() - metadata = (('key', 'value'), ('secret', '42')) + metadata = aio.Metadata(('key', 'value'), ('secret', '42')) call = multicallable(messages_pb2.SimpleRequest(), metadata=metadata) await call