implement metadata for aio unary call

pull/21647/head
Zhanghui Mao 5 years ago committed by Lidi Zheng
parent 5a4a5a0088
commit 0b802e0404
  1. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 5
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 4
      src/python/grpcio/grpc/experimental/aio/_call.py
  4. 5
      src/python/grpcio/grpc/experimental/aio/_channel.py
  5. 16
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  6. 18
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  7. 18
      src/python/grpcio_tests/tests_aio/unit/channel_test.py

@ -36,6 +36,7 @@ cdef class _AioCall(GrpcCallWrapper):
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method, call_credentials) self._create_grpc_call(deadline, method, call_credentials)
self._is_locally_cancelled = False self._is_locally_cancelled = False
self._status_received = asyncio.Event(loop=self._loop)
def __dealloc__(self): def __dealloc__(self):
if self.call: if self.call:
@ -133,7 +134,7 @@ cdef class _AioCall(GrpcCallWrapper):
cdef tuple ops cdef tuple ops
cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation( cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA, self._initial_metadata,
GRPC_INITIAL_METADATA_USED_MASK) GRPC_INITIAL_METADATA_USED_MASK)
cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS) cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS)
cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS) cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS)

@ -119,7 +119,7 @@ cdef class _ServicerContext:
elif self._rpc_state.metadata_sent: elif self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent') raise RuntimeError('Send initial metadata failed: already sent')
else: else:
_send_initial_metadata(self._rpc_state, self._loop) await _send_initial_metadata(self._rpc_state, metadata, self._loop)
self._rpc_state.metadata_sent = True self._rpc_state.metadata_sent = True
async def abort(self, async def abort(self,
@ -146,6 +146,9 @@ cdef class _ServicerContext:
raise self._rpc_state.abort_exception raise self._rpc_state.abort_exception
def invocation_metadata(self):
return _metadata(&self._rpc_state.request_metadata)
cdef _find_method_handler(str method, list generic_handlers): cdef _find_method_handler(str method, list generic_handlers):
# TODO(lidiz) connects Metadata to call details # TODO(lidiz) connects Metadata to call details

@ -273,12 +273,14 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
Returned when an instance of `UnaryUnaryMultiCallable` object is called. Returned when an instance of `UnaryUnaryMultiCallable` object is called.
""" """
_request: RequestType _request: RequestType
_metadata: Optional[MetadataType]
_request_serializer: SerializingFunction _request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction _response_deserializer: DeserializingFunction
_call: asyncio.Task _call: asyncio.Task
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float], def __init__(self, request: RequestType, deadline: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
@ -286,6 +288,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
channel.call(method, deadline, credentials) channel.call(method, deadline, credentials)
super().__init__(channel.call(method, deadline, credentials)) super().__init__(channel.call(method, deadline, credentials))
self._request = request self._request = request
self._metadata = metadata
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke()) self._call = self._loop.create_task(self._invoke())
@ -307,6 +310,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
try: try:
serialized_response = await self._cython_call.unary_unary( serialized_response = await self._cython_call.unary_unary(
serialized_request, serialized_request,
self._metadata,
self._set_initial_metadata, self._set_initial_metadata,
self._set_status, self._set_status,
) )

@ -95,9 +95,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details. metadata, status code, and details.
""" """
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if wait_for_ready: if wait_for_ready:
raise NotImplementedError( raise NotImplementedError(
"TODO: wait_for_ready not implemented yet") "TODO: wait_for_ready not implemented yet")
@ -108,6 +105,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
return UnaryUnaryCall( return UnaryUnaryCall(
request, request,
_timeout_to_deadline(timeout), _timeout_to_deadline(timeout),
metadata,
credentials, credentials,
self._channel, self._channel,
self._method, self._method,
@ -119,6 +117,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
self._interceptors, self._interceptors,
request, request,
timeout, timeout,
metadata,
credentials, credentials,
self._channel, self._channel,
self._method, self._method,

@ -23,6 +23,22 @@ from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE
_INITIAL_METADATA_KEY = "initial-md-key"
_TRAILING_METADATA_KEY = "trailing-md-key-bin"
async def _maybe_echo_metadata(servicer_context):
"""Copies metadata from request to response if it is present."""
invocation_metadata = dict(servicer_context.invocation_metadata())
if _INITIAL_METADATA_KEY in invocation_metadata:
initial_metadatum = (_INITIAL_METADATA_KEY,
invocation_metadata[_INITIAL_METADATA_KEY])
await servicer_context.send_initial_metadata((initial_metadatum,))
# if _TRAILING_METADATA_KEY in invocation_metadata:
# trailing_metadatum = (_TRAILING_METADATA_KEY,
# invocation_metadata[_TRAILING_METADATA_KEY])
# servicer_context.set_trailing_metadata((trailing_metadatum,))
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):

@ -112,6 +112,24 @@ class TestUnaryUnaryCall(AioTestBase):
call = hi(messages_pb2.SimpleRequest()) call = hi(messages_pb2.SimpleRequest())
self.assertEqual('', await call.details()) self.assertEqual('', await call.details())
async def test_call_initial_metadata_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual((), await call.initial_metadata())
async def test_call_trailing_metadata_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata())
async def test_cancel_unary_unary(self): async def test_cancel_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel: async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary( hi = channel.unary_unary(

@ -31,6 +31,12 @@ from tests_aio.unit._test_server import start_test_server
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_INVOCATION_METADATA = (
('initial-md-key', 'initial-md-value'),
('trailing-md-key-bin', b'\x00\x02'),
)
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7 _REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
@ -97,6 +103,18 @@ class TestChannel(AioTestBase):
timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5) timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5)
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_call_metadata(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
_UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest(),
metadata=_INVOCATION_METADATA)
initial_metadata = await call.initial_metadata()
self.assertIsInstance(initial_metadata, tuple)
async def test_unary_stream(self): async def test_unary_stream(self):
channel = aio.insecure_channel(self._server_target) channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)

Loading…
Cancel
Save