Merge pull request #21647 from lidizheng/aio-metadata

[Aio] Support metadata for unary calls
pull/21679/head
Lidi Zheng 5 years ago committed by GitHub
commit af67aaf031
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 17
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 29
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  3. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  4. 62
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  5. 5
      src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi
  6. 18
      src/python/grpcio/grpc/experimental/aio/_call.py
  7. 28
      src/python/grpcio/grpc/experimental/aio/_channel.py
  8. 28
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  9. 3
      src/python/grpcio/grpc/experimental/aio/_typing.py
  10. 1
      src/python/grpcio_tests/tests_aio/tests.json
  11. 7
      src/python/grpcio_tests/tests_aio/unit/BUILD.bazel
  12. 24
      src/python/grpcio_tests/tests_aio/unit/_common.py
  13. 19
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  14. 18
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  15. 6
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  16. 57
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py
  17. 274
      src/python/grpcio_tests/tests_aio/unit/metadata_test.py
  18. 10
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -117,6 +117,7 @@ cdef class _AioCall(GrpcCallWrapper):
async def unary_unary(self,
bytes request,
tuple outbound_initial_metadata,
object initial_metadata_observer,
object status_observer):
"""Performs a unary unary RPC.
@ -133,7 +134,7 @@ cdef class _AioCall(GrpcCallWrapper):
cdef tuple ops
cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA,
outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK)
cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS)
cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS)
@ -151,6 +152,9 @@ cdef class _AioCall(GrpcCallWrapper):
ops,
self._loop)
# Reports received initial metadata.
initial_metadata_observer(receive_initial_metadata_op.initial_metadata())
status = AioRpcStatus(
receive_status_on_client_op.code(),
receive_status_on_client_op.details(),
@ -216,6 +220,7 @@ cdef class _AioCall(GrpcCallWrapper):
async def initiate_unary_stream(self,
bytes request,
tuple outbound_initial_metadata,
object initial_metadata_observer,
object status_observer):
"""Implementation of the start of a unary-stream call."""
@ -225,7 +230,7 @@ cdef class _AioCall(GrpcCallWrapper):
cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA,
outbound_initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK)
cdef Operation send_message_op = SendMessageOperation(
request,
@ -251,7 +256,7 @@ cdef class _AioCall(GrpcCallWrapper):
)
async def stream_unary(self,
tuple metadata,
tuple outbound_initial_metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
@ -263,7 +268,7 @@ cdef class _AioCall(GrpcCallWrapper):
"""
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
metadata,
outbound_initial_metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()
@ -300,7 +305,7 @@ cdef class _AioCall(GrpcCallWrapper):
return None
async def initiate_stream_stream(self,
tuple metadata,
tuple outbound_initial_metadata,
object metadata_sent_observer,
object initial_metadata_observer,
object status_observer):
@ -316,7 +321,7 @@ cdef class _AioCall(GrpcCallWrapper):
# Sends out initial_metadata ASAP.
await _send_initial_metadata(self,
metadata,
outbound_initial_metadata,
self._loop)
# Notify upper level that sending messages are allowed now.
metadata_sent_observer()

@ -120,6 +120,15 @@ async def execute_batch(GrpcCallWrapper grpc_call_wrapper,
batch_operation_tag.event(c_event)
cdef prepend_send_initial_metadata_op(tuple ops, tuple metadata):
# Eventually, this function should be the only function that produces
# SendInitialMetadataOperation. So we have more control over the flag.
return (SendInitialMetadataOperation(
metadata,
_EMPTY_FLAG
),) + ops
async def _receive_message(GrpcCallWrapper grpc_call_wrapper,
object loop):
"""Retrives parsed messages from Core.
@ -147,15 +156,9 @@ async def _send_message(GrpcCallWrapper grpc_call_wrapper,
bint metadata_sent,
object loop):
cdef SendMessageOperation op = SendMessageOperation(message, _EMPTY_FLAG)
cdef tuple ops
if metadata_sent:
ops = (op,)
else:
ops = (
# Initial metadata must be sent before first outbound message.
SendInitialMetadataOperation(None, _EMPTY_FLAG),
op,
)
cdef tuple ops = (op,)
if not metadata_sent:
ops = prepend_send_initial_metadata_op(ops, None)
await execute_batch(grpc_call_wrapper, ops, loop)
@ -189,9 +192,7 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
details,
_EMPTY_FLAGS,
)
cdef tuple ops
if metadata_sent:
ops = (op,)
else:
ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAG))
cdef tuple ops = (op,)
if not metadata_sent:
ops = prepend_send_initial_metadata_op(ops, None)
await execute_batch(grpc_call_wrapper, ops, loop)

@ -28,8 +28,10 @@ cdef class RPCState(GrpcCallWrapper):
cdef object abort_exception
cdef bint metadata_sent
cdef bint status_sent
cdef tuple trailing_metadata
cdef bytes method(self)
cdef tuple invocation_metadata(self)
cdef enum AioServerStatus:

@ -40,9 +40,13 @@ cdef class RPCState:
self.abort_exception = None
self.metadata_sent = False
self.status_sent = False
self.trailing_metadata = _EMPTY_METADATA
cdef bytes method(self):
return _slice_bytes(self.details.method)
return _slice_bytes(self.details.method)
cdef tuple invocation_metadata(self):
return _metadata(&self.request_metadata)
def __dealloc__(self):
"""Cleans the Core objects."""
@ -119,7 +123,7 @@ cdef class _ServicerContext:
elif self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
else:
_send_initial_metadata(self._rpc_state, self._loop)
await _send_initial_metadata(self._rpc_state, metadata, self._loop)
self._rpc_state.metadata_sent = True
async def abort(self,
@ -134,6 +138,9 @@ cdef class _ServicerContext:
# could lead to undefined behavior.
self._rpc_state.abort_exception = AbortError('Locally aborted.')
if trailing_metadata == _EMPTY_METADATA and self._rpc_state.trailing_metadata:
trailing_metadata = self._rpc_state.trailing_metadata
self._rpc_state.status_sent = True
await _send_error_status_from_server(
self._rpc_state,
@ -146,11 +153,16 @@ cdef class _ServicerContext:
raise self._rpc_state.abort_exception
def set_trailing_metadata(self, tuple metadata):
self._rpc_state.trailing_metadata = metadata
def invocation_metadata(self):
return self._rpc_state.invocation_metadata()
cdef _find_method_handler(str method, list generic_handlers):
# TODO(lidiz) connects Metadata to call details
cdef _find_method_handler(str method, tuple metadata, list generic_handlers):
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
None)
metadata)
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
@ -188,24 +200,21 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
)
# Assembles the batch operations
cdef Operation send_status_op = SendStatusFromServerOperation(
tuple(),
cdef tuple finish_ops
finish_ops = (
SendMessageOperation(response_raw, _EMPTY_FLAGS),
SendStatusFromServerOperation(
rpc_state.trailing_metadata,
StatusCode.ok,
b'',
_EMPTY_FLAGS,
),
)
cdef tuple finish_ops
if not rpc_state.metadata_sent:
finish_ops = (
send_status_op,
SendInitialMetadataOperation(None, _EMPTY_FLAGS),
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
else:
finish_ops = (
send_status_op,
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
finish_ops = prepend_send_initial_metadata_op(
finish_ops,
None)
rpc_state.metadata_sent = True
rpc_state.status_sent = True
await execute_batch(rpc_state, finish_ops, loop)
@ -216,7 +225,7 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
_ServicerContext servicer_context,
object loop):
"""Finishes server method handler with multiple responses.
This function executes the application handler, and handles response
sending, as well as errors. It is shared between unary-stream and
stream-stream handlers.
@ -254,7 +263,7 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
# Sends the final status of this RPC
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
None,
rpc_state.trailing_metadata,
StatusCode.ok,
b'',
_EMPTY_FLAGS,
@ -262,7 +271,11 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
cdef tuple finish_ops = (op,)
if not rpc_state.metadata_sent:
finish_ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAGS))
finish_ops = prepend_send_initial_metadata_op(
finish_ops,
None
)
rpc_state.metadata_sent = True
rpc_state.status_sent = True
await execute_batch(rpc_state, finish_ops, loop)
@ -411,8 +424,8 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
await _send_error_status_from_server(
rpc_state,
StatusCode.unknown,
'%s: %s' % (type(e), e),
_EMPTY_METADATA,
'Unexpected %s: %s' % (type(e), e),
rpc_state.trailing_metadata,
rpc_state.metadata_sent,
loop
)
@ -449,6 +462,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
# Finds the method handler (application logic)
method_handler = _find_method_handler(
rpc_state.method().decode(),
rpc_state.invocation_metadata(),
generic_handlers,
)
if method_handler is None:
@ -456,7 +470,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
await _send_error_status_from_server(
rpc_state,
StatusCode.unimplemented,
b'Method not found!',
'Method not found!',
_EMPTY_METADATA,
rpc_state.metadata_sent,
loop

@ -41,6 +41,11 @@ cdef void _store_c_metadata(
for index, (key, value) in enumerate(metadata):
encoded_key = _encode(key)
encoded_value = value if encoded_key[-4:] == b'-bin' else _encode(value)
if not isinstance(encoded_value, bytes):
raise TypeError('Binary metadata key="%s" expected bytes, got %s' % (
key,
type(encoded_value)
))
c_metadata[0][index].key = _slice_from_bytes(encoded_key)
c_metadata[0][index].value = _slice_from_bytes(encoded_value)

@ -273,19 +273,21 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
_metadata: Optional[MetadataType]
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
channel.call(method, deadline, credentials)
super().__init__(channel.call(method, deadline, credentials))
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke())
@ -307,6 +309,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
try:
serialized_response = await self._cython_call.unary_unary(
serialized_request,
self._metadata,
self._set_initial_metadata,
self._set_status,
)
@ -343,6 +346,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_send_unary_request_task: asyncio.Task
@ -350,12 +354,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._send_unary_request_task = self._loop.create_task(
@ -374,7 +380,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._request_serializer)
try:
await self._cython_call.initiate_unary_stream(
serialized_request, self._set_initial_metadata,
serialized_request, self._metadata, self._set_initial_metadata,
self._set_status)
except asyncio.CancelledError:
if not self.cancelled():
@ -442,13 +448,13 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._metadata = _EMPTY_METADATA
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
@ -564,13 +570,13 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__(channel.call(method, deadline, credentials))
self._metadata = _EMPTY_METADATA
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer

@ -95,19 +95,20 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if metadata is None:
metadata = tuple()
if not self._interceptors:
return UnaryUnaryCall(
request,
_timeout_to_deadline(timeout),
metadata,
credentials,
self._channel,
self._method,
@ -119,6 +120,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
self._interceptors,
request,
timeout,
metadata,
credentials,
self._channel,
self._method,
@ -157,9 +159,6 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
Returns:
A Call object instance which is an awaitable object.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
@ -168,10 +167,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
return UnaryStreamCall(
request,
deadline,
metadata,
credentials,
self._channel,
self._method,
@ -214,10 +216,6 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
@ -226,10 +224,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
return StreamUnaryCall(
request_async_iterator,
deadline,
metadata,
credentials,
self._channel,
self._method,
@ -272,10 +273,6 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
@ -284,10 +281,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = tuple()
return StreamStreamCall(
request_async_iterator,
deadline,
metadata,
credentials,
self._channel,
self._method,

@ -103,25 +103,28 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
_intercepted_call_created: asyncio.Event
_interceptors_task: asyncio.Task
def __init__( # pylint: disable=R0913
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._loop = asyncio.get_event_loop()
self._interceptors_task = asyncio.ensure_future(
self._invoke(interceptors, method, timeout, credentials, request,
request_serializer, response_deserializer))
self._invoke(interceptors, method, timeout, metadata, credentials,
request, request_serializer, response_deserializer))
def __del__(self):
self.cancel()
async def _invoke( # pylint: disable=R0913
# pylint: disable=too-many-arguments
async def _invoke(
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> UnaryUnaryCall:
@ -148,11 +151,12 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
else:
return UnaryUnaryCall(
request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials, self._channel,
client_call_details.method, request_serializer,
response_deserializer)
client_call_details = ClientCallDetails(method, timeout, None,
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials)
return await _run_interceptor(iter(interceptors), client_call_details,
request)
@ -287,7 +291,7 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
return None
def __await__(self):
if False: # pylint: disable=W0125
if False: # pylint: disable=using-constant-test
# This code path is never used, but a yield statement is needed
# for telling the interpreter that __await__ is a generator.
yield None

@ -20,6 +20,7 @@ RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadataType = Sequence[Tuple[Text, AnyStr]]
MetadatumType = Tuple[Text, AnyStr]
MetadataType = Sequence[MetadatumType]
ChannelArgumentType = Sequence[Tuple[Text, Any]]
EOFType = type(EOF)

@ -13,5 +13,6 @@
"unit.init_test.TestSecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.metadata_test.TestMetadata",
"unit.server_test.TestServer"
]

@ -43,6 +43,12 @@ py_library(
srcs_version = "PY3",
)
py_library(
name = "_common",
srcs = ["_common.py"],
srcs_version = "PY3",
)
[
py_test(
name = test_file_name[:-3],
@ -55,6 +61,7 @@ py_library(
main = test_file_name,
python_version = "PY3",
deps = [
":_common",
":_constants",
":_test_base",
":_test_server",

@ -0,0 +1,24 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from grpc.experimental.aio._typing import MetadataType, MetadatumType
def seen_metadata(expected: MetadataType, actual: MetadataType):
return not bool(set(expected) - set(actual))
def seen_metadatum(expected: MetadatumType, actual: MetadataType):
metadata_dict = dict(actual)
return metadata_dict.get(expected[0]) == expected[1]

@ -23,10 +23,27 @@ from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE
_INITIAL_METADATA_KEY = "initial-md-key"
_TRAILING_METADATA_KEY = "trailing-md-key-bin"
async def _maybe_echo_metadata(servicer_context):
"""Copies metadata from request to response if it is present."""
invocation_metadata = dict(servicer_context.invocation_metadata())
if _INITIAL_METADATA_KEY in invocation_metadata:
initial_metadatum = (_INITIAL_METADATA_KEY,
invocation_metadata[_INITIAL_METADATA_KEY])
await servicer_context.send_initial_metadata((initial_metadatum,))
if _TRAILING_METADATA_KEY in invocation_metadata:
trailing_metadatum = (_TRAILING_METADATA_KEY,
invocation_metadata[_TRAILING_METADATA_KEY])
servicer_context.set_trailing_metadata((trailing_metadatum,))
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
async def UnaryCall(self, unused_request, unused_context):
async def UnaryCall(self, unused_request, context):
await _maybe_echo_metadata(context)
return messages_pb2.SimpleResponse()
async def StreamingOutputCall(

@ -112,6 +112,24 @@ class TestUnaryUnaryCall(AioTestBase):
call = hi(messages_pb2.SimpleRequest())
self.assertEqual('', await call.details())
async def test_call_initial_metadata_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual((), await call.initial_metadata())
async def test_call_trailing_metadata_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata())
async def test_cancel_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(

@ -31,6 +31,12 @@ from tests_aio.unit._test_server import start_test_server
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_INVOCATION_METADATA = (
('initial-md-key', 'initial-md-value'),
('trailing-md-key-bin', b'\x00\x02'),
)
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42

@ -18,11 +18,17 @@ import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY
from tests_aio.unit import _constants
from tests_aio.unit import _common
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_INITIAL_METADATA_TO_INJECT = (
(_INITIAL_METADATA_KEY, 'extra info'),
(_TRAILING_METADATA_KEY, b'\x13\x37'),
)
class TestUnaryUnaryClientInterceptor(AioTestBase):
@ -124,7 +130,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
client_call_details, request):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
return await continuation(new_client_call_details, request)
@ -165,7 +171,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
@ -342,8 +348,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest(),
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
call = multicallable(
messages_pb2.SimpleRequest(),
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
@ -375,8 +382,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest(),
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
call = multicallable(
messages_pb2.SimpleRequest(),
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
@ -532,6 +540,39 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.initial_metadata(), tuple())
self.assertEqual(await call.trailing_metadata(), None)
async def test_initial_metadata_modification(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=client_call_details.metadata +
_INITIAL_METADATA_TO_INJECT,
credentials=client_call_details.credentials,
)
return await continuation(new_details, request)
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.UnaryCall(messages_pb2.SimpleRequest())
# Expected to see the echoed initial metadata
self.assertTrue(
_common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[0], await
call.initial_metadata()))
# Expected to see the echoed trailing metadata
self.assertTrue(
_common.seen_metadatum(_INITIAL_METADATA_TO_INJECT[1], await
call.trailing_metadata()))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__':
logging.basicConfig()

@ -0,0 +1,274 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior around the metadata mechanism."""
import asyncio
import logging
import platform
import random
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit import _common
_TEST_CLIENT_TO_SERVER = '/test/TestClientToServer'
_TEST_SERVER_TO_CLIENT = '/test/TestServerToClient'
_TEST_TRAILING_METADATA = '/test/TestTrailingMetadata'
_TEST_ECHO_INITIAL_METADATA = '/test/TestEchoInitialMetadata'
_TEST_GENERIC_HANDLER = '/test/TestGenericHandler'
_TEST_UNARY_STREAM = '/test/TestUnaryStream'
_TEST_STREAM_UNARY = '/test/TestStreamUnary'
_TEST_STREAM_STREAM = '/test/TestStreamStream'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = (
('client-to-server', 'question'),
('client-to-server-bin', b'\x07\x07\x07'),
)
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = (
('server-to-client', 'answer'),
('server-to-client-bin', b'\x06\x06\x06'),
)
_TRAILING_METADATA = (('a-trailing-metadata', 'stack-trace'),
('a-trailing-metadata-bin', b'\x05\x05\x05'))
_INITIAL_METADATA_FOR_GENERIC_HANDLER = (('a-must-have-key', 'secret'),)
_INVALID_METADATA_TEST_CASES = (
(
TypeError,
((42, 42),),
),
(
TypeError,
(({}, {}),),
),
(
TypeError,
(('normal', object()),),
),
(
TypeError,
object(),
),
(
TypeError,
(object(),),
),
)
class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
def __init__(self):
self._routing_table = {
_TEST_CLIENT_TO_SERVER:
grpc.unary_unary_rpc_method_handler(self._test_client_to_server
),
_TEST_SERVER_TO_CLIENT:
grpc.unary_unary_rpc_method_handler(self._test_server_to_client
),
_TEST_TRAILING_METADATA:
grpc.unary_unary_rpc_method_handler(self._test_trailing_metadata
),
_TEST_UNARY_STREAM:
grpc.unary_stream_rpc_method_handler(self._test_unary_stream),
_TEST_STREAM_UNARY:
grpc.stream_unary_rpc_method_handler(self._test_stream_unary),
_TEST_STREAM_STREAM:
grpc.stream_stream_rpc_method_handler(self._test_stream_stream),
}
@staticmethod
async def _test_client_to_server(request, context):
assert _REQUEST == request
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
return _RESPONSE
@staticmethod
async def _test_server_to_client(request, context):
assert _REQUEST == request
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
return _RESPONSE
@staticmethod
async def _test_trailing_metadata(request, context):
assert _REQUEST == request
context.set_trailing_metadata(_TRAILING_METADATA)
return _RESPONSE
@staticmethod
async def _test_unary_stream(request, context):
assert _REQUEST == request
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
yield _RESPONSE
context.set_trailing_metadata(_TRAILING_METADATA)
@staticmethod
async def _test_stream_unary(request_iterator, context):
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
async for request in request_iterator:
assert _REQUEST == request
context.set_trailing_metadata(_TRAILING_METADATA)
return _RESPONSE
@staticmethod
async def _test_stream_stream(request_iterator, context):
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
await context.send_initial_metadata(
_INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
async for request in request_iterator:
assert _REQUEST == request
yield _RESPONSE
context.set_trailing_metadata(_TRAILING_METADATA)
def service(self, handler_call_details):
return self._routing_table.get(handler_call_details.method)
class _TestGenericHandlerItself(grpc.GenericRpcHandler):
@staticmethod
async def _method(request, unused_context):
assert _REQUEST == request
return _RESPONSE
def service(self, handler_call_details):
assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER,
handler_call_details.invocation_metadata)
return grpc.unary_unary_rpc_method_handler(self._method)
async def _start_test_server():
server = aio.server()
port = server.add_insecure_port('[::]:0')
server.add_generic_rpc_handlers((
_TestGenericHandlerForMethods(),
_TestGenericHandlerItself(),
))
await server.start()
return 'localhost:%d' % port, server
class TestMetadata(AioTestBase):
async def setUp(self):
address, self._server = await _start_test_server()
self._client = aio.insecure_channel(address)
async def tearDown(self):
await self._client.close()
await self._server.stop(None)
async def test_from_client_to_server(self):
multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
call = multicallable(_REQUEST,
metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_from_server_to_client(self):
multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT)
call = multicallable(_REQUEST)
self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata())
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_trailing_metadata(self):
multicallable = self._client.unary_unary(_TEST_TRAILING_METADATA)
call = multicallable(_REQUEST)
self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_invalid_metadata(self):
multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
for exception_type, metadata in _INVALID_METADATA_TEST_CASES:
with self.subTest(metadata=metadata):
call = multicallable(_REQUEST, metadata=metadata)
with self.assertRaises(exception_type):
await call
async def test_generic_handler(self):
multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER)
call = multicallable(_REQUEST,
metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER)
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_unary_stream(self):
multicallable = self._client.unary_stream(_TEST_UNARY_STREAM)
call = multicallable(_REQUEST,
metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
self.assertTrue(
_common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata()))
self.assertSequenceEqual([_RESPONSE],
[request async for request in call])
self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_stream_unary(self):
multicallable = self._client.stream_unary(_TEST_STREAM_UNARY)
call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
await call.write(_REQUEST)
await call.done_writing()
self.assertTrue(
_common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata()))
self.assertEqual(_RESPONSE, await call)
self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_stream_stream(self):
multicallable = self._client.stream_stream(_TEST_STREAM_STREAM)
call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
await call.write(_REQUEST)
await call.done_writing()
self.assertTrue(
_common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
call.initial_metadata()))
self.assertSequenceEqual([_RESPONSE],
[request async for request in call])
self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
self.assertEqual(grpc.StatusCode.OK, await call.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -36,6 +36,7 @@ _STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed'
_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen'
_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter'
_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed'
_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
@ -159,7 +160,7 @@ class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_details):
self._called.set_result(None)
return self._routing_table[handler_details.method]
return self._routing_table.get(handler_details.method)
async def wait_for_call(self):
await self._called
@ -393,6 +394,13 @@ class TestServer(AioTestBase):
async with aio.insecure_channel('localhost:%d' % port) as channel:
await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
async def test_unimplemented(self):
call = self._channel.unary_unary(_UNIMPLEMENTED_METHOD)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save