diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 468a8f42ce7..887e9c83316 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -36,6 +36,7 @@ cdef class _AioCall(GrpcCallWrapper): self._loop = asyncio.get_event_loop() self._create_grpc_call(deadline, method, call_credentials) self._is_locally_cancelled = False + self._status_received = asyncio.Event(loop=self._loop) def __dealloc__(self): if self.call: @@ -133,7 +134,7 @@ cdef class _AioCall(GrpcCallWrapper): cdef tuple ops cdef SendInitialMetadataOperation initial_metadata_op = SendInitialMetadataOperation( - _EMPTY_METADATA, + self._initial_metadata, GRPC_INITIAL_METADATA_USED_MASK) cdef SendMessageOperation send_message_op = SendMessageOperation(request, _EMPTY_FLAGS) cdef SendCloseFromClientOperation send_close_op = SendCloseFromClientOperation(_EMPTY_FLAGS) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index b8c635c4568..35d09537319 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -119,7 +119,7 @@ cdef class _ServicerContext: elif self._rpc_state.metadata_sent: raise RuntimeError('Send initial metadata failed: already sent') else: - _send_initial_metadata(self._rpc_state, self._loop) + await _send_initial_metadata(self._rpc_state, metadata, self._loop) self._rpc_state.metadata_sent = True async def abort(self, @@ -146,6 +146,9 @@ cdef class _ServicerContext: raise self._rpc_state.abort_exception + def invocation_metadata(self): + return _metadata(&self._rpc_state.request_metadata) + cdef _find_method_handler(str method, list generic_handlers): # TODO(lidiz) connects Metadata to call details diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 25ea89ccbcc..2ae4864e422 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -273,12 +273,14 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): Returned when an instance of `UnaryUnaryMultiCallable` object is called. """ _request: RequestType + _metadata: Optional[MetadataType] _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction _call: asyncio.Task # pylint: disable=too-many-arguments def __init__(self, request: RequestType, deadline: Optional[float], + metadata: Optional[MetadataType], credentials: Optional[grpc.CallCredentials], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, @@ -286,6 +288,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): channel.call(method, deadline, credentials) super().__init__(channel.call(method, deadline, credentials)) self._request = request + self._metadata = metadata self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._call = self._loop.create_task(self._invoke()) @@ -307,6 +310,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): try: serialized_response = await self._cython_call.unary_unary( serialized_request, + self._metadata, self._set_initial_metadata, self._set_status, ) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 6d4fe9145b0..e8ad9598473 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -95,9 +95,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): raised RpcError will also be a Call for the RPC affording the RPC's metadata, status code, and details. """ - if metadata: - raise NotImplementedError("TODO: metadata not implemented yet") - if wait_for_ready: raise NotImplementedError( "TODO: wait_for_ready not implemented yet") @@ -108,6 +105,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): return UnaryUnaryCall( request, _timeout_to_deadline(timeout), + metadata, credentials, self._channel, self._method, @@ -119,6 +117,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): self._interceptors, request, timeout, + metadata, credentials, self._channel, self._method, diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index ccb9f45fe4d..85475c6bc1b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -23,6 +23,22 @@ from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE +_INITIAL_METADATA_KEY = "initial-md-key" +_TRAILING_METADATA_KEY = "trailing-md-key-bin" + + +async def _maybe_echo_metadata(servicer_context): + """Copies metadata from request to response if it is present.""" + invocation_metadata = dict(servicer_context.invocation_metadata()) + if _INITIAL_METADATA_KEY in invocation_metadata: + initial_metadatum = (_INITIAL_METADATA_KEY, + invocation_metadata[_INITIAL_METADATA_KEY]) + await servicer_context.send_initial_metadata((initial_metadatum,)) + # if _TRAILING_METADATA_KEY in invocation_metadata: + # trailing_metadatum = (_TRAILING_METADATA_KEY, + # invocation_metadata[_TRAILING_METADATA_KEY]) + # servicer_context.set_trailing_metadata((trailing_metadatum,)) + class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 209643e52d1..05bd3e56897 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -112,6 +112,24 @@ class TestUnaryUnaryCall(AioTestBase): call = hi(messages_pb2.SimpleRequest()) self.assertEqual('', await call.details()) + async def test_call_initial_metadata_awaitable(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = hi(messages_pb2.SimpleRequest()) + self.assertEqual((), await call.initial_metadata()) + + async def test_call_trailing_metadata_awaitable(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = hi(messages_pb2.SimpleRequest()) + self.assertEqual((), await call.trailing_metadata()) + async def test_cancel_unary_unary(self): async with aio.insecure_channel(self._server_target) as channel: hi = channel.unary_unary( diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 6267862d890..706f09e8ecf 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -31,6 +31,12 @@ from tests_aio.unit._test_server import start_test_server _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' + +_INVOCATION_METADATA = ( + ('initial-md-key', 'initial-md-value'), + ('trailing-md-key-bin', b'\x00\x02'), +) + _NUM_STREAM_RESPONSES = 5 _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 @@ -97,6 +103,18 @@ class TestChannel(AioTestBase): timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5) self.assertEqual(await call.code(), grpc.StatusCode.OK) + async def test_unary_call_metadata(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + _UNARY_CALL_METHOD, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = hi(messages_pb2.SimpleRequest(), + metadata=_INVOCATION_METADATA) + initial_metadata = await call.initial_metadata() + + self.assertIsInstance(initial_metadata, tuple) + async def test_unary_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel)