Merge pull request #23045 from Skyscanner/issue-21953_use-metadata-type

[Aio] Use Metadata type
pull/23201/head
Lidi Zheng 5 years ago committed by GitHub
commit c47bc1a709
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 12
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  2. 8
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  3. 13
      src/python/grpcio/grpc/experimental/aio/_base_channel.py
  4. 13
      src/python/grpcio/grpc/experimental/aio/_base_server.py
  5. 55
      src/python/grpcio/grpc/experimental/aio/_call.py
  6. 37
      src/python/grpcio/grpc/experimental/aio/_channel.py
  7. 37
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  8. 10
      src/python/grpcio/grpc/experimental/aio/_metadata.py
  9. 9
      src/python/grpcio/grpc/experimental/aio/_typing.py
  10. 11
      src/python/grpcio_tests/tests_aio/interop/methods.py
  11. 16
      src/python/grpcio_tests/tests_aio/unit/_common.py
  12. 12
      src/python/grpcio_tests/tests_aio/unit/_metadata_test.py
  13. 9
      src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py
  14. 10
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  15. 12
      src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py
  16. 12
      src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py
  17. 4
      src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py
  18. 47
      src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py
  19. 3
      src/python/grpcio_tests/tests_aio/unit/compatibility_test.py
  20. 22
      src/python/grpcio_tests/tests_aio/unit/metadata_test.py
  21. 4
      src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py

@ -143,7 +143,7 @@ cdef class _ServicerContext:
self._loop)
self._rpc_state.metadata_sent = True
async def send_initial_metadata(self, tuple metadata):
async def send_initial_metadata(self, object metadata):
self._rpc_state.raise_for_termination()
if self._rpc_state.metadata_sent:
@ -151,7 +151,7 @@ cdef class _ServicerContext:
else:
await _send_initial_metadata(
self._rpc_state,
_augment_metadata(metadata, self._rpc_state.compression_algorithm),
_augment_metadata(tuple(metadata), self._rpc_state.compression_algorithm),
_EMPTY_FLAG,
self._loop
)
@ -192,8 +192,8 @@ cdef class _ServicerContext:
async def abort_with_status(self, object status):
await self.abort(status.code, status.details, status.trailing_metadata)
def set_trailing_metadata(self, tuple metadata):
self._rpc_state.trailing_metadata = metadata
def set_trailing_metadata(self, object metadata):
self._rpc_state.trailing_metadata = tuple(metadata)
def invocation_metadata(self):
return self._rpc_state.invocation_metadata()
@ -233,13 +233,13 @@ cdef class _SyncServicerContext:
# Abort should raise an AbortError
future.exception()
def send_initial_metadata(self, tuple metadata):
def send_initial_metadata(self, object metadata):
future = asyncio.run_coroutine_threadsafe(
self._context.send_initial_metadata(metadata),
self._loop)
future.result()
def set_trailing_metadata(self, tuple metadata):
def set_trailing_metadata(self, object metadata):
self._context.set_trailing_metadata(metadata)
def invocation_metadata(self):

@ -23,8 +23,8 @@ from typing import AsyncIterable, Awaitable, Generic, Optional, Union
import grpc
from ._typing import (DoneCallbackType, EOFType, MetadataType, RequestType,
ResponseType)
from ._typing import (DoneCallbackType, EOFType, RequestType, ResponseType)
from ._metadata import Metadata
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -86,7 +86,7 @@ class Call(RpcContext, metaclass=ABCMeta):
"""The abstract base class of an RPC on the client-side."""
@abstractmethod
async def initial_metadata(self) -> MetadataType:
async def initial_metadata(self) -> Metadata:
"""Accesses the initial metadata sent by the server.
Returns:
@ -94,7 +94,7 @@ class Call(RpcContext, metaclass=ABCMeta):
"""
@abstractmethod
async def trailing_metadata(self) -> MetadataType:
async def trailing_metadata(self) -> Metadata:
"""Accesses the trailing metadata sent by the server.
Returns:

@ -19,10 +19,9 @@ from typing import Any, Optional
import grpc
from . import _base_call
from ._typing import (DeserializingFunction, MetadataType, RequestIterableType,
from ._typing import (DeserializingFunction, RequestIterableType,
SerializingFunction)
_IMMUTABLE_EMPTY_TUPLE = tuple()
from ._metadata import Metadata
class UnaryUnaryMultiCallable(abc.ABC):
@ -33,7 +32,7 @@ class UnaryUnaryMultiCallable(abc.ABC):
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
@ -71,7 +70,7 @@ class UnaryStreamMultiCallable(abc.ABC):
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
@ -108,7 +107,7 @@ class StreamUnaryMultiCallable(abc.ABC):
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
@ -146,7 +145,7 @@ class StreamStreamMultiCallable(abc.ABC):
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None

@ -18,7 +18,8 @@ from typing import Generic, Optional, Sequence
import grpc
from ._typing import MetadataType, RequestType, ResponseType
from ._typing import RequestType, ResponseType
from ._metadata import Metadata
class Server(abc.ABC):
@ -157,8 +158,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
"""
@abc.abstractmethod
async def send_initial_metadata(self,
initial_metadata: MetadataType) -> None:
async def send_initial_metadata(self, initial_metadata: Metadata) -> None:
"""Sends the initial metadata value to the client.
This method need not be called by implementations if they have no
@ -170,7 +170,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
@abc.abstractmethod
async def abort(self, code: grpc.StatusCode, details: str,
trailing_metadata: MetadataType) -> None:
trailing_metadata: Metadata) -> None:
"""Raises an exception to terminate the RPC with a non-OK status.
The code and details passed as arguments will supercede any existing
@ -190,8 +190,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
"""
@abc.abstractmethod
async def set_trailing_metadata(self,
trailing_metadata: MetadataType) -> None:
async def set_trailing_metadata(self, trailing_metadata: Metadata) -> None:
"""Sends the trailing metadata for the RPC.
This method need not be called by implementations if they have no
@ -202,7 +201,7 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
"""
@abc.abstractmethod
def invocation_metadata(self) -> Optional[MetadataType]:
def invocation_metadata(self) -> Optional[Metadata]:
"""Accesses the metadata from the sent by the client.
Returns:

@ -25,9 +25,10 @@ from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
MetadatumType, RequestIterableType, RequestType,
ResponseType, SerializingFunction)
from ._metadata import Metadata
from ._typing import (DeserializingFunction, DoneCallbackType, MetadatumType,
RequestIterableType, RequestType, ResponseType,
SerializingFunction)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -58,22 +59,17 @@ class AioRpcError(grpc.RpcError):
determined. Hence, its methods no longer needs to be coroutines.
"""
# TODO(https://github.com/grpc/grpc/issues/20144) Metadata
# type returned by `initial_metadata` and `trailing_metadata`
# and also taken in the constructor needs to be revisit and make
# it more specific.
_code: grpc.StatusCode
_details: Optional[str]
_initial_metadata: Optional[MetadataType]
_trailing_metadata: Optional[MetadataType]
_initial_metadata: Optional[Metadata]
_trailing_metadata: Optional[Metadata]
_debug_error_string: Optional[str]
def __init__(self,
code: grpc.StatusCode,
initial_metadata: Metadata,
trailing_metadata: Metadata,
details: Optional[str] = None,
initial_metadata: Optional[MetadataType] = None,
trailing_metadata: Optional[MetadataType] = None,
debug_error_string: Optional[str] = None) -> None:
"""Constructor.
@ -108,7 +104,7 @@ class AioRpcError(grpc.RpcError):
"""
return self._details
def initial_metadata(self) -> Optional[MetadataType]:
def initial_metadata(self) -> Metadata:
"""Accesses the initial metadata sent by the server.
Returns:
@ -116,7 +112,7 @@ class AioRpcError(grpc.RpcError):
"""
return self._initial_metadata
def trailing_metadata(self) -> Optional[MetadataType]:
def trailing_metadata(self) -> Metadata:
"""Accesses the trailing metadata sent by the server.
Returns:
@ -145,14 +141,14 @@ class AioRpcError(grpc.RpcError):
return self._repr()
def _create_rpc_error(initial_metadata: Optional[MetadataType],
def _create_rpc_error(initial_metadata: Metadata,
status: cygrpc.AioRpcStatus) -> AioRpcError:
return AioRpcError(
_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
status.details(),
initial_metadata,
status.trailing_metadata(),
status.debug_error_string(),
Metadata.from_tuple(initial_metadata),
Metadata.from_tuple(status.trailing_metadata()),
details=status.details(),
debug_error_string=status.debug_error_string(),
)
@ -168,7 +164,7 @@ class Call:
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType,
def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
@ -208,11 +204,14 @@ class Call:
def time_remaining(self) -> Optional[float]:
return self._cython_call.time_remaining()
async def initial_metadata(self) -> MetadataType:
return await self._cython_call.initial_metadata()
async def initial_metadata(self) -> Metadata:
raw_metadata_tuple = await self._cython_call.initial_metadata()
return Metadata.from_tuple(raw_metadata_tuple)
async def trailing_metadata(self) -> MetadataType:
return (await self._cython_call.status()).trailing_metadata()
async def trailing_metadata(self) -> Metadata:
raw_metadata_tuple = (await
self._cython_call.status()).trailing_metadata()
return Metadata.from_tuple(raw_metadata_tuple)
async def code(self) -> grpc.StatusCode:
cygrpc_code = (await self._cython_call.status()).code()
@ -475,7 +474,7 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -524,7 +523,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -564,7 +563,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
# pylint: disable=too-many-arguments
def __init__(self, request_iterator: Optional[RequestIterableType],
deadline: Optional[float], metadata: MetadataType,
deadline: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -602,7 +601,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
# pylint: disable=too-many-arguments
def __init__(self, request_iterator: Optional[RequestIterableType],
deadline: Optional[float], metadata: MetadataType,
deadline: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,

@ -29,11 +29,11 @@ from ._interceptor import (
InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor,
UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor, StreamStreamClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
from ._metadata import Metadata
from ._typing import (ChannelArgumentType, DeserializingFunction,
SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple()
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
if sys.version_info[1] < 7:
@ -88,6 +88,19 @@ class _BaseMultiCallable:
self._response_deserializer = response_deserializer
self._interceptors = interceptors
@staticmethod
def _init_metadata(metadata: Optional[Metadata] = None,
compression: Optional[grpc.Compression] = None
) -> Metadata:
"""Based on the provided values for <metadata> or <compression> initialise the final
metadata, as it should be used for the current call.
"""
metadata = metadata or Metadata()
if compression:
metadata = Metadata(
*_compression.augment_metadata(metadata, compression))
return metadata
class UnaryUnaryMultiCallable(_BaseMultiCallable,
_base_channel.UnaryUnaryMultiCallable):
@ -96,14 +109,13 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
if compression:
metadata = _compression.augment_metadata(metadata, compression)
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, wait_for_ready,
@ -127,14 +139,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
if compression:
metadata = _compression.augment_metadata(metadata, compression)
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
@ -158,14 +169,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable,
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamUnaryCall:
if compression:
metadata = _compression.augment_metadata(metadata, compression)
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
@ -189,14 +199,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamStreamCall:
if compression:
metadata = _compression.augment_metadata(metadata, compression)
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:

@ -27,8 +27,9 @@ from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS
from ._call import _API_STYLE_ERROR
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
MetadataType, ResponseType, DoneCallbackType,
RequestIterableType, ResponseIterableType)
ResponseType, DoneCallbackType, RequestIterableType,
ResponseIterableType)
from ._metadata import Metadata
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
@ -82,7 +83,7 @@ class ClientCallDetails(
method: str
timeout: Optional[float]
metadata: Optional[MetadataType]
metadata: Optional[Metadata]
credentials: Optional[grpc.CallCredentials]
wait_for_ready: Optional[bool]
@ -248,7 +249,7 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
class InterceptedCall:
"""Base implementation for all intecepted call arities.
"""Base implementation for all intercepted call arities.
Interceptors might have some work to do before the RPC invocation with
the capacity of changing the invocation parameters, and some work to do
@ -370,7 +371,7 @@ class InterceptedCall:
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
async def initial_metadata(self) -> Optional[Metadata]:
try:
call = await self._interceptors_task
except AioRpcError as err:
@ -380,7 +381,7 @@ class InterceptedCall:
return await call.initial_metadata()
async def trailing_metadata(self) -> Optional[MetadataType]:
async def trailing_metadata(self) -> Optional[Metadata]:
try:
call = await self._interceptors_task
except AioRpcError as err:
@ -556,7 +557,7 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: MetadataType,
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -573,7 +574,7 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
# pylint: disable=too-many-arguments
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
metadata: Optional[Metadata],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], request: RequestType,
request_serializer: SerializingFunction,
@ -628,7 +629,7 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin,
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: MetadataType,
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -647,7 +648,7 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin,
# pylint: disable=too-many-arguments
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
metadata: Optional[Metadata],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], request: RequestType,
request_serializer: SerializingFunction,
@ -712,7 +713,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor],
request_iterator: Optional[RequestIterableType],
timeout: Optional[float], metadata: MetadataType,
timeout: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -731,7 +732,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
async def _invoke(
self, interceptors: Sequence[StreamUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
metadata: Optional[Metadata],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
request_iterator: RequestIterableType,
@ -783,7 +784,7 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin,
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor],
request_iterator: Optional[RequestIterableType],
timeout: Optional[float], metadata: MetadataType,
timeout: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -804,7 +805,7 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin,
async def _invoke(
self, interceptors: Sequence[StreamStreamClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
metadata: Optional[Metadata],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
request_iterator: RequestIterableType,
@ -876,10 +877,10 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
async def initial_metadata(self) -> Optional[Metadata]:
return None
async def trailing_metadata(self) -> Optional[MetadataType]:
async def trailing_metadata(self) -> Optional[Metadata]:
return None
async def code(self) -> grpc.StatusCode:
@ -928,10 +929,10 @@ class _StreamCallResponseIterator:
def time_remaining(self) -> Optional[float]:
return self._call.time_remaining()
async def initial_metadata(self) -> Optional[MetadataType]:
async def initial_metadata(self) -> Optional[Metadata]:
return await self._call.initial_metadata()
async def trailing_metadata(self) -> Optional[MetadataType]:
async def trailing_metadata(self) -> Optional[Metadata]:
return await self._call.trailing_metadata()
async def code(self) -> grpc.StatusCode:

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the metadata abstraction for gRPC Asyncio Python."""
from typing import List, Tuple, Iterator, Any, Text, Union
from typing import List, Tuple, Iterator, Any, Union
from collections import abc, OrderedDict
MetadataKey = Text
MetadataKey = str
MetadataValue = Union[str, bytes]
@ -37,6 +37,12 @@ class Metadata(abc.Mapping):
for md_key, md_value in args:
self.add(md_key, md_value)
@classmethod
def from_tuple(cls, raw_metadata: tuple):
if raw_metadata:
return cls(*raw_metadata)
return cls()
def add(self, key: MetadataKey, value: MetadataValue) -> None:
self._metadata.setdefault(key, [])
self._metadata[key].append(value)

@ -13,17 +13,18 @@
# limitations under the License.
"""Common types for gRPC Async API"""
from typing import (Any, AnyStr, AsyncIterable, Callable, Iterable, Sequence,
Tuple, TypeVar, Union)
from typing import (Any, AsyncIterable, Callable, Iterable, Sequence, Tuple,
TypeVar, Union)
from grpc._cython.cygrpc import EOF
from ._metadata import Metadata, MetadataKey, MetadataValue
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadatumType = Tuple[str, AnyStr]
MetadataType = Sequence[MetadatumType]
MetadatumType = Tuple[MetadataKey, MetadataValue]
MetadataType = Metadata
ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]

@ -287,16 +287,19 @@ async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub):
async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub):
initial_metadata_value = "test_initial_metadata_value"
trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),
(_TRAILING_METADATA_KEY, trailing_metadata_value))
metadata = aio.Metadata(
(_INITIAL_METADATA_KEY, initial_metadata_value),
(_TRAILING_METADATA_KEY, trailing_metadata_value),
)
async def _validate_metadata(call):
initial_metadata = dict(await call.initial_metadata())
initial_metadata = await call.initial_metadata()
if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
raise ValueError('expected initial metadata %s, got %s' %
(initial_metadata_value,
initial_metadata[_INITIAL_METADATA_KEY]))
trailing_metadata = dict(await call.trailing_metadata())
trailing_metadata = await call.trailing_metadata()
if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
raise ValueError('expected trailing metadata %s, got %s' %
(trailing_metadata_value,

@ -16,18 +16,20 @@ import asyncio
import grpc
from typing import AsyncIterable
from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType
from grpc.experimental.aio._typing import MetadatumType, MetadataKey, MetadataValue
from grpc.experimental.aio._metadata import Metadata
from tests.unit.framework.common import test_constants
def seen_metadata(expected: MetadataType, actual: MetadataType):
return not bool(set(expected) - set(actual))
def seen_metadata(expected: Metadata, actual: Metadata):
return not bool(set(tuple(expected)) - set(tuple(actual)))
def seen_metadatum(expected: MetadatumType, actual: MetadataType):
metadata_dict = dict(actual)
return metadata_dict.get(expected[0]) == expected[1]
def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue,
actual: Metadata) -> bool:
obtained = actual[expected_key]
return obtained == expected_value
async def block_until_certain_state(channel: aio.Channel,
@ -50,7 +52,7 @@ def inject_callbacks(call: aio.Call):
second_callback_ran = asyncio.Event()
def second_callback(call):
# Validate that all resopnses have been received
# Validate that all responses have been received
# and the call is an end state.
assert call.done()
second_callback_ran.set()

@ -119,6 +119,18 @@ class TestTypeMetadata(unittest.TestCase):
with self.assertRaises(KeyError):
del metadata["other key"]
def test_metadata_from_tuple(self):
scenarios = (
(None, Metadata()),
(Metadata(), Metadata()),
(self._DEFAULT_DATA, Metadata(*self._DEFAULT_DATA)),
(self._MULTI_ENTRY_DATA, Metadata(*self._MULTI_ENTRY_DATA)),
(Metadata(*self._DEFAULT_DATA), Metadata(*self._DEFAULT_DATA)),
)
for source, expected in scenarios:
with self.subTest(raw_metadata=source, expected=expected):
self.assertEqual(expected, Metadata.from_tuple(source))
if __name__ == '__main__':
logging.basicConfig()

@ -18,11 +18,14 @@ import unittest
import grpc
from grpc.experimental import aio
from grpc.experimental.aio._call import AioRpcError
from tests_aio.unit._test_base import AioTestBase
_TEST_INITIAL_METADATA = ('initial metadata',)
_TEST_TRAILING_METADATA = ('trailing metadata',)
_TEST_INITIAL_METADATA = aio.Metadata(
('initial metadata key', 'initial metadata value'))
_TEST_TRAILING_METADATA = aio.Metadata(
('trailing metadata key', 'trailing metadata value'))
_TEST_DEBUG_ERROR_STRING = '{This is a debug string}'
@ -30,9 +33,9 @@ class TestAioRpcError(unittest.TestCase):
def test_attributes(self):
aio_rpc_error = AioRpcError(grpc.StatusCode.CANCELLED,
'details',
initial_metadata=_TEST_INITIAL_METADATA,
trailing_metadata=_TEST_TRAILING_METADATA,
details="details",
debug_error_string=_TEST_DEBUG_ERROR_STRING)
self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(aio_rpc_error.details(), 'details')

@ -102,11 +102,11 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
async def test_call_initial_metadata_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.initial_metadata())
self.assertEqual(aio.Metadata(), await call.initial_metadata())
async def test_call_trailing_metadata_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata())
self.assertEqual(aio.Metadata(), await call.trailing_metadata())
async def test_call_initial_metadata_cancelable(self):
coro_started = asyncio.Event()
@ -122,7 +122,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
# Test that initial metadata can still be asked thought
# a cancellation happened with the previous task
self.assertEqual((), await call.initial_metadata())
self.assertEqual(aio.Metadata(), await call.initial_metadata())
async def test_call_initial_metadata_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
@ -134,8 +134,8 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
task2 = self.loop.create_task(coro())
await call
self.assertEqual([(), ()], await asyncio.gather(*[task1, task2]))
expected = [aio.Metadata() for _ in range(2)]
self.assertEqual(expected, await asyncio.gather(*[task1, task2]))
async def test_call_code_cancelable(self):
coro_started = asyncio.Event()

@ -98,8 +98,8 @@ class TestStreamStreamClientInterceptor(AioTestBase):
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
@ -140,8 +140,8 @@ class TestStreamStreamClientInterceptor(AioTestBase):
await call.done_writing()
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
@ -183,8 +183,8 @@ class TestStreamStreamClientInterceptor(AioTestBase):
await call.done_writing()
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)

@ -92,8 +92,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase):
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
@ -131,8 +131,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase):
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
@ -230,8 +230,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase):
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)

@ -96,8 +96,8 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)

@ -25,7 +25,7 @@ from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_INITIAL_METADATA_TO_INJECT = (
_INITIAL_METADATA_TO_INJECT = aio.Metadata(
(_INITIAL_METADATA_KEY, 'extra info'),
(_TRAILING_METADATA_KEY, b'\x13\x37'),
)
@ -302,8 +302,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_ok_awaited(self):
@ -331,8 +331,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_rpc_error(self):
@ -364,8 +364,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_rpc_error_awaited(self):
@ -398,8 +398,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_cancel_before_rpc(self):
@ -541,8 +541,10 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), tuple())
self.assertEqual(await call.trailing_metadata(), None)
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(
await call.trailing_metadata(), aio.Metadata(),
"When the raw response is None, empty metadata is returned")
async def test_initial_metadata_modification(self):
@ -550,11 +552,12 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_metadata = aio.Metadata(*client_call_details.metadata,
*_INITIAL_METADATA_TO_INJECT)
new_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=client_call_details.metadata +
_INITIAL_METADATA_TO_INJECT,
metadata=new_metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
)
@ -568,14 +571,20 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
# Expected to see the echoed initial metadata
self.assertTrue(
_common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await
call.initial_metadata()))
_common.seen_metadatum(
expected_key=_INITIAL_METADATA_KEY,
expected_value=_INITIAL_METADATA_TO_INJECT[
_INITIAL_METADATA_KEY],
actual=await call.initial_metadata(),
))
# Expected to see the echoed trailing metadata
self.assertTrue(
_common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await
call.trailing_metadata()))
_common.seen_metadatum(
expected_key=_TRAILING_METADATA_KEY,
expected_value=_INITIAL_METADATA_TO_INJECT[
_TRAILING_METADATA_KEY],
actual=await call.trailing_metadata(),
))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_add_done_callback_before_finishes(self):

@ -255,7 +255,8 @@ class TestCompatibility(AioTestBase):
self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary)
call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
self.assertTrue(
_common.seen_metadata(metadata, await call.initial_metadata()))
_common.seen_metadata(aio.Metadata(*metadata), await
call.initial_metadata()))
async def test_sync_unary_unary_abort(self):

@ -37,17 +37,20 @@ _TEST_STREAM_STREAM = '/test/TestStreamStream'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = (
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata(
('client-to-server', 'question'),
('client-to-server-bin', b'\x07\x07\x07'),
)
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = (
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata(
('server-to-client', 'answer'),
('server-to-client-bin', b'\x06\x06\x06'),
)
_TRAILING_METADATA = (('a-trailing-metadata', 'stack-trace'),
('a-trailing-metadata-bin', b'\x05\x05\x05'))
_INITIAL_METADATA_FOR_GENERIC_HANDLER = (('a-must-have-key', 'secret'),)
_TRAILING_METADATA = aio.Metadata(
('a-trailing-metadata', 'stack-trace'),
('a-trailing-metadata-bin', b'\x05\x05\x05'),
)
_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata(
('a-must-have-key', 'secret'),)
_INVALID_METADATA_TEST_CASES = (
(
@ -60,15 +63,15 @@ _INVALID_METADATA_TEST_CASES = (
),
(
TypeError,
(('normal', object()),),
((None, {}),),
),
(
TypeError,
object(),
(({}, {}),),
),
(
TypeError,
(object(),),
(('normal', object()),),
),
)
@ -198,6 +201,7 @@ class TestMetadata(AioTestBase):
async def test_from_server_to_client(self):
multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT)
call = multicallable(_REQUEST)
self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata())
self.assertEqual(_RESPONSE, await call)
@ -213,7 +217,7 @@ class TestMetadata(AioTestBase):
async def test_from_client_to_server_with_list(self):
multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
call = multicallable(
_REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER))
_REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) # pytype: disable=wrong-arg-types
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())

@ -198,7 +198,7 @@ class TestServerInterceptor(AioTestBase):
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
metadata = (('key', 'value'),)
metadata = aio.Metadata(('key', 'value'),)
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call
@ -208,7 +208,7 @@ class TestServerInterceptor(AioTestBase):
], record)
record.clear()
metadata = (('key', 'value'), ('secret', '42'))
metadata = aio.Metadata(('key', 'value'), ('secret', '42'))
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call

Loading…
Cancel
Save