Removed cython client-side call tracking

This ensures sync calls get cancelled after
a keyboard interrupt, as well as all calls
getting destroyed before grpc_shutdown()
pull/7275/head
Ken Payson 9 years ago
parent 2cbe754285
commit ea1b16f82f
  1. 27
      src/python/grpcio/grpc/_channel.py
  2. 15
      src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
  3. 6
      src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
  4. 16
      src/python/grpcio/grpc/_server.py
  5. 8
      src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
  6. 22
      src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
  7. 9
      src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py

@ -195,7 +195,8 @@ def _consume_request_iterator(
cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
call.start_client_batch(cygrpc.Operations(operations),
event_handler)
state.due.add(cygrpc.OperationType.send_message)
while True:
state.condition.wait()
@ -211,7 +212,7 @@ def _consume_request_iterator(
operations = (
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
call.start_client_batch(cygrpc.Operations(operations), event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
def stop_consumption_thread(timeout):
@ -312,7 +313,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
if self._state.code is None:
event_handler = _event_handler(
self._state, self._call, self._response_deserializer)
self._call.start_batch(
self._call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
event_handler)
@ -471,7 +472,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
None, 0, completion_queue, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
call.start_batch(cygrpc.Operations(operations), None)
call.start_client_batch(cygrpc.Operations(operations), None)
_handle_event(completion_queue.poll(), state, self._response_deserializer)
return state, deadline
@ -495,7 +496,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_batch(cygrpc.Operations(operations), event_handler)
call.start_client_batch(cygrpc.Operations(operations), event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@ -523,7 +524,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_batch(
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
@ -534,7 +535,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
call.start_client_batch(cygrpc.Operations(operations), event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@ -558,7 +559,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
if credentials is not None:
call.set_credentials(credentials._credentials)
with state.condition:
call.start_batch(
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
None)
@ -568,7 +569,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), None)
call.start_client_batch(cygrpc.Operations(operations), None)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
while True:
@ -602,7 +603,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_batch(
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
@ -612,7 +613,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
call.start_client_batch(cygrpc.Operations(operations), event_handler)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@ -639,7 +640,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_batch(
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
@ -648,7 +649,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
_common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
call.start_client_batch(cygrpc.Operations(operations), event_handler)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
return _Rendezvous(state, call, self._response_deserializer, deadline)

@ -37,13 +37,16 @@ cdef class Call:
self.c_call = NULL
self.references = []
def start_batch(self, operations, tag):
def _start_batch(self, operations, tag, retain_self):
if not self.is_valid:
raise ValueError("invalid call object cannot be used from Python")
cdef grpc_call_error result
cdef Operations cy_operations = Operations(operations)
cdef OperationTag operation_tag = OperationTag(tag)
operation_tag.operation_call = self
if retain_self:
operation_tag.operation_call = self
else:
operation_tag.operation_call = None
operation_tag.batch_operations = cy_operations
cpython.Py_INCREF(operation_tag)
with nogil:
@ -52,6 +55,14 @@ cdef class Call:
<cpython.PyObject *>operation_tag, NULL)
return result
def start_client_batch(self, operations, tag):
# We don't reference this call in the operations tag because
# it should be cancelled when it goes out of scope
return self._start_batch(operations, tag, False)
def start_server_batch(self, operations, tag):
return self._start_batch(operations, tag, True)
def cancel(
self, grpc_status_code error_code=GRPC_STATUS__DO_NOT_USE,
details=None):

@ -58,14 +58,14 @@ cdef class Event:
cdef readonly bint success
cdef readonly object tag
# For operations with calls
cdef readonly Call operation_call
# For Server.request_call
cdef readonly bint is_new_request
cdef readonly CallDetails request_call_details
cdef readonly Metadata request_metadata
# For server calls
cdef readonly Call operation_call
# For Call.start_batch
cdef readonly Operations batch_operations

@ -157,7 +157,7 @@ def _abort(state, call, code, details):
effective_details, _EMPTY_FLAGS),
)
token = _SEND_STATUS_FROM_SERVER_TOKEN
call.start_batch(
call.start_server_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, token))
state.statused = True
@ -257,7 +257,7 @@ class _Context(grpc.ServicerContext):
if self._state.initial_metadata_allowed:
operation = cygrpc.operation_send_initial_metadata(
_common.cygrpc_metadata(initial_metadata), _EMPTY_FLAGS)
self._rpc_event.operation_call.start_batch(
self._rpc_event.operation_call.start_server_batch(
cygrpc.Operations((operation,)),
_send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False
@ -292,7 +292,7 @@ class _RequestIterator(object):
elif self._state.client is _CLOSED or self._state.statused:
raise StopIteration()
else:
self._call.start_batch(
self._call.start_server_batch(
cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(self._state, self._call, self._request_deserializer))
self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
@ -333,7 +333,7 @@ def _unary_request(rpc_event, state, request_deserializer):
if state.client is _CANCELLED or state.statused:
return None
else:
start_batch_result = rpc_event.operation_call.start_batch(
start_server_batch_result = rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(
@ -417,7 +417,7 @@ def _send_response(rpc_event, state, serialized_response):
cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS),
)
token = _SEND_MESSAGE_TOKEN
rpc_event.operation_call.start_batch(
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations), _send_message(state, token))
state.due.add(token)
while True:
@ -443,7 +443,7 @@ def _status(rpc_event, state, serialized_response):
if serialized_response is not None:
operations.append(cygrpc.operation_send_message(
serialized_response, _EMPTY_FLAGS))
rpc_event.operation_call.start_batch(
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
state.statused = True
@ -550,7 +550,7 @@ def _handle_unrecognized_method(rpc_event):
b'Method not found!', _EMPTY_FLAGS),
)
rpc_state = _RPCState()
rpc_event.operation_call.start_batch(
rpc_event.operation_call.start_server_batch(
operations, lambda ignored_event: (rpc_state, (),))
return rpc_state
@ -558,7 +558,7 @@ def _handle_unrecognized_method(rpc_event):
def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state = _RPCState()
with state.condition:
rpc_event.operation_call.start_batch(
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
_receive_close_on_server(state))

@ -81,11 +81,11 @@ class _Handler(object):
self._state.condition.wait()
with self._lock:
self._call.start_batch(
self._call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
_RECEIVE_CLOSE_ON_SERVER_TAG)
self._call.start_batch(
self._call.start_server_batch(
cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_RECEIVE_MESSAGE_TAG)
first_event = self._completion_queue.poll()
@ -101,7 +101,7 @@ class _Handler(object):
_EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
_EMPTY_FLAGS),
)
self._call.start_batch(
self._call.start_server_batch(
cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
self._completion_queue.poll()
self._completion_queue.poll()
@ -193,7 +193,7 @@ class CancelManyCallsTest(unittest.TestCase):
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
tag = 'client_complete_call_{0:04d}_tag'.format(index)
client_call.start_batch(cygrpc.Operations(operations), tag)
client_call.start_client_batch(cygrpc.Operations(operations), tag)
client_due.add(tag)
client_calls.append(client_call)

@ -168,12 +168,12 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_batch(cygrpc.Operations([
client_call.start_client_batch(cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_due.add(client_receive_initial_metadata_tag)
client_complete_rpc_start_batch_result = (
client_call.start_batch(cygrpc.Operations([
client_call.start_client_batch(cygrpc.Operations([
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
@ -185,30 +185,30 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
with server_call_condition:
server_send_initial_metadata_start_batch_result = (
server_rpc_event.operation_call.start_batch(cygrpc.Operations([
server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
]), server_send_initial_metadata_tag))
], server_send_initial_metadata_tag))
server_send_first_message_start_batch_result = (
server_rpc_event.operation_call.start_batch(cygrpc.Operations([
server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
]), server_send_first_message_tag))
], server_send_first_message_tag))
server_send_initial_metadata_event = server_call_driver.event_with_tag(
server_send_initial_metadata_tag)
server_send_first_message_event = server_call_driver.event_with_tag(
server_send_first_message_tag)
with server_call_condition:
server_send_second_message_start_batch_result = (
server_rpc_event.operation_call.start_batch(cygrpc.Operations([
server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
]), server_send_second_message_tag))
], server_send_second_message_tag))
server_complete_rpc_start_batch_result = (
server_rpc_event.operation_call.start_batch(cygrpc.Operations([
server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details',
_EMPTY_FLAGS),
]), server_complete_rpc_tag))
], server_complete_rpc_tag))
server_send_second_message_event = server_call_driver.event_with_tag(
server_send_second_message_tag)
server_complete_rpc_event = server_call_driver.event_with_tag(
@ -218,7 +218,7 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
with client_condition:
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_receive_first_message_start_batch_result = (
client_call.start_batch(cygrpc.Operations([
client_call.start_client_batch(cygrpc.Operations([
cygrpc.operation_receive_message(_EMPTY_FLAGS),
]), client_receive_first_message_tag))
client_due.add(client_receive_first_message_tag)

@ -186,7 +186,8 @@ class ServerClientMixin(object):
def performer():
tag = object()
try:
call_result = call.start_batch(cygrpc.Operations(operations), tag)
call_result = call.start_client_batch(
cygrpc.Operations(operations), tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
@ -231,7 +232,7 @@ class ServerClientMixin(object):
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
client_start_batch_result = client_call.start_batch(cygrpc.Operations([
client_start_batch_result = client_call.start_client_batch([
cygrpc.operation_send_initial_metadata(client_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
@ -239,7 +240,7 @@ class ServerClientMixin(object):
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
]), client_call_tag)
], client_call_tag)
self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
client_event_future = test_utilities.CompletionQueuePollFuture(
self.client_completion_queue, cygrpc_deadline)
@ -268,7 +269,7 @@ class ServerClientMixin(object):
server_trailing_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE)])
server_start_batch_result = server_call.start_batch([
server_start_batch_result = server_call.start_server_batch([
cygrpc.operation_send_initial_metadata(server_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),

Loading…
Cancel
Save