fix metadata

pull/21647/head
Zhanghui Mao 5 years ago committed by Lidi Zheng
parent 0b802e0404
commit 6d556914d0
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  2. 19
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 11
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  4. 4
      src/python/grpcio_tests/tests_aio/unit/channel_test.py

@ -28,8 +28,10 @@ cdef class RPCState(GrpcCallWrapper):
cdef object abort_exception cdef object abort_exception
cdef bint metadata_sent cdef bint metadata_sent
cdef bint status_sent cdef bint status_sent
cdef tuple trailing_metadata
cdef bytes method(self) cdef bytes method(self)
cdef tuple invocation_metadata(self)
cdef enum AioServerStatus: cdef enum AioServerStatus:

@ -40,9 +40,13 @@ cdef class RPCState:
self.abort_exception = None self.abort_exception = None
self.metadata_sent = False self.metadata_sent = False
self.status_sent = False self.status_sent = False
self.trailing_metadata = tuple()
cdef bytes method(self): cdef bytes method(self):
return _slice_bytes(self.details.method) return _slice_bytes(self.details.method)
cdef tuple invocation_metadata(self):
return _metadata(&self.request_metadata)
def __dealloc__(self): def __dealloc__(self):
"""Cleans the Core objects.""" """Cleans the Core objects."""
@ -146,8 +150,11 @@ cdef class _ServicerContext:
raise self._rpc_state.abort_exception raise self._rpc_state.abort_exception
def set_trailing_metadata(self, tuple metadata):
self._rpc_state.trailing_metadata = metadata
def invocation_metadata(self): def invocation_metadata(self):
return _metadata(&self._rpc_state.request_metadata) return self._rpc_state.invocation_metadata()
cdef _find_method_handler(str method, list generic_handlers): cdef _find_method_handler(str method, list generic_handlers):
@ -192,10 +199,10 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
# Assembles the batch operations # Assembles the batch operations
cdef Operation send_status_op = SendStatusFromServerOperation( cdef Operation send_status_op = SendStatusFromServerOperation(
tuple(), rpc_state.trailing_metadata,
StatusCode.ok, StatusCode.ok,
b'', b'',
_EMPTY_FLAGS, _EMPTY_FLAGS,
) )
cdef tuple finish_ops cdef tuple finish_ops
if not rpc_state.metadata_sent: if not rpc_state.metadata_sent:

@ -34,15 +34,16 @@ async def _maybe_echo_metadata(servicer_context):
initial_metadatum = (_INITIAL_METADATA_KEY, initial_metadatum = (_INITIAL_METADATA_KEY,
invocation_metadata[_INITIAL_METADATA_KEY]) invocation_metadata[_INITIAL_METADATA_KEY])
await servicer_context.send_initial_metadata((initial_metadatum,)) await servicer_context.send_initial_metadata((initial_metadatum,))
# if _TRAILING_METADATA_KEY in invocation_metadata: if _TRAILING_METADATA_KEY in invocation_metadata:
# trailing_metadatum = (_TRAILING_METADATA_KEY, trailing_metadatum = (_TRAILING_METADATA_KEY,
# invocation_metadata[_TRAILING_METADATA_KEY]) invocation_metadata[_TRAILING_METADATA_KEY])
# servicer_context.set_trailing_metadata((trailing_metadatum,)) servicer_context.set_trailing_metadata((trailing_metadatum,))
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
async def UnaryCall(self, unused_request, unused_context): async def UnaryCall(self, unused_request, context):
await _maybe_echo_metadata(context)
return messages_pb2.SimpleResponse() return messages_pb2.SimpleResponse()
async def StreamingOutputCall( async def StreamingOutputCall(

@ -112,8 +112,12 @@ class TestChannel(AioTestBase):
call = hi(messages_pb2.SimpleRequest(), call = hi(messages_pb2.SimpleRequest(),
metadata=_INVOCATION_METADATA) metadata=_INVOCATION_METADATA)
initial_metadata = await call.initial_metadata() initial_metadata = await call.initial_metadata()
trailing_metadata = await call.trailing_metadata()
self.assertIsInstance(initial_metadata, tuple) self.assertIsInstance(initial_metadata, tuple)
self.assertEqual(_INVOCATION_METADATA[0], initial_metadata[0])
self.assertIsInstance(trailing_metadata, tuple)
self.assertEqual(_INVOCATION_METADATA[1], trailing_metadata[0])
async def test_unary_stream(self): async def test_unary_stream(self):
channel = aio.insecure_channel(self._server_target) channel = aio.insecure_channel(self._server_target)

Loading…
Cancel
Save