Add cancel_all_calls to Python server

Also format _low_test.py to fit within the 80 character fill-limit and
re-style test assertions.
pull/2969/head
Masood Malekghassemi 10 years ago
parent 083b4d3de3
commit 41a166f97b
  1. 2
      src/python/grpcio/grpc/_adapter/_c/types.h
  2. 16
      src/python/grpcio/grpc/_adapter/_c/types/server.c
  3. 3
      src/python/grpcio/grpc/_adapter/_low.py
  4. 199
      src/python/grpcio_test/grpc_test/_adapter/_low_test.py

@ -146,6 +146,7 @@ typedef struct Server {
PyObject_HEAD
grpc_server *c_serv;
CompletionQueue *cq;
int shutdown_called;
} Server;
Server *pygrpc_Server_new(PyTypeObject *type, PyObject *args, PyObject *kwargs);
void pygrpc_Server_dealloc(Server *self);
@ -156,6 +157,7 @@ PyObject *pygrpc_Server_add_http2_port(
PyObject *pygrpc_Server_start(Server *self, PyObject *ignored);
PyObject *pygrpc_Server_shutdown(
Server *self, PyObject *args, PyObject *kwargs);
PyObject *pygrpc_Server_cancel_all_calls(Server *self, PyObject *unused);
extern PyTypeObject pygrpc_Server_type;
/*=========*/

@ -45,6 +45,8 @@ PyMethodDef pygrpc_Server_methods[] = {
METH_KEYWORDS, ""},
{"start", (PyCFunction)pygrpc_Server_start, METH_NOARGS, ""},
{"shutdown", (PyCFunction)pygrpc_Server_shutdown, METH_KEYWORDS, ""},
{"cancel_all_calls", (PyCFunction)pygrpc_Server_cancel_all_calls,
METH_NOARGS, ""},
{NULL}
};
const char pygrpc_Server_doc[] = "See grpc._adapter._types.Server.";
@ -109,6 +111,7 @@ Server *pygrpc_Server_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
pygrpc_discard_channel_args(c_args);
self->cq = cq;
Py_INCREF(self->cq);
self->shutdown_called = 0;
return self;
}
@ -163,6 +166,7 @@ PyObject *pygrpc_Server_add_http2_port(
PyObject *pygrpc_Server_start(Server *self, PyObject *ignored) {
grpc_server_start(self->c_serv);
self->shutdown_called = 0;
Py_RETURN_NONE;
}
@ -176,5 +180,17 @@ PyObject *pygrpc_Server_shutdown(
}
tag = pygrpc_produce_server_shutdown_tag(user_tag);
grpc_server_shutdown_and_notify(self->c_serv, self->cq->c_cq, tag);
self->shutdown_called = 1;
Py_RETURN_NONE;
}
PyObject *pygrpc_Server_cancel_all_calls(Server *self, PyObject *unused) {
if (!self->shutdown_called) {
PyErr_SetString(
PyExc_RuntimeError,
"shutdown must have been called prior to calling cancel_all_calls!");
return NULL;
}
grpc_server_cancel_all_calls(self->c_serv);
Py_RETURN_NONE;
}

@ -124,3 +124,6 @@ class Server(_types.Server):
def request_call(self, completion_queue, tag):
return self.server.request_call(completion_queue.completion_queue, tag)
def cancel_all_calls(self):
return self.server.cancel_all_calls()

@ -52,7 +52,6 @@ def wait_for_events(completion_queues, deadline):
def set_ith_result(i, completion_queue):
result = completion_queue.next(deadline)
with lock:
print i, completion_queue, result, time.time() - deadline
results[i] = result
for i, completion_queue in enumerate(completion_queues):
thread = threading.Thread(target=set_ith_result,
@ -80,10 +79,12 @@ class InsecureServerInsecureClient(unittest.TestCase):
del self.client_channel
self.client_completion_queue.shutdown()
while self.client_completion_queue.next().type != _types.EventType.QUEUE_SHUTDOWN:
while (self.client_completion_queue.next().type !=
_types.EventType.QUEUE_SHUTDOWN):
pass
self.server_completion_queue.shutdown()
while self.server_completion_queue.next().type != _types.EventType.QUEUE_SHUTDOWN:
while (self.server_completion_queue.next().type !=
_types.EventType.QUEUE_SHUTDOWN):
pass
del self.client_completion_queue
@ -91,58 +92,68 @@ class InsecureServerInsecureClient(unittest.TestCase):
del self.server
def testEcho(self):
DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25
CLIENT_METADATA_ASCII_KEY = 'key'
CLIENT_METADATA_ASCII_VALUE = 'val'
CLIENT_METADATA_BIN_KEY = 'key-bin'
CLIENT_METADATA_BIN_VALUE = b'\0'*1000
SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
SERVER_STATUS_CODE = _types.StatusCode.OK
SERVER_STATUS_DETAILS = 'our work is never over'
REQUEST = 'in death a member of project mayhem has a name'
RESPONSE = 'his name is robert paulson'
METHOD = 'twinkies'
HOST = 'hostess'
deadline = time.time() + 5
event_time_tolerance = 2
deadline_tolerance = 0.25
client_metadata_ascii_key = 'key'
client_metadata_ascii_value = 'val'
client_metadata_bin_key = 'key-bin'
client_metadata_bin_value = b'\0'*1000
server_initial_metadata_key = 'init_me_me_me'
server_initial_metadata_value = 'whodawha?'
server_trailing_metadata_key = 'california_is_in_a_drought'
server_trailing_metadata_value = 'zomg it is'
server_status_code = _types.StatusCode.OK
server_status_details = 'our work is never over'
request = 'blarghaflargh'
response = 'his name is robert paulson'
method = 'twinkies'
host = 'hostess'
server_request_tag = object()
request_call_result = self.server.request_call(self.server_completion_queue, server_request_tag)
request_call_result = self.server.request_call(self.server_completion_queue,
server_request_tag)
self.assertEquals(_types.CallError.OK, request_call_result)
self.assertEqual(_types.CallError.OK, request_call_result)
client_call_tag = object()
client_call = self.client_channel.create_call(self.client_completion_queue, METHOD, HOST, DEADLINE)
client_initial_metadata = [(CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE), (CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)]
client_call = self.client_channel.create_call(
self.client_completion_queue, method, host, deadline)
client_initial_metadata = [
(client_metadata_ascii_key, client_metadata_ascii_value),
(client_metadata_bin_key, client_metadata_bin_value)
]
client_start_batch_result = client_call.start_batch([
_types.OpArgs.send_initial_metadata(client_initial_metadata),
_types.OpArgs.send_message(REQUEST, 0),
_types.OpArgs.send_message(request, 0),
_types.OpArgs.send_close_from_client(),
_types.OpArgs.recv_initial_metadata(),
_types.OpArgs.recv_message(),
_types.OpArgs.recv_status_on_client()
], client_call_tag)
self.assertEquals(_types.CallError.OK, client_start_batch_result)
self.assertEqual(_types.CallError.OK, client_start_batch_result)
client_no_event, request_event, = wait_for_events([self.client_completion_queue, self.server_completion_queue], time.time() + 2)
self.assertEquals(client_no_event, None)
self.assertEquals(_types.EventType.OP_COMPLETE, request_event.type)
client_no_event, request_event, = wait_for_events(
[self.client_completion_queue, self.server_completion_queue],
time.time() + event_time_tolerance)
self.assertEqual(client_no_event, None)
self.assertEqual(_types.EventType.OP_COMPLETE, request_event.type)
self.assertIsInstance(request_event.call, _low.Call)
self.assertIs(server_request_tag, request_event.tag)
self.assertEquals(1, len(request_event.results))
self.assertEqual(1, len(request_event.results))
received_initial_metadata = dict(request_event.results[0].initial_metadata)
# Check that our metadata were transmitted
self.assertEquals(
self.assertEqual(
dict(client_initial_metadata),
dict((x, received_initial_metadata[x]) for x in zip(*client_initial_metadata)[0]))
dict((x, received_initial_metadata[x])
for x in zip(*client_initial_metadata)[0]))
# Check that Python's user agent string is a part of the full user agent
# string
self.assertIn('Python-gRPC-{}'.format(_grpcio_metadata.__version__),
received_initial_metadata['user-agent'])
self.assertEquals(METHOD, request_event.call_details.method)
self.assertEquals(HOST, request_event.call_details.host)
self.assertLess(abs(DEADLINE - request_event.call_details.deadline), DEADLINE_TOLERANCE)
self.assertEqual(method, request_event.call_details.method)
self.assertEqual(host, request_event.call_details.host)
self.assertLess(abs(deadline - request_event.call_details.deadline),
deadline_tolerance)
# Check that the channel is connected, and that both it and the call have
# the proper target and peer; do this after the first flurry of messages to
@ -155,33 +166,43 @@ class InsecureServerInsecureClient(unittest.TestCase):
server_call_tag = object()
server_call = request_event.call
server_initial_metadata = [(SERVER_INITIAL_METADATA_KEY, SERVER_INITIAL_METADATA_VALUE)]
server_trailing_metadata = [(SERVER_TRAILING_METADATA_KEY, SERVER_TRAILING_METADATA_VALUE)]
server_initial_metadata = [
(server_initial_metadata_key, server_initial_metadata_value)
]
server_trailing_metadata = [
(server_trailing_metadata_key, server_trailing_metadata_value)
]
server_start_batch_result = server_call.start_batch([
_types.OpArgs.send_initial_metadata(server_initial_metadata),
_types.OpArgs.recv_message(),
_types.OpArgs.send_message(RESPONSE, 0),
_types.OpArgs.send_message(response, 0),
_types.OpArgs.recv_close_on_server(),
_types.OpArgs.send_status_from_server(server_trailing_metadata, SERVER_STATUS_CODE, SERVER_STATUS_DETAILS)
_types.OpArgs.send_status_from_server(
server_trailing_metadata, server_status_code, server_status_details)
], server_call_tag)
self.assertEquals(_types.CallError.OK, server_start_batch_result)
self.assertEqual(_types.CallError.OK, server_start_batch_result)
client_event, server_event, = wait_for_events([self.client_completion_queue, self.server_completion_queue], time.time() + 1)
client_event, server_event, = wait_for_events(
[self.client_completion_queue, self.server_completion_queue],
time.time() + event_time_tolerance)
self.assertEquals(6, len(client_event.results))
self.assertEqual(6, len(client_event.results))
found_client_op_types = set()
for client_result in client_event.results:
self.assertNotIn(client_result.type, found_client_op_types) # we expect each op type to be unique
# we expect each op type to be unique
self.assertNotIn(client_result.type, found_client_op_types)
found_client_op_types.add(client_result.type)
if client_result.type == _types.OpType.RECV_INITIAL_METADATA:
self.assertEquals(dict(server_initial_metadata), dict(client_result.initial_metadata))
self.assertEqual(dict(server_initial_metadata),
dict(client_result.initial_metadata))
elif client_result.type == _types.OpType.RECV_MESSAGE:
self.assertEquals(RESPONSE, client_result.message)
self.assertEqual(response, client_result.message)
elif client_result.type == _types.OpType.RECV_STATUS_ON_CLIENT:
self.assertEquals(dict(server_trailing_metadata), dict(client_result.trailing_metadata))
self.assertEquals(SERVER_STATUS_DETAILS, client_result.status.details)
self.assertEquals(SERVER_STATUS_CODE, client_result.status.code)
self.assertEquals(set([
self.assertEqual(dict(server_trailing_metadata),
dict(client_result.trailing_metadata))
self.assertEqual(server_status_details, client_result.status.details)
self.assertEqual(server_status_code, client_result.status.code)
self.assertEqual(set([
_types.OpType.SEND_INITIAL_METADATA,
_types.OpType.SEND_MESSAGE,
_types.OpType.SEND_CLOSE_FROM_CLIENT,
@ -190,16 +211,16 @@ class InsecureServerInsecureClient(unittest.TestCase):
_types.OpType.RECV_STATUS_ON_CLIENT
]), found_client_op_types)
self.assertEquals(5, len(server_event.results))
self.assertEqual(5, len(server_event.results))
found_server_op_types = set()
for server_result in server_event.results:
self.assertNotIn(client_result.type, found_server_op_types)
found_server_op_types.add(server_result.type)
if server_result.type == _types.OpType.RECV_MESSAGE:
self.assertEquals(REQUEST, server_result.message)
self.assertEqual(request, server_result.message)
elif server_result.type == _types.OpType.RECV_CLOSE_ON_SERVER:
self.assertFalse(server_result.cancelled)
self.assertEquals(set([
self.assertEqual(set([
_types.OpType.SEND_INITIAL_METADATA,
_types.OpType.RECV_MESSAGE,
_types.OpType.SEND_MESSAGE,
@ -211,5 +232,81 @@ class InsecureServerInsecureClient(unittest.TestCase):
del server_call
class HangingServerShutdown(unittest.TestCase):
def setUp(self):
self.server_completion_queue = _low.CompletionQueue()
self.server = _low.Server(self.server_completion_queue, [])
self.port = self.server.add_http2_port('[::]:0')
self.client_completion_queue = _low.CompletionQueue()
self.client_channel = _low.Channel('localhost:%d'%self.port, [])
self.server.start()
def tearDown(self):
self.server.shutdown()
del self.client_channel
self.client_completion_queue.shutdown()
self.server_completion_queue.shutdown()
while True:
client_event, server_event = wait_for_events(
[self.client_completion_queue, self.server_completion_queue],
float("+inf"))
if (client_event.type == _types.EventType.QUEUE_SHUTDOWN and
server_event.type == _types.EventType.QUEUE_SHUTDOWN):
break
del self.client_completion_queue
del self.server_completion_queue
del self.server
def testHangingServerCall(self):
deadline = time.time() + 5
deadline_tolerance = 0.25
event_time_tolerance = 2
cancel_all_calls_time_tolerance = 0.5
request = 'blarghaflargh'
method = 'twinkies'
host = 'hostess'
server_request_tag = object()
request_call_result = self.server.request_call(self.server_completion_queue,
server_request_tag)
client_call_tag = object()
client_call = self.client_channel.create_call(self.client_completion_queue,
method, host, deadline)
client_start_batch_result = client_call.start_batch([
_types.OpArgs.send_initial_metadata([]),
_types.OpArgs.send_message(request, 0),
_types.OpArgs.send_close_from_client(),
_types.OpArgs.recv_initial_metadata(),
_types.OpArgs.recv_message(),
_types.OpArgs.recv_status_on_client()
], client_call_tag)
client_no_event, request_event, = wait_for_events(
[self.client_completion_queue, self.server_completion_queue],
time.time() + event_time_tolerance)
# Now try to shutdown the server and expect that we see server shutdown
# almost immediately after calling cancel_all_calls.
with self.assertRaises(RuntimeError):
self.server.cancel_all_calls()
shutdown_tag = object()
self.server.shutdown(shutdown_tag)
pre_cancel_timestamp = time.time()
self.server.cancel_all_calls()
finish_shutdown_timestamp = None
client_call_event, server_shutdown_event = wait_for_events(
[self.client_completion_queue, self.server_completion_queue],
time.time() + event_time_tolerance)
self.assertIs(shutdown_tag, server_shutdown_event.tag)
self.assertGreater(pre_cancel_timestamp + cancel_all_calls_time_tolerance,
time.time())
del client_call
if __name__ == '__main__':
unittest.main(verbosity=2)

Loading…
Cancel
Save