[issue-21953] Improvements from review

* Replace ``MetadataType`` by ``Metadata`` in all places
* Fix annotations
* Use the new ``Metadata.from_tuple`` to create Metadata objects
pull/23045/head
Mariano Anaya 5 years ago
parent e9dadf46bf
commit 8fcc77a310
  1. 2
      src/python/grpcio/grpc/_compression.py
  2. 42
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 10
      src/python/grpcio/grpc/experimental/aio/_metadata.py
  4. 12
      src/python/grpcio_tests/tests_aio/unit/_metadata_test.py
  5. 8
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  6. 4
      src/python/grpcio_tests/tests_aio/unit/metadata_test.py

@ -39,7 +39,7 @@ def create_channel_option(compression):
int(compression)),) if compression else ()
def augment_metadata(metadata, compression) -> tuple:
def augment_metadata(metadata, compression):
if not metadata and not compression:
return None
base_metadata = tuple(metadata) if metadata else ()

@ -26,7 +26,7 @@ from grpc._cython import cygrpc
from . import _base_call
from ._metadata import Metadata
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
from ._typing import (DeserializingFunction, DoneCallbackType,
MetadatumType, RequestIterableType, RequestType,
ResponseType, SerializingFunction)
@ -61,15 +61,15 @@ class AioRpcError(grpc.RpcError):
_code: grpc.StatusCode
_details: Optional[str]
_initial_metadata: Optional[MetadataType]
_trailing_metadata: Optional[MetadataType]
_initial_metadata: Optional[Metadata]
_trailing_metadata: Optional[Metadata]
_debug_error_string: Optional[str]
def __init__(self,
code: grpc.StatusCode,
details: Optional[str] = None,
initial_metadata: Optional[MetadataType] = None,
trailing_metadata: Optional[MetadataType] = None,
initial_metadata: Optional[Metadata] = None,
trailing_metadata: Optional[Metadata] = None,
debug_error_string: Optional[str] = None) -> None:
"""Constructor.
@ -84,8 +84,8 @@ class AioRpcError(grpc.RpcError):
super().__init__(self)
self._code = code
self._details = details
self._initial_metadata = Metadata(*(initial_metadata or ()))
self._trailing_metadata = Metadata(*(trailing_metadata or ()))
self._initial_metadata = initial_metadata
self._trailing_metadata = trailing_metadata
self._debug_error_string = debug_error_string
def code(self) -> grpc.StatusCode:
@ -104,7 +104,7 @@ class AioRpcError(grpc.RpcError):
"""
return self._details
def initial_metadata(self) -> Optional[MetadataType]:
def initial_metadata(self) -> Metadata:
"""Accesses the initial metadata sent by the server.
Returns:
@ -112,7 +112,7 @@ class AioRpcError(grpc.RpcError):
"""
return self._initial_metadata
def trailing_metadata(self) -> Optional[MetadataType]:
def trailing_metadata(self) -> Metadata:
"""Accesses the trailing metadata sent by the server.
Returns:
@ -141,13 +141,13 @@ class AioRpcError(grpc.RpcError):
return self._repr()
def _create_rpc_error(initial_metadata: Optional[MetadataType],
def _create_rpc_error(initial_metadata: Metadata,
status: cygrpc.AioRpcStatus) -> AioRpcError:
return AioRpcError(
_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
status.details(),
initial_metadata,
status.trailing_metadata(),
Metadata.from_tuple(initial_metadata),
Metadata.from_tuple(status.trailing_metadata()),
status.debug_error_string(),
)
@ -164,7 +164,7 @@ class Call:
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType,
def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
@ -204,14 +204,14 @@ class Call:
def time_remaining(self) -> Optional[float]:
return self._cython_call.time_remaining()
async def initial_metadata(self) -> MetadataType:
async def initial_metadata(self) -> Metadata:
raw_metadata_tuple = await self._cython_call.initial_metadata()
return Metadata(*(raw_metadata_tuple or ()))
return Metadata.from_tuple(raw_metadata_tuple)
async def trailing_metadata(self) -> MetadataType:
async def trailing_metadata(self) -> Metadata:
raw_metadata_tuple = (await
self._cython_call.status()).trailing_metadata()
return Metadata(*(raw_metadata_tuple or ()))
return Metadata.from_tuple(raw_metadata_tuple)
async def code(self) -> grpc.StatusCode:
cygrpc_code = (await self._cython_call.status()).code()
@ -474,7 +474,7 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -523,7 +523,7 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -563,7 +563,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
# pylint: disable=too-many-arguments
def __init__(self, request_iterator: Optional[RequestIterableType],
deadline: Optional[float], metadata: MetadataType,
deadline: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
@ -601,7 +601,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
# pylint: disable=too-many-arguments
def __init__(self, request_iterator: Optional[RequestIterableType],
deadline: Optional[float], metadata: MetadataType,
deadline: Optional[float], metadata: Metadata,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the metadata abstraction for gRPC Asyncio Python."""
from typing import List, Tuple, Iterator, Any, Text, Union
from typing import List, Tuple, Iterator, Any, Union
from collections import abc, OrderedDict
MetadataKey = Text
MetadataKey = str
MetadataValue = Union[str, bytes]
@ -37,6 +37,12 @@ class Metadata(abc.Mapping):
for md_key, md_value in args:
self.add(md_key, md_value)
@classmethod
def from_tuple(cls, raw_metadata: tuple):
if raw_metadata:
return cls(*raw_metadata)
return cls()
def add(self, key: MetadataKey, value: MetadataValue) -> None:
self._metadata.setdefault(key, [])
self._metadata[key].append(value)

@ -119,6 +119,18 @@ class TestTypeMetadata(unittest.TestCase):
with self.assertRaises(KeyError):
del metadata["other key"]
def test_metadata_from_tuple(self):
scenarios = (
(None, Metadata()),
(Metadata(), Metadata()),
(self._DEFAULT_DATA, Metadata(*self._DEFAULT_DATA)),
(self._MULTI_ENTRY_DATA, Metadata(*self._MULTI_ENTRY_DATA)),
(Metadata(*self._DEFAULT_DATA), Metadata(*self._DEFAULT_DATA)),
)
for source, expected in scenarios:
with self.subTest(raw_metadata=source, expected=expected):
self.assertEqual(expected, Metadata.from_tuple(source))
if __name__ == '__main__':
logging.basicConfig()

@ -102,11 +102,11 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
async def test_call_initial_metadata_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(aio.Metadata(), await call.initial_metadata())
async def test_call_trailing_metadata_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(aio.Metadata(), await call.trailing_metadata())
async def test_call_initial_metadata_cancelable(self):
coro_started = asyncio.Event()
@ -122,7 +122,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
# Test that initial metadata can still be asked thought
# a cancellation happened with the previous task
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(aio.Metadata(), await call.initial_metadata())
async def test_call_initial_metadata_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
@ -135,7 +135,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
await call
expected = [aio.Metadata() for _ in range(2)]
self.assertEqual(await asyncio.gather(*[task1, task2]), expected)
self.assertEqual(expected, await asyncio.gather(*[task1, task2]))
async def test_call_code_cancelable(self):
coro_started = asyncio.Event()

@ -57,6 +57,10 @@ _INVALID_METADATA_TEST_CASES = (
TypeError,
((42, 42),),
),
(
TypeError,
(({}, {}),),
),
(
TypeError,
((None, {}),),

Loading…
Cancel
Save