From c045fd964559eed01e27f8f012cd74776b09ddf3 Mon Sep 17 00:00:00 2001 From: Masood Malekghassemi Date: Mon, 23 May 2016 12:05:54 -0700 Subject: [PATCH] Add a test for grpc/grpc#6522 --- .../grpcio/tests/unit/_cython/cygrpc_test.py | 284 ++++++++---------- .../tests/unit/_cython/test_utilities.py | 34 ++- 2 files changed, 153 insertions(+), 165 deletions(-) diff --git a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py index 0a511101f0d..4039c1b99bd 100644 --- a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py @@ -143,22 +143,60 @@ class TypeSmokeTest(unittest.TestCase): del completion_queue -class InsecureServerInsecureClient(unittest.TestCase): +class ServerClientMixin(object): - def setUp(self): + def setUpMixin(self, server_credentials, client_credentials, host_override): self.server_completion_queue = cygrpc.CompletionQueue() self.server = cygrpc.Server() self.server.register_completion_queue(self.server_completion_queue) - self.port = self.server.add_http2_port('[::]:0') + if server_credentials: + self.port = self.server.add_http2_port('[::]:0', server_credentials) + else: + self.port = self.server.add_http2_port('[::]:0') self.server.start() self.client_completion_queue = cygrpc.CompletionQueue() - self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port)) - - def tearDown(self): + if client_credentials: + client_channel_arguments = cygrpc.ChannelArgs([ + cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override, + host_override)]) + self.client_channel = cygrpc.Channel( + 'localhost:{}'.format(self.port), client_channel_arguments, + client_credentials) + else: + self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port)) + if host_override: + self.host_argument = None # default host + self.expected_host = host_override + else: + # arbitrary host name necessitating no further identification + self.host_argument = b'hostess' + self.expected_host = self.host_argument + + def tearDownMixin(self): del self.server del self.client_completion_queue del self.server_completion_queue + def _perform_operations(self, operations, call, queue, deadline, description): + """Perform the list of operations with given call, queue, and deadline. + + Invocation errors are reported with as an exception with `description` in + the message. Performs the operations asynchronously, returning a future. + """ + def performer(): + tag = object() + try: + call_result = call.start_batch(cygrpc.Operations(operations), tag) + self.assertEqual(cygrpc.CallError.ok, call_result) + event = queue.poll(deadline) + self.assertEqual(cygrpc.CompletionType.operation_complete, event.type) + self.assertTrue(event.success) + self.assertIs(tag, event.tag) + except Exception as error: + raise Exception("Error in '{}': {}".format(description, error.message)) + return event + return test_utilities.SimpleFuture(performer) + def testEcho(self): DEADLINE = time.time()+5 DEADLINE_TOLERANCE = 0.25 @@ -175,7 +213,6 @@ class InsecureServerInsecureClient(unittest.TestCase): REQUEST = b'in death a member of project mayhem has a name' RESPONSE = b'his name is robert paulson' METHOD = b'twinkies' - HOST = b'hostess' cygrpc_deadline = cygrpc.Timespec(DEADLINE) @@ -188,7 +225,8 @@ class InsecureServerInsecureClient(unittest.TestCase): client_call_tag = object() client_call = self.client_channel.create_call( - None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline) + None, 0, self.client_completion_queue, METHOD, self.host_argument, + cygrpc_deadline) client_initial_metadata = cygrpc.Metadata([ cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE), @@ -216,7 +254,8 @@ class InsecureServerInsecureClient(unittest.TestCase): test_common.metadata_transmitted(client_initial_metadata, request_event.request_metadata)) self.assertEqual(METHOD, request_event.request_call_details.method) - self.assertEqual(HOST, request_event.request_call_details.host) + self.assertEqual(self.expected_host, + request_event.request_call_details.host) self.assertLess( abs(DEADLINE - float(request_event.request_call_details.deadline)), DEADLINE_TOLERANCE) @@ -292,172 +331,101 @@ class InsecureServerInsecureClient(unittest.TestCase): del client_call del server_call - -class SecureServerSecureClient(unittest.TestCase): - - def setUp(self): - server_credentials = cygrpc.server_credentials_ssl( - None, [cygrpc.SslPemKeyCertPair(resources.private_key(), - resources.certificate_chain())], False) - channel_credentials = cygrpc.channel_credentials_ssl( - resources.test_root_certificates(), None) - self.server_completion_queue = cygrpc.CompletionQueue() - self.server = cygrpc.Server() - self.server.register_completion_queue(self.server_completion_queue) - self.port = self.server.add_http2_port('[::]:0', server_credentials) - self.server.start() - self.client_completion_queue = cygrpc.CompletionQueue() - client_channel_arguments = cygrpc.ChannelArgs([ - cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override, - _SSL_HOST_OVERRIDE)]) - self.client_channel = cygrpc.Channel( - 'localhost:{}'.format(self.port), client_channel_arguments, - channel_credentials) - - def tearDown(self): - del self.server - del self.client_completion_queue - del self.server_completion_queue - - def testEcho(self): + def test6522(self): DEADLINE = time.time()+5 DEADLINE_TOLERANCE = 0.25 - CLIENT_METADATA_ASCII_KEY = b'key' - CLIENT_METADATA_ASCII_VALUE = b'val' - CLIENT_METADATA_BIN_KEY = b'key-bin' - CLIENT_METADATA_BIN_VALUE = b'\0'*1000 - SERVER_INITIAL_METADATA_KEY = b'init_me_me_me' - SERVER_INITIAL_METADATA_VALUE = b'whodawha?' - SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought' - SERVER_TRAILING_METADATA_VALUE = b'zomg it is' - SERVER_STATUS_CODE = cygrpc.StatusCode.ok - SERVER_STATUS_DETAILS = b'our work is never over' - REQUEST = b'in death a member of project mayhem has a name' - RESPONSE = b'his name is robert paulson' - METHOD = b'/twinkies' - HOST = None # Default host + METHOD = b'twinkies' cygrpc_deadline = cygrpc.Timespec(DEADLINE) + empty_metadata = cygrpc.Metadata([]) server_request_tag = object() - request_call_result = self.server.request_call( + self.server.request_call( self.server_completion_queue, self.server_completion_queue, server_request_tag) + client_call = self.client_channel.create_call( + None, 0, self.client_completion_queue, METHOD, self.host_argument, + cygrpc_deadline) - self.assertEqual(cygrpc.CallError.ok, request_call_result) - - plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '') - call_credentials = cygrpc.call_credentials_metadata_plugin(plugin) + # Prologue + def perform_client_operations(operations, description): + return self._perform_operations( + operations, client_call, + self.client_completion_queue, cygrpc_deadline, description) - client_call_tag = object() - client_call = self.client_channel.create_call( - None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline) - client_call.set_credentials(call_credentials) - client_initial_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY, - CLIENT_METADATA_ASCII_VALUE), - cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)]) - client_start_batch_result = client_call.start_batch(cygrpc.Operations([ - cygrpc.operation_send_initial_metadata(client_initial_metadata, - _EMPTY_FLAGS), - cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS), - cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), - cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), - cygrpc.operation_receive_message(_EMPTY_FLAGS), - cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS) - ]), client_call_tag) - self.assertEqual(cygrpc.CallError.ok, client_start_batch_result) - client_event_future = test_utilities.CompletionQueuePollFuture( - self.client_completion_queue, cygrpc_deadline) + client_event_future = perform_client_operations([ + cygrpc.operation_send_initial_metadata(empty_metadata, + _EMPTY_FLAGS), + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + ], "Client prologue") request_event = self.server_completion_queue.poll(cygrpc_deadline) - self.assertEqual(cygrpc.CompletionType.operation_complete, - request_event.type) - self.assertIsInstance(request_event.operation_call, cygrpc.Call) - self.assertIs(server_request_tag, request_event.tag) - self.assertEqual(0, len(request_event.batch_operations)) - client_metadata_with_credentials = list(client_initial_metadata) + [ - (_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)] - self.assertTrue( - test_common.metadata_transmitted(client_metadata_with_credentials, - request_event.request_metadata)) - self.assertEqual(METHOD, request_event.request_call_details.method) - self.assertEqual(_SSL_HOST_OVERRIDE, - request_event.request_call_details.host) - self.assertLess( - abs(DEADLINE - float(request_event.request_call_details.deadline)), - DEADLINE_TOLERANCE) - - server_call_tag = object() server_call = request_event.operation_call - server_initial_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY, - SERVER_INITIAL_METADATA_VALUE)]) - server_trailing_metadata = cygrpc.Metadata([ - cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY, - SERVER_TRAILING_METADATA_VALUE)]) - server_start_batch_result = server_call.start_batch([ - cygrpc.operation_send_initial_metadata(server_initial_metadata, - _EMPTY_FLAGS), - cygrpc.operation_receive_message(_EMPTY_FLAGS), - cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS), - cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), - cygrpc.operation_send_status_from_server( - server_trailing_metadata, SERVER_STATUS_CODE, - SERVER_STATUS_DETAILS, _EMPTY_FLAGS) - ], server_call_tag) - self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) - client_event = client_event_future.result() - server_event = self.server_completion_queue.poll(cygrpc_deadline) + def perform_server_operations(operations, description): + return self._perform_operations( + operations, server_call, + self.server_completion_queue, cygrpc_deadline, description) - self.assertEqual(6, len(client_event.batch_operations)) - found_client_op_types = set() - for client_result in client_event.batch_operations: - # we expect each op type to be unique - self.assertNotIn(client_result.type, found_client_op_types) - found_client_op_types.add(client_result.type) - if client_result.type == cygrpc.OperationType.receive_initial_metadata: - self.assertTrue( - test_common.metadata_transmitted(server_initial_metadata, - client_result.received_metadata)) - elif client_result.type == cygrpc.OperationType.receive_message: - self.assertEqual(RESPONSE, client_result.received_message.bytes()) - elif client_result.type == cygrpc.OperationType.receive_status_on_client: - self.assertTrue( - test_common.metadata_transmitted(server_trailing_metadata, - client_result.received_metadata)) - self.assertEqual(SERVER_STATUS_DETAILS, - client_result.received_status_details) - self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code) - self.assertEqual(set([ - cygrpc.OperationType.send_initial_metadata, - cygrpc.OperationType.send_message, - cygrpc.OperationType.send_close_from_client, - cygrpc.OperationType.receive_initial_metadata, - cygrpc.OperationType.receive_message, - cygrpc.OperationType.receive_status_on_client - ]), found_client_op_types) + server_event_future = perform_server_operations([ + cygrpc.operation_send_initial_metadata(empty_metadata, + _EMPTY_FLAGS), + ], "Server prologue") - self.assertEqual(5, len(server_event.batch_operations)) - found_server_op_types = set() - for server_result in server_event.batch_operations: - self.assertNotIn(client_result.type, found_server_op_types) - found_server_op_types.add(server_result.type) - if server_result.type == cygrpc.OperationType.receive_message: - self.assertEqual(REQUEST, server_result.received_message.bytes()) - elif server_result.type == cygrpc.OperationType.receive_close_on_server: - self.assertFalse(server_result.received_cancelled) - self.assertEqual(set([ - cygrpc.OperationType.send_initial_metadata, - cygrpc.OperationType.receive_message, - cygrpc.OperationType.send_message, - cygrpc.OperationType.receive_close_on_server, - cygrpc.OperationType.send_status_from_server - ]), found_server_op_types) + client_event_future.result() # force completion + server_event_future.result() - del client_call - del server_call + # Messaging + for _ in range(10): + client_event_future = perform_client_operations([ + cygrpc.operation_send_message(b'', _EMPTY_FLAGS), + cygrpc.operation_receive_message(_EMPTY_FLAGS), + ], "Client message") + server_event_future = perform_server_operations([ + cygrpc.operation_send_message(b'', _EMPTY_FLAGS), + cygrpc.operation_receive_message(_EMPTY_FLAGS), + ], "Server receive") + + client_event_future.result() # force completion + server_event_future.result() + + # Epilogue + client_event_future = perform_client_operations([ + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS) + ], "Client epilogue") + + server_event_future = perform_server_operations([ + cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), + cygrpc.operation_send_status_from_server( + empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) + ], "Server epilogue") + + client_event_future.result() # force completion + server_event_future.result() + + +class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin): + + def setUp(self): + self.setUpMixin(None, None, None) + + def tearDown(self): + self.tearDownMixin() + + +class SecureServerSecureClient(unittest.TestCase, ServerClientMixin): + + def setUp(self): + server_credentials = cygrpc.server_credentials_ssl( + None, [cygrpc.SslPemKeyCertPair(resources.private_key(), + resources.certificate_chain())], False) + client_credentials = cygrpc.channel_credentials_ssl( + resources.test_root_certificates(), None) + self.setUpMixin(server_credentials, client_credentials, _SSL_HOST_OVERRIDE) + + def tearDown(self): + self.tearDownMixin() if __name__ == '__main__': diff --git a/src/python/grpcio/tests/unit/_cython/test_utilities.py b/src/python/grpcio/tests/unit/_cython/test_utilities.py index 6a097396431..6280ce74c47 100644 --- a/src/python/grpcio/tests/unit/_cython/test_utilities.py +++ b/src/python/grpcio/tests/unit/_cython/test_utilities.py @@ -32,15 +32,35 @@ import threading from grpc._cython import cygrpc -class CompletionQueuePollFuture: +class SimpleFuture(object): + """A simple future mechanism.""" - def __init__(self, completion_queue, deadline): - def poller_function(): - self._event_result = completion_queue.poll(deadline) - self._event_result = None - self._thread = threading.Thread(target=poller_function) + def __init__(self, function, *args, **kwargs): + def wrapped_function(): + try: + self._result = function(*args, **kwargs) + except Exception as error: + self._error = error + self._result = None + self._error = None + self._thread = threading.Thread(target=wrapped_function) self._thread.start() def result(self): + """The resulting value of this future. + + Re-raises any exceptions. + """ self._thread.join() - return self._event_result + if self._error: + # TODO(atash): re-raise exceptions in a way that preserves tracebacks + raise self._error + return self._result + + +class CompletionQueuePollFuture(SimpleFuture): + + def __init__(self, completion_queue, deadline): + super(CompletionQueuePollFuture, self).__init__( + lambda: completion_queue.poll(deadline)) +