Merge pull request #6751 from soltanmm/6522

Add a test for grpc/grpc#6522
pull/6858/head
Jan Tattermusch 9 years ago committed by GitHub
commit 24cb78d8f4
  1. 284
      src/python/grpcio/tests/unit/_cython/cygrpc_test.py
  2. 34
      src/python/grpcio/tests/unit/_cython/test_utilities.py

@ -143,22 +143,60 @@ class TypeSmokeTest(unittest.TestCase):
del completion_queue
class InsecureServerInsecureClient(unittest.TestCase):
class ServerClientMixin(object):
def setUp(self):
def setUpMixin(self, server_credentials, client_credentials, host_override):
self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server()
self.server.register_completion_queue(self.server_completion_queue)
self.port = self.server.add_http2_port('[::]:0')
if server_credentials:
self.port = self.server.add_http2_port('[::]:0', server_credentials)
else:
self.port = self.server.add_http2_port('[::]:0')
self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue()
self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port))
def tearDown(self):
if client_credentials:
client_channel_arguments = cygrpc.ChannelArgs([
cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
host_override)])
self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port), client_channel_arguments,
client_credentials)
else:
self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port))
if host_override:
self.host_argument = None # default host
self.expected_host = host_override
else:
# arbitrary host name necessitating no further identification
self.host_argument = b'hostess'
self.expected_host = self.host_argument
def tearDownMixin(self):
del self.server
del self.client_completion_queue
del self.server_completion_queue
def _perform_operations(self, operations, call, queue, deadline, description):
"""Perform the list of operations with given call, queue, and deadline.
Invocation errors are reported with as an exception with `description` in
the message. Performs the operations asynchronously, returning a future.
"""
def performer():
tag = object()
try:
call_result = call.start_batch(cygrpc.Operations(operations), tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
self.assertTrue(event.success)
self.assertIs(tag, event.tag)
except Exception as error:
raise Exception("Error in '{}': {}".format(description, error.message))
return event
return test_utilities.SimpleFuture(performer)
def testEcho(self):
DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25
@ -175,7 +213,6 @@ class InsecureServerInsecureClient(unittest.TestCase):
REQUEST = b'in death a member of project mayhem has a name'
RESPONSE = b'his name is robert paulson'
METHOD = b'twinkies'
HOST = b'hostess'
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
@ -188,7 +225,8 @@ class InsecureServerInsecureClient(unittest.TestCase):
client_call_tag = object()
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline)
None, 0, self.client_completion_queue, METHOD, self.host_argument,
cygrpc_deadline)
client_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
@ -216,7 +254,8 @@ class InsecureServerInsecureClient(unittest.TestCase):
test_common.metadata_transmitted(client_initial_metadata,
request_event.request_metadata))
self.assertEqual(METHOD, request_event.request_call_details.method)
self.assertEqual(HOST, request_event.request_call_details.host)
self.assertEqual(self.expected_host,
request_event.request_call_details.host)
self.assertLess(
abs(DEADLINE - float(request_event.request_call_details.deadline)),
DEADLINE_TOLERANCE)
@ -292,172 +331,101 @@ class InsecureServerInsecureClient(unittest.TestCase):
del client_call
del server_call
class SecureServerSecureClient(unittest.TestCase):
def setUp(self):
server_credentials = cygrpc.server_credentials_ssl(
None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
resources.certificate_chain())], False)
channel_credentials = cygrpc.channel_credentials_ssl(
resources.test_root_certificates(), None)
self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server()
self.server.register_completion_queue(self.server_completion_queue)
self.port = self.server.add_http2_port('[::]:0', server_credentials)
self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue()
client_channel_arguments = cygrpc.ChannelArgs([
cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
_SSL_HOST_OVERRIDE)])
self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port), client_channel_arguments,
channel_credentials)
def tearDown(self):
del self.server
del self.client_completion_queue
del self.server_completion_queue
def testEcho(self):
def test6522(self):
DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25
CLIENT_METADATA_ASCII_KEY = b'key'
CLIENT_METADATA_ASCII_VALUE = b'val'
CLIENT_METADATA_BIN_KEY = b'key-bin'
CLIENT_METADATA_BIN_VALUE = b'\0'*1000
SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
SERVER_STATUS_CODE = cygrpc.StatusCode.ok
SERVER_STATUS_DETAILS = b'our work is never over'
REQUEST = b'in death a member of project mayhem has a name'
RESPONSE = b'his name is robert paulson'
METHOD = b'/twinkies'
HOST = None # Default host
METHOD = b'twinkies'
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
empty_metadata = cygrpc.Metadata([])
server_request_tag = object()
request_call_result = self.server.request_call(
self.server.request_call(
self.server_completion_queue, self.server_completion_queue,
server_request_tag)
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
cygrpc_deadline)
self.assertEqual(cygrpc.CallError.ok, request_call_result)
plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '')
call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
# Prologue
def perform_client_operations(operations, description):
return self._perform_operations(
operations, client_call,
self.client_completion_queue, cygrpc_deadline, description)
client_call_tag = object()
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline)
client_call.set_credentials(call_credentials)
client_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
client_start_batch_result = client_call.start_batch(cygrpc.Operations([
cygrpc.operation_send_initial_metadata(client_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
]), client_call_tag)
self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
client_event_future = test_utilities.CompletionQueuePollFuture(
self.client_completion_queue, cygrpc_deadline)
client_event_future = perform_client_operations([
cygrpc.operation_send_initial_metadata(empty_metadata,
_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
], "Client prologue")
request_event = self.server_completion_queue.poll(cygrpc_deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete,
request_event.type)
self.assertIsInstance(request_event.operation_call, cygrpc.Call)
self.assertIs(server_request_tag, request_event.tag)
self.assertEqual(0, len(request_event.batch_operations))
client_metadata_with_credentials = list(client_initial_metadata) + [
(_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)]
self.assertTrue(
test_common.metadata_transmitted(client_metadata_with_credentials,
request_event.request_metadata))
self.assertEqual(METHOD, request_event.request_call_details.method)
self.assertEqual(_SSL_HOST_OVERRIDE,
request_event.request_call_details.host)
self.assertLess(
abs(DEADLINE - float(request_event.request_call_details.deadline)),
DEADLINE_TOLERANCE)
server_call_tag = object()
server_call = request_event.operation_call
server_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
SERVER_INITIAL_METADATA_VALUE)])
server_trailing_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE)])
server_start_batch_result = server_call.start_batch([
cygrpc.operation_send_initial_metadata(server_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
server_trailing_metadata, SERVER_STATUS_CODE,
SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
], server_call_tag)
self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
client_event = client_event_future.result()
server_event = self.server_completion_queue.poll(cygrpc_deadline)
def perform_server_operations(operations, description):
return self._perform_operations(
operations, server_call,
self.server_completion_queue, cygrpc_deadline, description)
self.assertEqual(6, len(client_event.batch_operations))
found_client_op_types = set()
for client_result in client_event.batch_operations:
# we expect each op type to be unique
self.assertNotIn(client_result.type, found_client_op_types)
found_client_op_types.add(client_result.type)
if client_result.type == cygrpc.OperationType.receive_initial_metadata:
self.assertTrue(
test_common.metadata_transmitted(server_initial_metadata,
client_result.received_metadata))
elif client_result.type == cygrpc.OperationType.receive_message:
self.assertEqual(RESPONSE, client_result.received_message.bytes())
elif client_result.type == cygrpc.OperationType.receive_status_on_client:
self.assertTrue(
test_common.metadata_transmitted(server_trailing_metadata,
client_result.received_metadata))
self.assertEqual(SERVER_STATUS_DETAILS,
client_result.received_status_details)
self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code)
self.assertEqual(set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_message,
cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client
]), found_client_op_types)
server_event_future = perform_server_operations([
cygrpc.operation_send_initial_metadata(empty_metadata,
_EMPTY_FLAGS),
], "Server prologue")
self.assertEqual(5, len(server_event.batch_operations))
found_server_op_types = set()
for server_result in server_event.batch_operations:
self.assertNotIn(client_result.type, found_server_op_types)
found_server_op_types.add(server_result.type)
if server_result.type == cygrpc.OperationType.receive_message:
self.assertEqual(REQUEST, server_result.received_message.bytes())
elif server_result.type == cygrpc.OperationType.receive_close_on_server:
self.assertFalse(server_result.received_cancelled)
self.assertEqual(set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.send_message,
cygrpc.OperationType.receive_close_on_server,
cygrpc.OperationType.send_status_from_server
]), found_server_op_types)
client_event_future.result() # force completion
server_event_future.result()
del client_call
del server_call
# Messaging
for _ in range(10):
client_event_future = perform_client_operations([
cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
], "Client message")
server_event_future = perform_server_operations([
cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
], "Server receive")
client_event_future.result() # force completion
server_event_future.result()
# Epilogue
client_event_future = perform_client_operations([
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
], "Client epilogue")
server_event_future = perform_server_operations([
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
], "Server epilogue")
client_event_future.result() # force completion
server_event_future.result()
class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
def setUp(self):
self.setUpMixin(None, None, None)
def tearDown(self):
self.tearDownMixin()
class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
def setUp(self):
server_credentials = cygrpc.server_credentials_ssl(
None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
resources.certificate_chain())], False)
client_credentials = cygrpc.channel_credentials_ssl(
resources.test_root_certificates(), None)
self.setUpMixin(server_credentials, client_credentials, _SSL_HOST_OVERRIDE)
def tearDown(self):
self.tearDownMixin()
if __name__ == '__main__':

@ -32,15 +32,35 @@ import threading
from grpc._cython import cygrpc
class CompletionQueuePollFuture:
class SimpleFuture(object):
"""A simple future mechanism."""
def __init__(self, completion_queue, deadline):
def poller_function():
self._event_result = completion_queue.poll(deadline)
self._event_result = None
self._thread = threading.Thread(target=poller_function)
def __init__(self, function, *args, **kwargs):
def wrapped_function():
try:
self._result = function(*args, **kwargs)
except Exception as error:
self._error = error
self._result = None
self._error = None
self._thread = threading.Thread(target=wrapped_function)
self._thread.start()
def result(self):
"""The resulting value of this future.
Re-raises any exceptions.
"""
self._thread.join()
return self._event_result
if self._error:
# TODO(atash): re-raise exceptions in a way that preserves tracebacks
raise self._error
return self._result
class CompletionQueuePollFuture(SimpleFuture):
def __init__(self, completion_queue, deadline):
super(CompletionQueuePollFuture, self).__init__(
lambda: completion_queue.poll(deadline))

Loading…
Cancel
Save