Support echo status for interop test server

pull/21714/head
Lidi Zheng 5 years ago
parent 4842e23e9c
commit 9b26b410fb
  1. 13
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  2. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 2
      src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi
  4. 2
      src/python/grpcio_tests/commands.py
  5. 4
      src/python/grpcio_tests/tests_aio/interop/local_interop_test.py
  6. 9
      src/python/grpcio_tests/tests_aio/interop/methods.py
  7. 9
      src/python/grpcio_tests/tests_aio/unit/_test_server.py

@ -13,6 +13,19 @@
# limitations under the License. # limitations under the License.
cdef int get_status_code(object code) except *:
if isinstance(code, int):
if code >=0 and code < 15:
return code
else:
return StatusCode.unknown
else:
try:
return code.value[0]
except (KeyError, AttributeError):
return StatusCode.unknown
cdef object deserialize(object deserializer, bytes raw_message): cdef object deserialize(object deserializer, bytes raw_message):
"""Perform deserialization on raw bytes. """Perform deserialization on raw bytes.

@ -143,10 +143,12 @@ cdef class _ServicerContext:
if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata: if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata:
trailing_metadata = self._rpc_state.trailing_metadata trailing_metadata = self._rpc_state.trailing_metadata
actual_code = get_status_code(code)
self._rpc_state.status_sent = True self._rpc_state.status_sent = True
await _send_error_status_from_server( await _send_error_status_from_server(
self._rpc_state, self._rpc_state,
code.value[0], actual_code,
details, details,
trailing_metadata, trailing_metadata,
self._rpc_state.metadata_sent, self._rpc_state.metadata_sent,

@ -35,7 +35,7 @@ cdef bytes _encode(object string_or_none):
elif isinstance(string_or_none, (bytes,)): elif isinstance(string_or_none, (bytes,)):
return <bytes>string_or_none return <bytes>string_or_none
elif isinstance(string_or_none, (unicode,)): elif isinstance(string_or_none, (unicode,)):
return string_or_none.encode('ascii') return string_or_none.encode('utf8')
else: else:
raise TypeError('Expected str, not {}'.format(type(string_or_none))) raise TypeError('Expected str, not {}'.format(type(string_or_none)))

@ -237,7 +237,7 @@ class RunInterop(test.test):
('args=', None, 'pass-thru arguments for the client/server'), ('args=', None, 'pass-thru arguments for the client/server'),
('client', None, 'flag indicating to run the client'), ('client', None, 'flag indicating to run the client'),
('server', None, 'flag indicating to run the server'), ('server', None, 'flag indicating to run the server'),
('use_asyncio', None, 'flag indicating to run the asyncio stack') ('use-asyncio', None, 'flag indicating to run the asyncio stack')
] ]
def initialize_options(self): def initialize_options(self):

@ -67,6 +67,10 @@ class InteropTestCaseMixin:
await methods.test_interoperability( await methods.test_interoperability(
methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER, self._stub, None) methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER, self._stub, None)
async def test_special_status_message(self):
await methods.test_interoperability(
methods.TestCase.SPECIAL_STATUS_MESSAGE, self._stub, None)
class InsecureLocalInteropTest(InteropTestCaseMixin, AioTestBase): class InsecureLocalInteropTest(InteropTestCaseMixin, AioTestBase):

@ -167,6 +167,8 @@ async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None:
response = await call.read() response = await call.read()
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
response_size) response_size)
await call.done_writing()
await _validate_status_code_and_details(call, grpc.StatusCode.OK, '')
async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub): async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub):
@ -362,11 +364,9 @@ async def _per_rpc_creds(stub: test_pb2_grpc.TestServiceStub,
(wanted_email, response.username)) (wanted_email, response.username))
async def _special_status_message(stub: test_pb2_grpc.TestServiceStub, async def _special_status_message(stub: test_pb2_grpc.TestServiceStub):
args: argparse.Namespace):
details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode( details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode(
'utf-8') 'utf-8')
code = 2
status = grpc.StatusCode.UNKNOWN # code = 2 status = grpc.StatusCode.UNKNOWN # code = 2
# Test with a UnaryCall # Test with a UnaryCall
@ -374,7 +374,8 @@ async def _special_status_message(stub: test_pb2_grpc.TestServiceStub,
response_type=messages_pb2.COMPRESSABLE, response_type=messages_pb2.COMPRESSABLE,
response_size=1, response_size=1,
payload=messages_pb2.Payload(body=b'\x00'), payload=messages_pb2.Payload(body=b'\x00'),
response_status=messages_pb2.EchoStatus(code=code, message=details)) response_status=messages_pb2.EchoStatus(code=status.value[0],
message=details))
call = stub.UnaryCall(request) call = stub.UnaryCall(request)
await _validate_status_code_and_details(call, status, details) await _validate_status_code_and_details(call, status, details)

@ -39,10 +39,19 @@ async def _maybe_echo_metadata(servicer_context):
servicer_context.set_trailing_metadata((trailing_metadatum,)) servicer_context.set_trailing_metadata((trailing_metadatum,))
async def _maybe_echo_status(request: messages_pb2.SimpleRequest,
servicer_context):
"""Echos the RPC status if demanded by the request."""
if request.HasField('response_status'):
await servicer_context.abort(request.response_status.code,
request.response_status.message)
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
async def UnaryCall(self, request, context): async def UnaryCall(self, request, context):
await _maybe_echo_metadata(context) await _maybe_echo_metadata(context)
await _maybe_echo_status(request, context)
return messages_pb2.SimpleResponse( return messages_pb2.SimpleResponse(
payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE, payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE,
body=b'\x00' * request.response_size)) body=b'\x00' * request.response_size))

Loading…
Cancel
Save