Support metadata for streaming RPCs

pull/21647/head
Lidi Zheng 5 years ago
parent f912ddf7d4
commit 613f64f12e
  1. 11
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 14
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 13
      src/python/grpcio/grpc/experimental/aio/_call.py
  4. 20
      src/python/grpcio/grpc/experimental/aio/_channel.py
  5. 20
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  6. 2
      src/python/grpcio_tests/tests_aio/unit/_common.py
  7. 125
      src/python/grpcio_tests/tests_aio/unit/metadata_test.py

@ -220,6 +220,7 @@ cdef class _AioCall(GrpcCallWrapper):
async def initiate_unary_stream(self,
bytes request,
tuple outbound_initial_metadata,
object initial_metadata_observer,
object status_observer):
"""Implementation of the start of a unary-stream call."""
@ -229,7 +230,7 @@ cdef class _AioCall(GrpcCallWrapper):
cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA,
outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK)
cdef Operation send_message_op = SendMessageOperation(
request,
@ -255,7 +256,7 @@ cdef class _AioCall(GrpcCallWrapper):
)
async def stream_unary(self,
tuple metadata,
tuple outbound_initial_metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
@ -267,7 +268,7 @@ cdef class _AioCall(GrpcCallWrapper):
"""
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
metadata,
outbound_initial_metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
@ -304,7 +305,7 @@ cdef class _AioCall(GrpcCallWrapper):
return None
async def initiate_stream_stream(self,
tuple metadata,
tuple outbound_initial_metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
@ -320,7 +321,7 @@ cdef class _AioCall(GrpcCallWrapper):
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
metadata,
outbound_initial_metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()

@ -138,6 +138,9 @@ cdef class _ServicerContext:
# could lead to undefined behavior.
self._rpc_state.abort_exception = AbortError('Locally aborted.')
if trailing_metadata == _EMPTY_METADATA and self._rpc_state.trailing_metadata:
trailing_metadata = self._rpc_state.trailing_metadata
self._rpc_state.status_sent = True
await _send_error_status_from_server(
self._rpc_state,
@ -210,8 +213,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
if not rpc_state.metadata_sent:
finish_ops = prepend_send_initial_metadata_op(
finish_ops,
None
)
None)
rpc_state.metadata_sent = True
rpc_state.status_sent = True
await execute_batch(rpc_state, finish_ops, loop)
@ -223,7 +225,7 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
_ServicerContext servicer_context,
object loop):
"""Finishes server method handler with multiple responses.
This function executes the application handler, and handles response
sending, as well as errors. It is shared between unary-stream and
stream-stream handlers.
@ -261,7 +263,7 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
# Sends the final status of this RPC
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
None,
rpc_state.trailing_metadata,
StatusCode.ok,
b'',
_EMPTY_FLAGS,
@ -422,8 +424,8 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
await _send_error_status_from_server(
rpc_state,
StatusCode.unknown,
'%s: %s' % (type(e), e),
_EMPTY_METADATA,
'Unexpected %s: %s' % (type(e), e),
rpc_state.trailing_metadata,
rpc_state.metadata_sent,
loop
)

@ -346,6 +346,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_send_unary_request_task: asyncio.Task
@ -353,12 +354,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._send_unary_request_task = self._loop.create_task(
@ -377,7 +380,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._request_serializer)
try:
await self._cython_call.initiate_unary_stream(
serialized_request, self._set_initial_metadata,
serialized_request, self._metadata, self._set_initial_metadata,
self._set_status)
except asyncio.CancelledError:
if not self.cancelled():
@ -445,13 +448,13 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._metadata = _EMPTY_METADATA
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
@ -567,13 +570,13 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._metadata = _EMPTY_METADATA
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer

@ -159,9 +159,6 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
Returns:
A Call object instance which is an awaitable object.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
@ -170,10 +167,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
return UnaryStreamCall(
request,
deadline,
metadata,
credentials,
self._channel,
self._method,
@ -216,10 +216,6 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
@ -228,10 +224,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
return StreamUnaryCall(
request_async_iterator,
deadline,
metadata,
credentials,
self._channel,
self._method,
@ -274,10 +273,6 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
@ -286,10 +281,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
return StreamStreamCall(
request_async_iterator,
deadline,
metadata,
credentials,
self._channel,
self._method,

@ -103,13 +103,14 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
_intercepted_call_created: asyncio.Event
_interceptors_task: asyncio.Task
def __init__( # pylint: disable=R0913
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: MetadataType, credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._loop = asyncio.get_event_loop()
self._interceptors_task = asyncio.ensure_future(
@ -119,7 +120,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
def __del__(self):
self.cancel()
async def _invoke( # pylint: disable=R0913
# pylint: disable=too-many-arguments
async def _invoke(
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
@ -289,7 +291,7 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
return None
def __await__(self):
if False: # pylint: disable=W0125
if False: # pylint: disable=using-constant-test
# This code path is never used, but a yield statement is needed
# for telling the interpreter that __await__ is a generator.
yield None

@ -16,7 +16,7 @@ from grpc.experimental.aio._typing import MetadataType, MetadatumType
def seen_metadata(expected: MetadataType, actual: MetadataType):
return bool(set(expected) - set(actual))
return not bool(set(expected) - set(actual))
def seen_metadatum(expected: MetadatumType, actual: MetadataType):

@ -30,6 +30,9 @@ _TEST_SERVER_TO_CLIENT = '/test/TestServerToClient'
_TEST_TRAILING_METADATA = '/test/TestTrailingMetadata'
_TEST_ECHO_INITIAL_METADATA = '/test/TestEchoInitialMetadata'
_TEST_GENERIC_HANDLER = '/test/TestGenericHandler'
_TEST_UNARY_STREAM = '/test/TestUnaryStream'
_TEST_STREAM_UNARY = '/test/TestStreamUnary'
_TEST_STREAM_STREAM = '/test/TestStreamStream'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
@ -72,6 +75,25 @@ _INVALID_METADATA_TEST_CASES = (
class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
def __init__(self):
self._routing_table = {
_TEST_CLIENT_TO_SERVER:
grpc.unary_unary_rpc_method_handler(self._test_client_to_server
),
_TEST_SERVER_TO_CLIENT:
grpc.unary_unary_rpc_method_handler(self._test_server_to_client
),
_TEST_TRAILING_METADATA:
grpc.unary_unary_rpc_method_handler(self._test_trailing_metadata
),
_TEST_UNARY_STREAM:
grpc.unary_stream_rpc_method_handler(self._test_unary_stream),
_TEST_STREAM_UNARY:
grpc.stream_unary_rpc_method_handler(self._test_stream_unary),
_TEST_STREAM_STREAM:
grpc.stream_stream_rpc_method_handler(self._test_stream_stream),
}
@staticmethod
async def _test_client_to_server(request, context):
assert _REQUEST == request
@ -92,17 +114,44 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
context.set_trailing_metadata(_TRAILING_METADATA)
return _RESPONSE
def service(self, handler_details):
if handler_details.method == _TEST_CLIENT_TO_SERVER:
return grpc.unary_unary_rpc_method_handler(
self._test_client_to_server)
if handler_details.method == _TEST_SERVER_TO_CLIENT:
return grpc.unary_unary_rpc_method_handler(
self._test_server_to_client)
if handler_details.method == _TEST_TRAILING_METADATA:
return grpc.unary_unary_rpc_method_handler(
self._test_trailing_metadata)
return None
@staticmethod
async def _test_unary_stream(request, context):
assert _REQUEST == request
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
yield _RESPONSE
context.set_trailing_metadata(_TRAILING_METADATA)
@staticmethod
async def _test_stream_unary(request_iterator, context):
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
async for request in request_iterator:
assert _REQUEST == request
context.set_trailing_metadata(_TRAILING_METADATA)
return _RESPONSE
@staticmethod
async def _test_stream_stream(request_iterator, context):
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
async for request in request_iterator:
assert _REQUEST == request
yield _RESPONSE
context.set_trailing_metadata(_TRAILING_METADATA)
def service(self, handler_call_details):
return self._routing_table.get(handler_call_details.method)
class _TestGenericHandlerItself(grpc.GenericRpcHandler):
@ -112,9 +161,9 @@ class _TestGenericHandlerItself(grpc.GenericRpcHandler):
assert _REQUEST == request
return _RESPONSE
def service(self, handler_details):
def service(self, handler_call_details):
assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER,
handler_details.invocation_metadata)
handler_call_details.invocation_metadata)
return grpc.unary_unary_rpc_method_handler(self._method)
@ -164,9 +213,10 @@ class TestMetadata(AioTestBase):
async def test_invalid_metadata(self):
multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
for exception_type, metadata in _INVALID_METADATA_TEST_CASES:
call = multicallable(_REQUEST, metadata=metadata)
with self.assertRaises(exception_type):
await call
with self.subTest(metadata=metadata):
call = multicallable(_REQUEST, metadata=metadata)
with self.assertRaises(exception_type):
await call
async def test_generic_handler(self):
multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER)
@ -175,6 +225,49 @@ class TestMetadata(AioTestBase):
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_unary_stream(self):
multicallable = self._client.unary_stream(_TEST_UNARY_STREAM)
call = multicallable(_REQUEST,
metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
self.assertTrue(
_common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata()))
self.assertSequenceEqual([_RESPONSE],
[request async for request in call])
self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_stream_unary(self):
multicallable = self._client.stream_unary(_TEST_STREAM_UNARY)
call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
await call.write(_REQUEST)
await call.done_writing()
self.assertTrue(
_common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata()))
self.assertEqual(_RESPONSE, await call)
self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_stream_stream(self):
multicallable = self._client.stream_stream(_TEST_STREAM_STREAM)
call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
await call.write(_REQUEST)
await call.done_writing()
self.assertTrue(
_common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata()))
self.assertSequenceEqual([_RESPONSE],
[request async for request in call])
self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
self.assertEqual(grpc.StatusCode.OK, await call.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save