Revert "Revert "Expose code and details from context on the server side (#25457)" (#26112)" (#26143)

This reverts commit ff9ece1588.
reviewable/pr26166/r1
Lidi Zheng 4 years ago committed by GitHub
parent cc4b5569c0
commit ff79a925ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 30
      src/python/grpcio/grpc/__init__.py
  2. 9
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 9
      src/python/grpcio/grpc/_server.py
  4. 30
      src/python/grpcio/grpc/aio/_base_server.py
  5. 1
      src/python/grpcio_tests/tests/tests.json
  6. 58
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
  7. 28
      src/python/grpcio_tests/tests_aio/unit/metadata_test.py

@ -1174,6 +1174,16 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
""" """
raise NotImplementedError() raise NotImplementedError()
def trailing_metadata(self):
"""Access value to be used as trailing metadata upon RPC completion.
This is an EXPERIMENTAL API.
Returns:
The trailing :term:`metadata` for the RPC.
"""
raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def abort(self, code, details): def abort(self, code, details):
"""Raises an exception to terminate the RPC with a non-OK status. """Raises an exception to terminate the RPC with a non-OK status.
@ -1237,6 +1247,26 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
""" """
raise NotImplementedError() raise NotImplementedError()
def code(self):
"""Accesses the value to be used as status code upon RPC completion.
This is an EXPERIMENTAL API.
Returns:
The StatusCode value for the RPC.
"""
raise NotImplementedError()
def details(self):
"""Accesses the value to be used as detail string upon RPC completion.
This is an EXPERIMENTAL API.
Returns:
The details string of the RPC.
"""
raise NotImplementedError()
def disable_next_message_compression(self): def disable_next_message_compression(self):
"""Disables compression for the next response message. """Disables compression for the next response message.

@ -197,15 +197,24 @@ cdef class _ServicerContext:
def set_trailing_metadata(self, object metadata): def set_trailing_metadata(self, object metadata):
self._rpc_state.trailing_metadata = tuple(metadata) self._rpc_state.trailing_metadata = tuple(metadata)
def trailing_metadata(self):
return self._rpc_state.trailing_metadata
def invocation_metadata(self): def invocation_metadata(self):
return self._rpc_state.invocation_metadata() return self._rpc_state.invocation_metadata()
def set_code(self, object code): def set_code(self, object code):
self._rpc_state.status_code = get_status_code(code) self._rpc_state.status_code = get_status_code(code)
def code(self):
return self._rpc_state.status_code
def set_details(self, str details): def set_details(self, str details):
self._rpc_state.status_details = details self._rpc_state.status_details = details
def details(self):
return self._rpc_state.status_details
def set_compression(self, object compression): def set_compression(self, object compression):
if self._rpc_state.metadata_sent: if self._rpc_state.metadata_sent:
raise RuntimeError('Compression setting must be specified before sending initial metadata') raise RuntimeError('Compression setting must be specified before sending initial metadata')

@ -305,6 +305,9 @@ class _Context(grpc.ServicerContext):
with self._state.condition: with self._state.condition:
self._state.trailing_metadata = trailing_metadata self._state.trailing_metadata = trailing_metadata
def trailing_metadata(self):
return self._state.trailing_metadata
def abort(self, code, details): def abort(self, code, details):
# treat OK like other invalid arguments: fail the RPC # treat OK like other invalid arguments: fail the RPC
if code == grpc.StatusCode.OK: if code == grpc.StatusCode.OK:
@ -326,10 +329,16 @@ class _Context(grpc.ServicerContext):
with self._state.condition: with self._state.condition:
self._state.code = code self._state.code = code
def code(self):
return self._state.code
def set_details(self, details): def set_details(self, details):
with self._state.condition: with self._state.condition:
self._state.details = _common.encode(details) self._state.details = _common.encode(details)
def details(self):
return self._state.details
def _finalize_state(self): def _finalize_state(self):
pass pass

@ -304,3 +304,33 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
remaining for the RPC to complete before it is considered to have remaining for the RPC to complete before it is considered to have
timed out, or None if no deadline was specified for the RPC. timed out, or None if no deadline was specified for the RPC.
""" """
def trailing_metadata(self):
"""Access value to be used as trailing metadata upon RPC completion.
This is an EXPERIMENTAL API.
Returns:
The trailing :term:`metadata` for the RPC.
"""
raise NotImplementedError()
def code(self):
"""Accesses the value to be used as status code upon RPC completion.
This is an EXPERIMENTAL API.
Returns:
The StatusCode value for the RPC.
"""
raise NotImplementedError()
def details(self):
"""Accesses the value to be used as detail string upon RPC completion.
This is an EXPERIMENTAL API.
Returns:
The details string of the RPC.
"""
raise NotImplementedError()

@ -60,6 +60,7 @@
"unit._invocation_defects_test.InvocationDefectsTest", "unit._invocation_defects_test.InvocationDefectsTest",
"unit._local_credentials_test.LocalCredentialsTest", "unit._local_credentials_test.LocalCredentialsTest",
"unit._logging_test.LoggingTest", "unit._logging_test.LoggingTest",
"unit._metadata_code_details_test.InspectContextTest",
"unit._metadata_code_details_test.MetadataCodeDetailsTest", "unit._metadata_code_details_test.MetadataCodeDetailsTest",
"unit._metadata_flags_test.MetadataFlagsTest", "unit._metadata_flags_test.MetadataFlagsTest",
"unit._metadata_test.MetadataTest", "unit._metadata_test.MetadataTest",

@ -658,6 +658,64 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self.assertEqual(_DETAILS, exception_context.exception.details()) self.assertEqual(_DETAILS, exception_context.exception.details())
class _InspectServicer(_Servicer):
def __init__(self):
super(_InspectServicer, self).__init__()
self.actual_code = None
self.actual_details = None
self.actual_trailing_metadata = None
def unary_unary(self, request, context):
super(_InspectServicer, self).unary_unary(request, context)
self.actual_code = context.code()
self.actual_details = context.details()
self.actual_trailing_metadata = context.trailing_metadata()
class InspectContextTest(unittest.TestCase):
def setUp(self):
self._servicer = _InspectServicer()
self._server = test_common.test_server()
self._server.add_generic_rpc_handlers(
(_generic_handler(self._servicer),))
port = self._server.add_insecure_port('[::]:0')
self._server.start()
self._channel = grpc.insecure_channel('localhost:{}'.format(port))
self._unary_unary = self._channel.unary_unary(
'/'.join((
'',
_SERVICE,
_UNARY_UNARY,
)),
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
)
def tearDown(self):
self._server.stop(None)
self._channel.close()
def testCodeDetailsInContext(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
with self.assertRaises(grpc.RpcError) as exc_info:
self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
err = exc_info.exception
self.assertEqual(_NON_OK_CODE, err.code())
self.assertEqual(self._servicer.actual_code, _NON_OK_CODE)
self.assertEqual(self._servicer.actual_details.decode('utf-8'),
_DETAILS)
self.assertEqual(self._servicer.actual_trailing_metadata,
_SERVER_TRAILING_METADATA)
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig() logging.basicConfig()
unittest.main(verbosity=2) unittest.main(verbosity=2)

@ -33,6 +33,7 @@ _TEST_GENERIC_HANDLER = '/test/TestGenericHandler'
_TEST_UNARY_STREAM = '/test/TestUnaryStream' _TEST_UNARY_STREAM = '/test/TestUnaryStream'
_TEST_STREAM_UNARY = '/test/TestStreamUnary' _TEST_STREAM_UNARY = '/test/TestStreamUnary'
_TEST_STREAM_STREAM = '/test/TestStreamStream' _TEST_STREAM_STREAM = '/test/TestStreamStream'
_TEST_INSPECT_CONTEXT = '/test/TestInspectContext'
_REQUEST = b'\x00\x00\x00' _REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01' _RESPONSE = b'\x01\x01\x01'
@ -75,6 +76,9 @@ _INVALID_METADATA_TEST_CASES = (
), ),
) )
_NON_OK_CODE = grpc.StatusCode.NOT_FOUND
_DETAILS = 'Test details!'
class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
@ -95,6 +99,8 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
grpc.stream_unary_rpc_method_handler(self._test_stream_unary), grpc.stream_unary_rpc_method_handler(self._test_stream_unary),
_TEST_STREAM_STREAM: _TEST_STREAM_STREAM:
grpc.stream_stream_rpc_method_handler(self._test_stream_stream), grpc.stream_stream_rpc_method_handler(self._test_stream_stream),
_TEST_INSPECT_CONTEXT:
grpc.unary_unary_rpc_method_handler(self._test_inspect_context),
} }
@staticmethod @staticmethod
@ -153,6 +159,19 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
yield _RESPONSE yield _RESPONSE
context.set_trailing_metadata(_TRAILING_METADATA) context.set_trailing_metadata(_TRAILING_METADATA)
@staticmethod
async def _test_inspect_context(request, context):
assert _REQUEST == request
context.set_code(_NON_OK_CODE)
context.set_details(_DETAILS)
context.set_trailing_metadata(_TRAILING_METADATA)
# ensure that we can read back the data we set on the context
assert context.get_code() == _NON_OK_CODE
assert context.get_details() == _DETAILS
assert context.get_trailing_metadata() == _TRAILING_METADATA
return _RESPONSE
def service(self, handler_call_details): def service(self, handler_call_details):
return self._routing_table.get(handler_call_details.method) return self._routing_table.get(handler_call_details.method)
@ -291,6 +310,15 @@ class TestMetadata(AioTestBase):
self.assertEqual(expected_sum, metadata_obj + aio.Metadata( self.assertEqual(expected_sum, metadata_obj + aio.Metadata(
('third', '3'))) ('third', '3')))
async def test_inspect_context(self):
multicallable = self._client.unary_unary(_TEST_INSPECT_CONTEXT)
call = multicallable(_REQUEST)
with self.assertRaises(grpc.RpcError) as exc_data:
await call
err = exc_data.exception
self.assertEqual(_NON_OK_CODE, err.code())
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save