Merge pull request #13688 from nathanielmanistaatgoogle/12531

Elide cygrpc.Operations.
pull/13667/head
Nathaniel Manista 7 years ago committed by GitHub
commit d14035c55a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 41
      src/python/grpcio/grpc/_channel.py
  2. 12
      src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
  3. 2
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  4. 4
      src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
  5. 16
      src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
  6. 74
      src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
  7. 5
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  8. 21
      src/python/grpcio/grpc/_server.py
  9. 13
      src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
  10. 13
      src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py
  11. 13
      src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py
  12. 27
      src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
  13. 14
      src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py

@ -202,8 +202,7 @@ def _consume_request_iterator(request_iterator, state, call,
else:
operations = (cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),)
call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call.start_client_batch(operations, event_handler)
state.due.add(cygrpc.OperationType.send_message)
while True:
state.condition.wait()
@ -218,8 +217,7 @@ def _consume_request_iterator(request_iterator, state, call,
if state.code is None:
operations = (
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),)
call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call.start_client_batch(operations, event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
def stop_consumption_thread(timeout): # pylint: disable=unused-argument
@ -321,8 +319,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
event_handler = _event_handler(self._state, self._call,
self._response_deserializer)
self._call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_message(_EMPTY_FLAGS),),
event_handler)
self._state.due.add(cygrpc.OperationType.receive_message)
elif self._state.code is grpc.StatusCode.OK:
@ -476,8 +473,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
call_error = call.start_client_batch(
cygrpc.Operations(operations), None)
call_error = call.start_client_batch(operations, None)
_check_call_error(call_error, metadata)
_handle_event(completion_queue.poll(), state,
self._response_deserializer)
@ -506,8 +502,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
event_handler = _event_handler(state, call,
self._response_deserializer)
with state.condition:
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
@ -541,17 +536,15 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._response_deserializer)
with state.condition:
call.start_client_batch(
cygrpc.Operations((
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
)), event_handler)
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
metadata, _EMPTY_FLAGS), cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
@ -580,15 +573,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
call.set_credentials(credentials._credentials)
with state.condition:
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),),
None)
operations = (
cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), None)
call_error = call.start_client_batch(operations, None)
_check_call_error(call_error, metadata)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer)
@ -633,15 +624,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
@ -675,14 +664,12 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)

@ -26,20 +26,16 @@ cdef class Call:
def _start_batch(self, operations, tag, retain_self):
if not self.is_valid:
raise ValueError("invalid call object cannot be used from Python")
cdef grpc_call_error result
cdef Operations cy_operations = Operations(operations)
cdef OperationTag operation_tag = OperationTag(tag)
cdef OperationTag operation_tag = OperationTag(tag, operations)
if retain_self:
operation_tag.operation_call = self
else:
operation_tag.operation_call = None
operation_tag.batch_operations = cy_operations
operation_tag.store_ops()
cpython.Py_INCREF(operation_tag)
with nogil:
result = grpc_call_start_batch(
self.c_call, cy_operations.c_ops, cy_operations.c_nops,
return grpc_call_start_batch(
self.c_call, operation_tag.c_ops, operation_tag.c_nops,
<cpython.PyObject *>operation_tag, NULL)
return result
def start_client_batch(self, operations, tag):
# We don't reference this call in the operations tag because

@ -76,7 +76,7 @@ cdef class Channel:
def watch_connectivity_state(
self, grpc_connectivity_state last_observed_state,
Timespec deadline not None, CompletionQueue queue not None, tag):
cdef OperationTag operation_tag = OperationTag(tag)
cdef OperationTag operation_tag = OperationTag(tag, None)
cpython.Py_INCREF(operation_tag)
with nogil:
grpc_channel_watch_connectivity_state(

@ -42,7 +42,7 @@ cdef class CompletionQueue:
cdef Call operation_call = None
cdef CallDetails request_call_details = None
cdef object request_metadata = None
cdef Operations batch_operations = None
cdef object batch_operations = None
if event.type == GRPC_QUEUE_TIMEOUT:
return Event(
event.type, False, None, None, None, None, False, None)
@ -64,7 +64,7 @@ cdef class CompletionQueue:
if tag.is_new_request:
request_metadata = _metadata(&tag._c_request_metadata)
grpc_metadata_array_destroy(&tag._c_request_metadata)
batch_operations = tag.batch_operations
batch_operations = tag.release_ops()
if tag.is_new_request:
# Stuff in the tag not explicitly handled by us needs to live through
# the life of the call

@ -38,9 +38,14 @@ cdef class OperationTag:
cdef Call operation_call
cdef CallDetails request_call_details
cdef grpc_metadata_array _c_request_metadata
cdef Operations batch_operations
cdef grpc_op *c_ops
cdef size_t c_nops
cdef readonly object _operations
cdef bint is_new_request
cdef void store_ops(self)
cdef object release_ops(self)
cdef class Event:
@ -57,7 +62,7 @@ cdef class Event:
cdef readonly Call operation_call
# For Call.start_batch
cdef readonly Operations batch_operations
cdef readonly object batch_operations
cdef class ByteBuffer:
@ -100,13 +105,6 @@ cdef class Operation:
cdef object references
cdef class Operations:
cdef grpc_op *c_ops
cdef size_t c_nops
cdef list operations
cdef class CompressionOptions:
cdef grpc_compression_options c_options

@ -220,9 +220,26 @@ cdef class CallDetails:
cdef class OperationTag:
def __cinit__(self, user_tag):
def __cinit__(self, user_tag, operations):
self.user_tag = user_tag
self.references = []
self._operations = operations
cdef void store_ops(self):
self.c_nops = 0 if self._operations is None else len(self._operations)
if 0 < self.c_nops:
self.c_ops = <grpc_op *>gpr_malloc(sizeof(grpc_op) * self.c_nops)
for index in range(self.c_nops):
self.c_ops[index] = (<Operation>(self._operations[index])).c_op
cdef object release_ops(self):
if 0 < self.c_nops:
for index, operation in enumerate(self._operations):
(<Operation>operation).c_op = self.c_ops[index]
gpr_free(self.c_ops)
return self._operations
else:
return ()
cdef class Event:
@ -232,7 +249,7 @@ cdef class Event:
CallDetails request_call_details,
object request_metadata,
bint is_new_request,
Operations batch_operations):
object batch_operations):
self.type = type
self.success = success
self.tag = tag
@ -569,59 +586,6 @@ def operation_receive_close_on_server(int flags):
return op
cdef class _OperationsIterator:
cdef size_t i
cdef Operations operations
def __cinit__(self, Operations operations not None):
self.i = 0
self.operations = operations
def __iter__(self):
return self
def __next__(self):
if self.i < len(self.operations):
result = self.operations[self.i]
self.i = self.i + 1
return result
else:
raise StopIteration()
cdef class Operations:
def __cinit__(self, operations):
grpc_init()
self.operations = list(operations) # normalize iterable
self.c_ops = NULL
self.c_nops = 0
for operation in self.operations:
if not isinstance(operation, Operation):
raise TypeError("expected operations to be iterable of Operation")
self.c_nops = len(self.operations)
with nogil:
self.c_ops = <grpc_op *>gpr_malloc(sizeof(grpc_op)*self.c_nops)
for i in range(self.c_nops):
self.c_ops[i] = (<Operation>(self.operations[i])).c_op
def __len__(self):
return self.c_nops
def __getitem__(self, size_t i):
# self.operations is never stale; it's only updated from this file
return self.operations[i]
def __dealloc__(self):
with nogil:
gpr_free(self.c_ops)
grpc_shutdown()
def __iter__(self):
return _OperationsIterator(self)
cdef class CompressionOptions:
def __cinit__(self):

@ -78,13 +78,12 @@ cdef class Server:
raise ValueError("server must be started and not shutting down")
if server_queue not in self.registered_completion_queues:
raise ValueError("server_queue must be a registered completion queue")
cdef OperationTag operation_tag = OperationTag(tag)
cdef OperationTag operation_tag = OperationTag(tag, None)
operation_tag.operation_call = Call()
operation_tag.request_call_details = CallDetails()
grpc_metadata_array_init(&operation_tag._c_request_metadata)
operation_tag.references.extend([self, call_queue, server_queue])
operation_tag.is_new_request = True
operation_tag.batch_operations = Operations([])
cpython.Py_INCREF(operation_tag)
return grpc_server_request_call(
self.c_server, &operation_tag.operation_call.c_call,
@ -132,7 +131,7 @@ cdef class Server:
cdef _c_shutdown(self, CompletionQueue queue, tag):
self.is_shutting_down = True
operation_tag = OperationTag(tag)
operation_tag = OperationTag(tag, None)
operation_tag.shutting_down_server = self
cpython.Py_INCREF(operation_tag)
with nogil:

@ -138,9 +138,8 @@ def _abort(state, call, code, details):
state.trailing_metadata, effective_code, effective_details,
_EMPTY_FLAGS),)
token = _SEND_STATUS_FROM_SERVER_TOKEN
call.start_server_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, token))
call.start_server_batch(operations,
_send_status_from_server(state, token))
state.statused = True
state.due.add(token)
@ -264,8 +263,7 @@ class _Context(grpc.ServicerContext):
operation = cygrpc.operation_send_initial_metadata(
initial_metadata, _EMPTY_FLAGS)
self._rpc_event.operation_call.start_server_batch(
cygrpc.Operations((operation,)),
_send_initial_metadata(self._state))
(operation,), _send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False
self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
else:
@ -298,8 +296,7 @@ class _RequestIterator(object):
raise StopIteration()
else:
self._call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_message(_EMPTY_FLAGS),),
_receive_message(self._state, self._call,
self._request_deserializer))
self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
@ -342,8 +339,7 @@ def _unary_request(rpc_event, state, request_deserializer):
return None
else:
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_message(_EMPTY_FLAGS),),
_receive_message(state, rpc_event.operation_call,
request_deserializer))
state.due.add(_RECEIVE_MESSAGE_TOKEN)
@ -423,7 +419,7 @@ def _send_response(rpc_event, state, serialized_response):
_EMPTY_FLAGS),)
token = _SEND_MESSAGE_TOKEN
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations), _send_message(state, token))
operations, _send_message(state, token))
state.due.add(token)
while True:
state.condition.wait()
@ -449,7 +445,7 @@ def _status(rpc_event, state, serialized_response):
cygrpc.operation_send_message(serialized_response,
_EMPTY_FLAGS))
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations),
operations,
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
state.statused = True
state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
@ -559,8 +555,7 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state = _RPCState()
with state.condition:
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),),
_receive_close_on_server(state))
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
if method_handler.request_streaming:

@ -65,12 +65,10 @@ class _Handler(object):
with self._lock:
self._call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),),
_RECEIVE_CLOSE_ON_SERVER_TAG)
self._call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_message(_EMPTY_FLAGS),),
_RECEIVE_MESSAGE_TAG)
first_event = self._completion_queue.poll()
if _is_cancellation_event(first_event):
@ -84,8 +82,8 @@ class _Handler(object):
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
_EMPTY_FLAGS),)
self._call.start_server_batch(
cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
self._call.start_server_batch(operations,
_SERVER_COMPLETE_CALL_TAG)
self._completion_queue.poll()
self._completion_queue.poll()
@ -179,8 +177,7 @@ class CancelManyCallsTest(unittest.TestCase):
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
tag = 'client_complete_call_{0:04d}_tag'.format(index)
client_call.start_client_batch(
cygrpc.Operations(operations), tag)
client_call.start_client_batch(operations, tag)
client_due.add(tag)
client_calls.append(client_call)

@ -48,20 +48,19 @@ class Test(_common.RpcTest, unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with self.client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(
_common.EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_call.start_client_batch([
cygrpc.operation_receive_initial_metadata(
_common.EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_complete_rpc_start_batch_result = client_call.start_client_batch(
cygrpc.Operations([
[
cygrpc.operation_send_initial_metadata(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(
_common.EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(
_common.EMPTY_FLAGS),
]), client_complete_rpc_tag)
], client_complete_rpc_tag)
self.client_driver.add_due({
client_receive_initial_metadata_tag,
client_complete_rpc_tag,

@ -43,20 +43,19 @@ class Test(_common.RpcTest, unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with self.client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(
_common.EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_call.start_client_batch([
cygrpc.operation_receive_initial_metadata(
_common.EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_complete_rpc_start_batch_result = client_call.start_client_batch(
cygrpc.Operations([
[
cygrpc.operation_send_initial_metadata(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(
_common.EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(
_common.EMPTY_FLAGS),
]), client_complete_rpc_tag)
], client_complete_rpc_tag)
self.client_driver.add_due({
client_receive_initial_metadata_tag,
client_complete_rpc_tag,

@ -157,19 +157,17 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_call.start_client_batch([
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_due.add(client_receive_initial_metadata_tag)
client_complete_rpc_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
]), client_complete_rpc_tag))
client_call.start_client_batch([
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
], client_complete_rpc_tag))
client_due.add(client_complete_rpc_tag)
server_rpc_event = server_driver.first_event()
@ -209,10 +207,9 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
with client_condition:
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_receive_first_message_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_message(_EMPTY_FLAGS),
]), client_receive_first_message_tag))
client_call.start_client_batch([
cygrpc.operation_receive_message(_EMPTY_FLAGS),
], client_receive_first_message_tag))
client_due.add(client_receive_first_message_tag)
client_receive_first_message_event = client_driver.event_with_tag(
client_receive_first_message_tag)

@ -35,17 +35,6 @@ def _metadata_plugin(context, callback):
class TypeSmokeTest(unittest.TestCase):
def testOperationsIteration(self):
operations = cygrpc.Operations(
[cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)])
iterator = iter(operations)
operation = next(iterator)
self.assertIsInstance(operation, cygrpc.Operation)
# `Operation`s are write-only structures; can't directly debug anything out
# of them. Just check that we stop iterating.
with self.assertRaises(StopIteration):
next(iterator)
def testOperationFlags(self):
operation = cygrpc.operation_send_message(b'asdf',
cygrpc.WriteFlag.no_compress)
@ -155,8 +144,7 @@ class ServerClientMixin(object):
def performer():
tag = object()
try:
call_result = call.start_client_batch(
cygrpc.Operations(operations), tag)
call_result = call.start_client_batch(operations, tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete,

Loading…
Cancel
Save