[issue-24953] Fix tests, format, & types

Fixes https://github.com/grpc/grpc/issues/21953
pull/23045/head
Mariano Anaya 5 years ago
parent e04fcd2998
commit e9dadf46bf
  1. 13
      src/python/grpcio/grpc/experimental/aio/_call.py
  2. 2
      src/python/grpcio/grpc/experimental/aio/_channel.py
  3. 5
      src/python/grpcio_tests/tests_aio/interop/methods.py
  4. 2
      src/python/grpcio_tests/tests_aio/unit/_common.py
  5. 10
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  6. 12
      src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py
  7. 4
      src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py
  8. 22
      src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py
  9. 3
      src/python/grpcio_tests/tests_aio/unit/compatibility_test.py
  10. 27
      src/python/grpcio_tests/tests_aio/unit/metadata_test.py
  11. 4
      src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py

@ -25,10 +25,10 @@ from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
from ._metadata import Metadata
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
MetadatumType, RequestIterableType, RequestType,
ResponseType, SerializingFunction)
from ._metadata import Metadata
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -84,8 +84,8 @@ class AioRpcError(grpc.RpcError):
super().__init__(self)
self._code = code
self._details = details
self._initial_metadata = initial_metadata or Metadata()
self._trailing_metadata = trailing_metadata or Metadata()
self._initial_metadata = Metadata(*(initial_metadata or ()))
self._trailing_metadata = Metadata(*(trailing_metadata or ()))
self._debug_error_string = debug_error_string
def code(self) -> grpc.StatusCode:
@ -205,10 +205,13 @@ class Call:
return self._cython_call.time_remaining()
async def initial_metadata(self) -> MetadataType:
return await self._cython_call.initial_metadata()
raw_metadata_tuple = await self._cython_call.initial_metadata()
return Metadata(*(raw_metadata_tuple or ()))
async def trailing_metadata(self) -> MetadataType:
return (await self._cython_call.status()).trailing_metadata()
raw_metadata_tuple = (await
self._cython_call.status()).trailing_metadata()
return Metadata(*(raw_metadata_tuple or ()))
async def code(self) -> grpc.StatusCode:
cygrpc_code = (await self._cython_call.status()).code()

@ -29,10 +29,10 @@ from ._interceptor import (
InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor,
UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor, StreamStreamClientInterceptor)
from ._metadata import Metadata
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline
from ._metadata import Metadata
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)

@ -293,12 +293,13 @@ async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub):
)
async def _validate_metadata(call):
initial_metadata = dict(await call.initial_metadata())
initial_metadata = await call.initial_metadata()
if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
raise ValueError('expected initial metadata %s, got %s' %
(initial_metadata_value,
initial_metadata[_INITIAL_METADATA_KEY]))
trailing_metadata = dict(await call.trailing_metadata())
trailing_metadata = await call.trailing_metadata()
if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
raise ValueError('expected trailing metadata %s, got %s' %
(trailing_metadata_value,

@ -28,7 +28,7 @@ def seen_metadata(expected: MetadataType, actual: MetadataType):
def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue,
actual: MetadataType) -> bool:
obtained = actual[expected_key]
assert obtained == expected_value
return obtained == expected_value
async def block_until_certain_state(channel: aio.Channel,

@ -102,11 +102,11 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
async def test_call_initial_metadata_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.initial_metadata())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
async def test_call_trailing_metadata_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_initial_metadata_cancelable(self):
coro_started = asyncio.Event()
@ -122,7 +122,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
# Test that initial metadata can still be asked thought
# a cancellation happened with the previous task
self.assertEqual((), await call.initial_metadata())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
async def test_call_initial_metadata_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
@ -134,8 +134,8 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
task2 = self.loop.create_task(coro())
await call
self.assertEqual([(), ()], await asyncio.gather(*[task1, task2]))
expected = [aio.Metadata() for _ in range(2)]
self.assertEqual(await asyncio.gather(*[task1, task2]), expected)
async def test_call_code_cancelable(self):
coro_started = asyncio.Event()

@ -92,8 +92,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase):
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
@ -131,8 +131,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase):
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
@ -230,8 +230,8 @@ class TestStreamUnaryClientInterceptor(AioTestBase):
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)

@ -96,8 +96,8 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)

@ -302,8 +302,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_ok_awaited(self):
@ -331,8 +331,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_rpc_error(self):
@ -364,8 +364,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_rpc_error_awaited(self):
@ -398,8 +398,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_cancel_before_rpc(self):
@ -541,8 +541,10 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), tuple())
self.assertEqual(await call.trailing_metadata(), None)
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(
await call.trailing_metadata(), aio.Metadata(),
"When the raw response is None, empty metadata is returned")
async def test_initial_metadata_modification(self):

@ -255,7 +255,8 @@ class TestCompatibility(AioTestBase):
self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary)
call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
self.assertTrue(
_common.seen_metadata(metadata, await call.initial_metadata()))
_common.seen_metadata(aio.Metadata(*metadata), await
call.initial_metadata()))
async def test_sync_unary_unary_abort(self):

@ -55,15 +55,15 @@ _INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata(
_INVALID_METADATA_TEST_CASES = (
(
TypeError,
aio.Metadata((42, 42),),
((42, 42),),
),
(
TypeError,
aio.Metadata(({}, {}),),
((None, {}),),
),
(
TypeError,
aio.Metadata(('normal', object()),),
(('normal', object()),),
),
)
@ -100,13 +100,13 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
async def _test_server_to_client(request, context):
assert _REQUEST == request
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT))
return _RESPONSE
@staticmethod
async def _test_trailing_metadata(request, context):
assert _REQUEST == request
context.set_trailing_metadata(_TRAILING_METADATA)
context.set_trailing_metadata(tuple(_TRAILING_METADATA))
return _RESPONSE
@staticmethod
@ -115,21 +115,21 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT))
yield _RESPONSE
context.set_trailing_metadata(_TRAILING_METADATA)
context.set_trailing_metadata(tuple(_TRAILING_METADATA))
@staticmethod
async def _test_stream_unary(request_iterator, context):
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT))
async for request in request_iterator:
assert _REQUEST == request
context.set_trailing_metadata(_TRAILING_METADATA)
context.set_trailing_metadata(tuple(_TRAILING_METADATA))
return _RESPONSE
@staticmethod
@ -137,13 +137,13 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
tuple(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT))
async for request in request_iterator:
assert _REQUEST == request
yield _RESPONSE
context.set_trailing_metadata(_TRAILING_METADATA)
context.set_trailing_metadata(tuple(_TRAILING_METADATA))
def service(self, handler_call_details):
return self._routing_table.get(handler_call_details.method)
@ -193,6 +193,7 @@ class TestMetadata(AioTestBase):
async def test_from_server_to_client(self):
multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT)
call = multicallable(_REQUEST)
self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata())
self.assertEqual(_RESPONSE, await call)
@ -207,8 +208,8 @@ class TestMetadata(AioTestBase):
async def test_from_client_to_server_with_list(self):
multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
call = multicallable(
_REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER))
call = multicallable(_REQUEST,
metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())

@ -198,7 +198,7 @@ class TestServerInterceptor(AioTestBase):
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
metadata = (('key', 'value'),)
metadata = aio.Metadata(('key', 'value'),)
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call
@ -208,7 +208,7 @@ class TestServerInterceptor(AioTestBase):
], record)
record.clear()
metadata = (('key', 'value'), ('secret', '42'))
metadata = aio.Metadata(('key', 'value'), ('secret', '42'))
call = multicallable(messages_pb2.SimpleRequest(),
metadata=metadata)
await call

Loading…
Cancel
Save