[issue-21953] Use the Metadata type

In all places where a tuple was used for metadata (in the aio version),
replace it by the new ``Metadata`` class.
pull/23045/head
Mariano Anaya 5 years ago
parent 48e7c3d275
commit e04fcd2998
  1. 2
      src/python/grpcio/grpc/_compression.py
  2. 10
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 35
      src/python/grpcio/grpc/experimental/aio/_channel.py
  4. 2
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  5. 9
      src/python/grpcio/grpc/experimental/aio/_typing.py
  6. 6
      src/python/grpcio_tests/tests_aio/interop/methods.py
  7. 13
      src/python/grpcio_tests/tests_aio/unit/_common.py
  8. 7
      src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py
  9. 25
      src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py
  10. 27
      src/python/grpcio_tests/tests_aio/unit/metadata_test.py

@ -39,7 +39,7 @@ def create_channel_option(compression):
int(compression)),) if compression else ()
def augment_metadata(metadata, compression):
def augment_metadata(metadata, compression) -> tuple:
if not metadata and not compression:
return None
base_metadata = tuple(metadata) if metadata else ()

@ -28,6 +28,7 @@ from . import _base_call
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
MetadatumType, RequestIterableType, RequestType,
ResponseType, SerializingFunction)
from ._metadata import Metadata
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -58,11 +59,6 @@ class AioRpcError(grpc.RpcError):
determined. Hence, its methods no longer needs to be coroutines.
"""
# TODO(https://github.com/grpc/grpc/issues/20144) Metadata
# type returned by `initial_metadata` and `trailing_metadata`
# and also taken in the constructor needs to be revisit and make
# it more specific.
_code: grpc.StatusCode
_details: Optional[str]
_initial_metadata: Optional[MetadataType]
@ -88,8 +84,8 @@ class AioRpcError(grpc.RpcError):
super().__init__(self)
self._code = code
self._details = details
self._initial_metadata = initial_metadata
self._trailing_metadata = trailing_metadata
self._initial_metadata = initial_metadata or Metadata()
self._trailing_metadata = trailing_metadata or Metadata()
self._debug_error_string = debug_error_string
def code(self) -> grpc.StatusCode:

@ -32,8 +32,8 @@ from ._interceptor import (
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline
from ._metadata import Metadata
_IMMUTABLE_EMPTY_TUPLE = tuple()
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
if sys.version_info[1] < 7:
@ -88,6 +88,19 @@ class _BaseMultiCallable:
self._response_deserializer = response_deserializer
self._interceptors = interceptors
@staticmethod
def _init_metadata(metadata: Optional[Metadata] = None,
compression: Optional[grpc.Compression] = None
) -> Metadata:
"""Based on the provided values for <metadata> or <compression> initialise the final
metadata, as it should be used for the current call.
"""
metadata = metadata or Metadata()
if compression:
metadata = Metadata(
*_compression.augment_metadata(metadata, compression))
return metadata
class UnaryUnaryMultiCallable(_BaseMultiCallable,
_base_channel.UnaryUnaryMultiCallable):
@ -96,14 +109,13 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
if compression:
metadata = _compression.augment_metadata(metadata, compression)
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, wait_for_ready,
@ -127,14 +139,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
if compression:
metadata = _compression.augment_metadata(metadata, compression)
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
@ -158,14 +169,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable,
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamUnaryCall:
if compression:
metadata = _compression.augment_metadata(metadata, compression)
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
@ -189,14 +199,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
def __call__(self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamStreamCall:
if compression:
metadata = _compression.augment_metadata(metadata, compression)
metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:

@ -248,7 +248,7 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
class InterceptedCall:
"""Base implementation for all intecepted call arities.
"""Base implementation for all intercepted call arities.
Interceptors might have some work to do before the RPC invocation with
the capacity of changing the invocation parameters, and some work to do

@ -13,17 +13,18 @@
# limitations under the License.
"""Common types for gRPC Async API"""
from typing import (Any, AnyStr, AsyncIterable, Callable, Iterable, Sequence,
Tuple, TypeVar, Union)
from typing import (Any, AsyncIterable, Callable, Iterable, Sequence, Tuple,
TypeVar, Union)
from grpc._cython.cygrpc import EOF
from ._metadata import Metadata, MetadataKey, MetadataValue
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadatumType = Tuple[str, AnyStr]
MetadataType = Sequence[MetadatumType]
MetadatumType = Tuple[MetadataKey, MetadataValue]
MetadataType = Metadata
ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]

@ -287,8 +287,10 @@ async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub):
async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub):
initial_metadata_value = "test_initial_metadata_value"
trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),
(_TRAILING_METADATA_KEY, trailing_metadata_value))
metadata = aio.Metadata(
(_INITIAL_METADATA_KEY, initial_metadata_value),
(_TRAILING_METADATA_KEY, trailing_metadata_value),
)
async def _validate_metadata(call):
initial_metadata = dict(await call.initial_metadata())

@ -16,18 +16,19 @@ import asyncio
import grpc
from typing import AsyncIterable
from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType
from grpc.experimental.aio._typing import MetadataType, MetadatumType, MetadataKey, MetadataValue
from tests.unit.framework.common import test_constants
def seen_metadata(expected: MetadataType, actual: MetadataType):
return not bool(set(expected) - set(actual))
return not bool(set(tuple(expected)) - set(tuple(actual)))
def seen_metadatum(expected: MetadatumType, actual: MetadataType):
metadata_dict = dict(actual)
return metadata_dict.get(expected[0]) == expected[1]
def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue,
actual: MetadataType) -> bool:
obtained = actual[expected_key]
assert obtained == expected_value
async def block_until_certain_state(channel: aio.Channel,
@ -50,7 +51,7 @@ def inject_callbacks(call: aio.Call):
second_callback_ran = asyncio.Event()
def second_callback(call):
# Validate that all resopnses have been received
# Validate that all responses have been received
# and the call is an end state.
assert call.done()
second_callback_ran.set()

@ -18,11 +18,14 @@ import unittest
import grpc
from grpc.experimental import aio
from grpc.experimental.aio._call import AioRpcError
from tests_aio.unit._test_base import AioTestBase
_TEST_INITIAL_METADATA = ('initial metadata',)
_TEST_TRAILING_METADATA = ('trailing metadata',)
_TEST_INITIAL_METADATA = aio.Metadata(
('initial metadata key', 'initial metadata value'))
_TEST_TRAILING_METADATA = aio.Metadata(
('trailing metadata key', 'trailing metadata value'))
_TEST_DEBUG_ERROR_STRING = '{This is a debug string}'

@ -25,7 +25,7 @@ from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_INITIAL_METADATA_TO_INJECT = (
_INITIAL_METADATA_TO_INJECT = aio.Metadata(
(_INITIAL_METADATA_KEY, 'extra info'),
(_TRAILING_METADATA_KEY, b'\x13\x37'),
)
@ -550,11 +550,12 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_metadata = aio.Metadata(*client_call_details.metadata,
*_INITIAL_METADATA_TO_INJECT)
new_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=client_call_details.metadata +
_INITIAL_METADATA_TO_INJECT,
metadata=new_metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
)
@ -568,14 +569,20 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
# Expected to see the echoed initial metadata
self.assertTrue(
_common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await
call.initial_metadata()))
_common.seen_metadatum(
expected_key=_INITIAL_METADATA_KEY,
expected_value=_INITIAL_METADATA_TO_INJECT[
_INITIAL_METADATA_KEY],
actual=await call.initial_metadata(),
))
# Expected to see the echoed trailing metadata
self.assertTrue(
_common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await
call.trailing_metadata()))
_common.seen_metadatum(
expected_key=_TRAILING_METADATA_KEY,
expected_value=_INITIAL_METADATA_TO_INJECT[
_TRAILING_METADATA_KEY],
actual=await call.trailing_metadata(),
))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_add_done_callback_before_finishes(self):

@ -37,38 +37,33 @@ _TEST_STREAM_STREAM = '/test/TestStreamStream'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = (
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata(
('client-to-server', 'question'),
('client-to-server-bin', b'\x07\x07\x07'),
)
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = (
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata(
('server-to-client', 'answer'),
('server-to-client-bin', b'\x06\x06\x06'),
)
_TRAILING_METADATA = (('a-trailing-metadata', 'stack-trace'),
('a-trailing-metadata-bin', b'\x05\x05\x05'))
_INITIAL_METADATA_FOR_GENERIC_HANDLER = (('a-must-have-key', 'secret'),)
_TRAILING_METADATA = aio.Metadata(
('a-trailing-metadata', 'stack-trace'),
('a-trailing-metadata-bin', b'\x05\x05\x05'),
)
_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata(
('a-must-have-key', 'secret'),)
_INVALID_METADATA_TEST_CASES = (
(
TypeError,
((42, 42),),
),
(
TypeError,
(({}, {}),),
),
(
TypeError,
(('normal', object()),),
aio.Metadata((42, 42),),
),
(
TypeError,
object(),
aio.Metadata(({}, {}),),
),
(
TypeError,
(object(),),
aio.Metadata(('normal', object()),),
),
)

Loading…
Cancel
Save