diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index 15f6bba0a80..4703337b60c 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -28,8 +28,10 @@ cdef class RPCState(GrpcCallWrapper): cdef object abort_exception cdef bint metadata_sent cdef bint status_sent + cdef tuple trailing_metadata cdef bytes method(self) + cdef tuple invocation_metadata(self) cdef enum AioServerStatus: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 35d09537319..c21a0d0eed1 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -40,9 +40,13 @@ cdef class RPCState: self.abort_exception = None self.metadata_sent = False self.status_sent = False + self.trailing_metadata = tuple() cdef bytes method(self): - return _slice_bytes(self.details.method) + return _slice_bytes(self.details.method) + + cdef tuple invocation_metadata(self): + return _metadata(&self.request_metadata) def __dealloc__(self): """Cleans the Core objects.""" @@ -146,8 +150,11 @@ cdef class _ServicerContext: raise self._rpc_state.abort_exception + def set_trailing_metadata(self, tuple metadata): + self._rpc_state.trailing_metadata = metadata + def invocation_metadata(self): - return _metadata(&self._rpc_state.request_metadata) + return self._rpc_state.invocation_metadata() cdef _find_method_handler(str method, list generic_handlers): @@ -192,10 +199,10 @@ async def _finish_handler_with_unary_response(RPCState rpc_state, # Assembles the batch operations cdef Operation send_status_op = SendStatusFromServerOperation( - tuple(), - StatusCode.ok, - b'', - _EMPTY_FLAGS, + rpc_state.trailing_metadata, + StatusCode.ok, + b'', + _EMPTY_FLAGS, ) cdef tuple finish_ops if not rpc_state.metadata_sent: diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 85475c6bc1b..03aea81ec91 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -34,15 +34,16 @@ async def _maybe_echo_metadata(servicer_context): initial_metadatum = (_INITIAL_METADATA_KEY, invocation_metadata[_INITIAL_METADATA_KEY]) await servicer_context.send_initial_metadata((initial_metadatum,)) - # if _TRAILING_METADATA_KEY in invocation_metadata: - # trailing_metadatum = (_TRAILING_METADATA_KEY, - # invocation_metadata[_TRAILING_METADATA_KEY]) - # servicer_context.set_trailing_metadata((trailing_metadatum,)) + if _TRAILING_METADATA_KEY in invocation_metadata: + trailing_metadatum = (_TRAILING_METADATA_KEY, + invocation_metadata[_TRAILING_METADATA_KEY]) + servicer_context.set_trailing_metadata((trailing_metadatum,)) class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): - async def UnaryCall(self, unused_request, unused_context): + async def UnaryCall(self, unused_request, context): + await _maybe_echo_metadata(context) return messages_pb2.SimpleResponse() async def StreamingOutputCall( diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 706f09e8ecf..1a1598c2a60 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -112,8 +112,12 @@ class TestChannel(AioTestBase): call = hi(messages_pb2.SimpleRequest(), metadata=_INVOCATION_METADATA) initial_metadata = await call.initial_metadata() + trailing_metadata = await call.trailing_metadata() self.assertIsInstance(initial_metadata, tuple) + self.assertEqual(_INVOCATION_METADATA[0], initial_metadata[0]) + self.assertIsInstance(trailing_metadata, tuple) + self.assertEqual(_INVOCATION_METADATA[1], trailing_metadata[0]) async def test_unary_stream(self): channel = aio.insecure_channel(self._server_target)