Merge pull request #15132 from nathanielmanistaatgoogle/12531

Keep Core memory inside cygrpc.Channel objects.
pull/15258/head
Nathaniel Manista 7 years ago committed by GitHub
commit c955125c32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 381
      src/python/grpcio/grpc/_channel.py
  2. 56
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
  3. 477
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  4. 8
      src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi
  5. 93
      src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
  6. 50
      src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
  7. 28
      src/python/grpcio_tests/tests/unit/_cython/_channel_test.py
  8. 3
      src/python/grpcio_tests/tests/unit/_cython/_common.py
  9. 54
      src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py
  10. 55
      src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py
  11. 73
      src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
  12. 112
      src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
  13. 49
      src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py

@ -79,27 +79,6 @@ def _wait_once_until(condition, until):
condition.wait(timeout=remaining)
_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = (
'Internal gRPC call error %d. ' +
'Please report to https://github.com/grpc/grpc/issues')
def _check_call_error(call_error, metadata):
if call_error == cygrpc.CallError.invalid_metadata:
raise ValueError('metadata was invalid: %s' % metadata)
elif call_error != cygrpc.CallError.ok:
raise ValueError(_INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error)
def _call_error_set_RPCstate(state, call_error, metadata):
if call_error == cygrpc.CallError.invalid_metadata:
_abort(state, grpc.StatusCode.INTERNAL,
'metadata was invalid: %s' % metadata)
else:
_abort(state, grpc.StatusCode.INTERNAL,
_INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error)
class _RPCState(object):
def __init__(self, due, initial_metadata, trailing_metadata, code, details):
@ -163,7 +142,7 @@ def _handle_event(event, state, response_deserializer):
return callbacks
def _event_handler(state, call, response_deserializer):
def _event_handler(state, response_deserializer):
def handle_event(event):
with state.condition:
@ -172,40 +151,47 @@ def _event_handler(state, call, response_deserializer):
done = not state.due
for callback in callbacks:
callback()
return call if done else None
return done
return handle_event
def _consume_request_iterator(request_iterator, state, call,
request_serializer):
event_handler = _event_handler(state, call, None)
def _consume_request_iterator(request_iterator, state, call, request_serializer,
event_handler):
def consume_request_iterator():
def consume_request_iterator(): # pylint: disable=too-many-branches
while True:
try:
request = next(request_iterator)
except StopIteration:
break
except Exception: # pylint: disable=broad-except
logging.exception("Exception iterating requests!")
call.cancel()
_abort(state, grpc.StatusCode.UNKNOWN,
"Exception iterating requests!")
code = grpc.StatusCode.UNKNOWN
details = 'Exception iterating requests!'
logging.exception(details)
call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code],
details)
_abort(state, code, details)
return
serialized_request = _common.serialize(request, request_serializer)
with state.condition:
if state.code is None and not state.cancelled:
if serialized_request is None:
call.cancel()
code = grpc.StatusCode.INTERNAL # pylint: disable=redefined-variable-type
details = 'Exception serializing request!'
_abort(state, grpc.StatusCode.INTERNAL, details)
call.cancel(
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code],
details)
_abort(state, code, details)
return
else:
operations = (cygrpc.SendMessageOperation(
serialized_request, _EMPTY_FLAGS),)
call.start_client_batch(operations, event_handler)
state.due.add(cygrpc.OperationType.send_message)
operating = call.operate(operations, event_handler)
if operating:
state.due.add(cygrpc.OperationType.send_message)
else:
return
while True:
state.condition.wait()
if state.code is None:
@ -219,15 +205,19 @@ def _consume_request_iterator(request_iterator, state, call,
if state.code is None:
operations = (
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),)
call.start_client_batch(operations, event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
operating = call.operate(operations, event_handler)
if operating:
state.due.add(cygrpc.OperationType.send_close_from_client)
def stop_consumption_thread(timeout): # pylint: disable=unused-argument
with state.condition:
if state.code is None:
call.cancel()
code = grpc.StatusCode.CANCELLED
details = 'Consumption thread cleaned up!'
call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code],
details)
state.cancelled = True
_abort(state, grpc.StatusCode.CANCELLED, 'Cancelled!')
_abort(state, code, details)
state.condition.notify_all()
consumption_thread = _common.CleanupThread(
@ -247,9 +237,12 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
def cancel(self):
with self._state.condition:
if self._state.code is None:
self._call.cancel()
code = grpc.StatusCode.CANCELLED
details = 'Locally cancelled by application!'
self._call.cancel(
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details)
self._state.cancelled = True
_abort(self._state, grpc.StatusCode.CANCELLED, 'Cancelled!')
_abort(self._state, code, details)
self._state.condition.notify_all()
return False
@ -318,12 +311,13 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
def _next(self):
with self._state.condition:
if self._state.code is None:
event_handler = _event_handler(self._state, self._call,
event_handler = _event_handler(self._state,
self._response_deserializer)
self._call.start_client_batch(
operating = self._call.operate(
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
event_handler)
self._state.due.add(cygrpc.OperationType.receive_message)
if operating:
self._state.due.add(cygrpc.OperationType.receive_message)
elif self._state.code is grpc.StatusCode.OK:
raise StopIteration()
else:
@ -408,9 +402,12 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
def __del__(self):
with self._state.condition:
if self._state.code is None:
self._call.cancel()
self._state.cancelled = True
self._state.code = grpc.StatusCode.CANCELLED
self._state.details = 'Cancelled upon garbage collection!'
self._state.cancelled = True
self._call.cancel(
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code],
self._state.details)
self._state.condition.notify_all()
@ -437,6 +434,24 @@ def _end_unary_response_blocking(state, call, with_call, deadline):
raise _Rendezvous(state, None, None, deadline)
def _stream_unary_invocation_operationses(metadata):
return (
(
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
),
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
)
def _stream_unary_invocation_operationses_and_tags(metadata):
return tuple((
operations,
None,
) for operations in _stream_unary_invocation_operationses(metadata))
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
def __init__(self, channel, managed_call, method, request_serializer,
@ -448,8 +463,8 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
self._response_deserializer = response_deserializer
def _prepare(self, request, timeout, metadata):
deadline, serialized_request, rendezvous = (_start_unary_request(
request, timeout, self._request_serializer))
deadline, serialized_request, rendezvous = _start_unary_request(
request, timeout, self._request_serializer)
if serialized_request is None:
return None, None, None, rendezvous
else:
@ -467,48 +482,38 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
def _blocking(self, request, timeout, metadata, credentials):
state, operations, deadline, rendezvous = self._prepare(
request, timeout, metadata)
if rendezvous:
if state is None:
raise rendezvous
else:
completion_queue = cygrpc.CompletionQueue()
call = self._channel.create_call(None, 0, completion_queue,
self._method, None, deadline)
if credentials is not None:
call.set_credentials(credentials._credentials)
call_error = call.start_client_batch(operations, None)
_check_call_error(call_error, metadata)
_handle_event(completion_queue.poll(), state,
self._response_deserializer)
return state, call, deadline
call = self._channel.segregated_call(
0, self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials, ((
operations,
None,
),))
event = call.next_event()
_handle_event(event, state, self._response_deserializer)
return state, call,
def __call__(self, request, timeout=None, metadata=None, credentials=None):
state, call, deadline = self._blocking(request, timeout, metadata,
credentials)
return _end_unary_response_blocking(state, call, False, deadline)
state, call, = self._blocking(request, timeout, metadata, credentials)
return _end_unary_response_blocking(state, call, False, None)
def with_call(self, request, timeout=None, metadata=None, credentials=None):
state, call, deadline = self._blocking(request, timeout, metadata,
credentials)
return _end_unary_response_blocking(state, call, True, deadline)
state, call, = self._blocking(request, timeout, metadata, credentials)
return _end_unary_response_blocking(state, call, True, None)
def future(self, request, timeout=None, metadata=None, credentials=None):
state, operations, deadline, rendezvous = self._prepare(
request, timeout, metadata)
if rendezvous:
return rendezvous
if state is None:
raise rendezvous
else:
call, drive_call = self._managed_call(None, 0, self._method, None,
deadline)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call,
self._response_deserializer)
with state.condition:
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call(
0, self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials,
(operations,), event_handler)
return _Rendezvous(state, call, self._response_deserializer,
deadline)
@ -524,34 +529,27 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._response_deserializer = response_deserializer
def __call__(self, request, timeout=None, metadata=None, credentials=None):
deadline, serialized_request, rendezvous = (_start_unary_request(
request, timeout, self._request_serializer))
deadline, serialized_request, rendezvous = _start_unary_request(
request, timeout, self._request_serializer)
if serialized_request is None:
raise rendezvous
else:
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
call, drive_call = self._managed_call(None, 0, self._method, None,
deadline)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call,
self._response_deserializer)
with state.condition:
call.start_client_batch(
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
event_handler)
operations = (
operationses = (
(
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.SendMessageOperation(serialized_request,
_EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
),
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
)
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call(
0, self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials,
operationses, event_handler)
return _Rendezvous(state, call, self._response_deserializer,
deadline)
@ -569,49 +567,38 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
def _blocking(self, request_iterator, timeout, metadata, credentials):
deadline = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
completion_queue = cygrpc.CompletionQueue()
call = self._channel.create_call(None, 0, completion_queue,
self._method, None, deadline)
if credentials is not None:
call.set_credentials(credentials._credentials)
with state.condition:
call.start_client_batch(
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), None)
operations = (
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(operations, None)
_check_call_error(call_error, metadata)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer)
call = self._channel.segregated_call(
0, self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials,
_stream_unary_invocation_operationses_and_tags(metadata))
_consume_request_iterator(request_iterator, state, call,
self._request_serializer, None)
while True:
event = completion_queue.poll()
event = call.next_event()
with state.condition:
_handle_event(event, state, self._response_deserializer)
state.condition.notify_all()
if not state.due:
break
return state, call, deadline
return state, call,
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
state, call, deadline = self._blocking(request_iterator, timeout,
metadata, credentials)
return _end_unary_response_blocking(state, call, False, deadline)
state, call, = self._blocking(request_iterator, timeout, metadata,
credentials)
return _end_unary_response_blocking(state, call, False, None)
def with_call(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
state, call, deadline = self._blocking(request_iterator, timeout,
metadata, credentials)
return _end_unary_response_blocking(state, call, True, deadline)
state, call, = self._blocking(request_iterator, timeout, metadata,
credentials)
return _end_unary_response_blocking(state, call, True, None)
def future(self,
request_iterator,
@ -620,27 +607,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
credentials=None):
deadline = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
call, drive_call = self._managed_call(None, 0, self._method, None,
deadline)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_client_batch(
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
event_handler)
operations = (
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
_consume_request_iterator(request_iterator, state, call,
self._request_serializer)
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call(
0, self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials,
_stream_unary_invocation_operationses(metadata), event_handler)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer, event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@ -661,26 +634,20 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
credentials=None):
deadline = _deadline(timeout)
state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None)
call, drive_call = self._managed_call(None, 0, self._method, None,
deadline)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_client_batch(
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
event_handler)
operations = (
operationses = (
(
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
_consume_request_iterator(request_iterator, state, call,
self._request_serializer)
),
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
)
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call(
0, self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials, operationses,
event_handler)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer, event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@ -689,28 +656,25 @@ class _ChannelCallState(object):
def __init__(self, channel):
self.lock = threading.Lock()
self.channel = channel
self.completion_queue = cygrpc.CompletionQueue()
self.managed_calls = None
self.managed_calls = 0
def _run_channel_spin_thread(state):
def channel_spin():
while True:
event = state.completion_queue.poll()
completed_call = event.tag(event)
if completed_call is not None:
event = state.channel.next_call_event()
call_completed = event.tag(event)
if call_completed:
with state.lock:
state.managed_calls.remove(completed_call)
if not state.managed_calls:
state.managed_calls = None
state.managed_calls -= 1
if state.managed_calls == 0:
return
def stop_channel_spin(timeout): # pylint: disable=unused-argument
with state.lock:
if state.managed_calls is not None:
for call in state.managed_calls:
call.cancel()
state.channel.close(cygrpc.StatusCode.cancelled,
'Channel spin thread cleaned up!')
channel_spin_thread = _common.CleanupThread(
stop_channel_spin, target=channel_spin)
@ -719,37 +683,41 @@ def _run_channel_spin_thread(state):
def _channel_managed_call_management(state):
def create(parent, flags, method, host, deadline):
"""Creates a managed cygrpc.Call and a function to call to drive it.
If operations are successfully added to the returned cygrpc.Call, the
returned function must be called. If operations are not successfully added
to the returned cygrpc.Call, the returned function must not be called.
Args:
parent: A cygrpc.Call to be used as the parent of the created call.
flags: An integer bitfield of call flags.
method: The RPC method.
host: A host string for the created call.
deadline: A float to be the deadline of the created call or None if the
call is to have an infinite deadline.
Returns:
A cygrpc.Call with which to conduct an RPC and a function to call if
operations are successfully started on the call.
"""
call = state.channel.create_call(parent, flags, state.completion_queue,
method, host, deadline)
def drive():
with state.lock:
if state.managed_calls is None:
state.managed_calls = set((call,))
_run_channel_spin_thread(state)
else:
state.managed_calls.add(call)
# pylint: disable=too-many-arguments
def create(flags, method, host, deadline, metadata, credentials,
operationses, event_handler):
"""Creates a cygrpc.IntegratedCall.
return call, drive
Args:
flags: An integer bitfield of call flags.
method: The RPC method.
host: A host string for the created call.
deadline: A float to be the deadline of the created call or None if
the call is to have an infinite deadline.
metadata: The metadata for the call or None.
credentials: A cygrpc.CallCredentials or None.
operationses: An iterable of iterables of cygrpc.Operations to be
started on the call.
event_handler: A behavior to call to handle the events resultant from
the operations on the call.
Returns:
A cygrpc.IntegratedCall with which to conduct an RPC.
"""
operationses_and_tags = tuple((
operations,
event_handler,
) for operations in operationses)
with state.lock:
call = state.channel.integrated_call(flags, method, host, deadline,
metadata, credentials,
operationses_and_tags)
if state.managed_calls == 0:
state.managed_calls = 1
_run_channel_spin_thread(state)
else:
state.managed_calls += 1
return call
return create
@ -819,12 +787,9 @@ def _poll_connectivity(state, channel, initial_try_to_connect):
callback_and_connectivity[1] = state.connectivity
if callbacks:
_spawn_delivery(state, callbacks)
completion_queue = cygrpc.CompletionQueue()
while True:
channel.watch_connectivity_state(connectivity,
time.time() + 0.2, completion_queue,
None)
event = completion_queue.poll()
event = channel.watch_connectivity_state(connectivity,
time.time() + 0.2)
with state.lock:
if not state.callbacks_and_connectivities and not state.try_to_connect:
state.polling = False

@ -13,9 +13,59 @@
# limitations under the License.
cdef _check_call_error_no_metadata(c_call_error)
cdef _check_and_raise_call_error_no_metadata(c_call_error)
cdef _check_call_error(c_call_error, metadata)
cdef class _CallState:
cdef grpc_call *c_call
cdef set due
cdef class _ChannelState:
cdef object condition
cdef grpc_channel *c_channel
# A boolean field indicating that the channel is open (if True) or is being
# closed (i.e. a call to close is currently executing) or is closed (if
# False).
# TODO(https://github.com/grpc/grpc/issues/3064): Eliminate "is being closed"
# a state in which condition may be acquired by any thread, eliminate this
# field and just use the NULLness of c_channel as an indication that the
# channel is closed.
cdef object open
# A dict from _BatchOperationTag to _CallState
cdef dict integrated_call_states
cdef grpc_completion_queue *c_call_completion_queue
# A set of _CallState
cdef set segregated_call_states
cdef set connectivity_due
cdef grpc_completion_queue *c_connectivity_completion_queue
cdef class IntegratedCall:
cdef _ChannelState _channel_state
cdef _CallState _call_state
cdef class SegregatedCall:
cdef _ChannelState _channel_state
cdef _CallState _call_state
cdef grpc_completion_queue *_c_completion_queue
cdef class Channel:
cdef grpc_arg_pointer_vtable _vtable
cdef grpc_channel *c_channel
cdef list references
cdef readonly _ArgumentsProcessor _arguments_processor
cdef _ChannelState _state

@ -14,82 +14,439 @@
cimport cpython
import threading
_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = (
'Internal gRPC call error %d. ' +
'Please report to https://github.com/grpc/grpc/issues')
cdef str _call_error_metadata(metadata):
return 'metadata was invalid: %s' % metadata
cdef str _call_error_no_metadata(c_call_error):
return _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % c_call_error
cdef str _call_error(c_call_error, metadata):
if c_call_error == GRPC_CALL_ERROR_INVALID_METADATA:
return _call_error_metadata(metadata)
else:
return _call_error_no_metadata(c_call_error)
cdef _check_call_error_no_metadata(c_call_error):
if c_call_error != GRPC_CALL_OK:
return _INTERNAL_CALL_ERROR_MESSAGE_FORMAT % c_call_error
else:
return None
cdef _check_and_raise_call_error_no_metadata(c_call_error):
error = _check_call_error_no_metadata(c_call_error)
if error is not None:
raise ValueError(error)
cdef _check_call_error(c_call_error, metadata):
if c_call_error == GRPC_CALL_ERROR_INVALID_METADATA:
return _call_error_metadata(metadata)
else:
return _check_call_error_no_metadata(c_call_error)
cdef void _raise_call_error_no_metadata(c_call_error) except *:
raise ValueError(_call_error_no_metadata(c_call_error))
cdef void _raise_call_error(c_call_error, metadata) except *:
raise ValueError(_call_error(c_call_error, metadata))
cdef _destroy_c_completion_queue(grpc_completion_queue *c_completion_queue):
grpc_completion_queue_shutdown(c_completion_queue)
grpc_completion_queue_destroy(c_completion_queue)
cdef class _CallState:
def __cinit__(self):
self.due = set()
cdef class _ChannelState:
def __cinit__(self):
self.condition = threading.Condition()
self.open = True
self.integrated_call_states = {}
self.segregated_call_states = set()
self.connectivity_due = set()
cdef tuple _operate(grpc_call *c_call, object operations, object user_tag):
cdef grpc_call_error c_call_error
cdef _BatchOperationTag tag = _BatchOperationTag(user_tag, operations, None)
tag.prepare()
cpython.Py_INCREF(tag)
with nogil:
c_call_error = grpc_call_start_batch(
c_call, tag.c_ops, tag.c_nops, <cpython.PyObject *>tag, NULL)
return c_call_error, tag
cdef object _operate_from_integrated_call(
_ChannelState channel_state, _CallState call_state, object operations,
object user_tag):
cdef grpc_call_error c_call_error
cdef _BatchOperationTag tag
with channel_state.condition:
if call_state.due:
c_call_error, tag = _operate(call_state.c_call, operations, user_tag)
if c_call_error == GRPC_CALL_OK:
call_state.due.add(tag)
channel_state.integrated_call_states[tag] = call_state
return True
else:
_raise_call_error_no_metadata(c_call_error)
else:
return False
cdef object _operate_from_segregated_call(
_ChannelState channel_state, _CallState call_state, object operations,
object user_tag):
cdef grpc_call_error c_call_error
cdef _BatchOperationTag tag
with channel_state.condition:
if call_state.due:
c_call_error, tag = _operate(call_state.c_call, operations, user_tag)
if c_call_error == GRPC_CALL_OK:
call_state.due.add(tag)
return True
else:
_raise_call_error_no_metadata(c_call_error)
else:
return False
cdef _cancel(
_ChannelState channel_state, _CallState call_state, grpc_status_code code,
str details):
cdef grpc_call_error c_call_error
with channel_state.condition:
if call_state.due:
c_call_error = grpc_call_cancel_with_status(
call_state.c_call, code, _encode(details), NULL)
_check_and_raise_call_error_no_metadata(c_call_error)
cdef BatchOperationEvent _next_call_event(
_ChannelState channel_state, grpc_completion_queue *c_completion_queue,
on_success):
tag, event = _latent_event(c_completion_queue, None)
with channel_state.condition:
on_success(tag)
channel_state.condition.notify_all()
return event
# TODO(https://github.com/grpc/grpc/issues/14569): This could be a lot simpler.
cdef void _call(
_ChannelState channel_state, _CallState call_state,
grpc_completion_queue *c_completion_queue, on_success, int flags, method,
host, object deadline, CallCredentials credentials,
object operationses_and_user_tags, object metadata) except *:
"""Invokes an RPC.
Args:
channel_state: A _ChannelState with its "open" attribute set to True. RPCs
may not be invoked on a closed channel.
call_state: An empty _CallState to be altered (specifically assigned a
c_call and having its due set populated) if the RPC invocation is
successful.
c_completion_queue: A grpc_completion_queue to be used for the call's
operations.
on_success: A behavior to be called if attempting to start operations for
the call succeeds. If called the behavior will be called while holding the
channel_state condition and passed the tags associated with operations
that were successfully started for the call.
flags: Flags to be passed to gRPC Core as part of call creation.
method: The fully-qualified name of the RPC method being invoked.
host: A "host" string to be passed to gRPC Core as part of call creation.
deadline: A float for the deadline of the RPC, or None if the RPC is to have
no deadline.
credentials: A _CallCredentials for the RPC or None.
operationses_and_user_tags: A sequence of length-two sequences the first
element of which is a sequence of Operations and the second element of
which is an object to be used as a tag. A SendInitialMetadataOperation
must be present in the first element of this value.
metadata: The metadata for this call.
"""
cdef grpc_slice method_slice
cdef grpc_slice host_slice
cdef grpc_slice *host_slice_ptr
cdef grpc_call_credentials *c_call_credentials
cdef grpc_call_error c_call_error
cdef tuple error_and_wrapper_tag
cdef _BatchOperationTag wrapper_tag
with channel_state.condition:
if channel_state.open:
method_slice = _slice_from_bytes(method)
if host is None:
host_slice_ptr = NULL
else:
host_slice = _slice_from_bytes(host)
host_slice_ptr = &host_slice
call_state.c_call = grpc_channel_create_call(
channel_state.c_channel, NULL, flags,
c_completion_queue, method_slice, host_slice_ptr,
_timespec_from_time(deadline), NULL)
grpc_slice_unref(method_slice)
if host_slice_ptr:
grpc_slice_unref(host_slice)
if credentials is not None:
c_call_credentials = credentials.c()
c_call_error = grpc_call_set_credentials(
call_state.c_call, c_call_credentials)
grpc_call_credentials_release(c_call_credentials)
if c_call_error != GRPC_CALL_OK:
grpc_call_unref(call_state.c_call)
call_state.c_call = NULL
_raise_call_error_no_metadata(c_call_error)
started_tags = set()
for operations, user_tag in operationses_and_user_tags:
c_call_error, tag = _operate(call_state.c_call, operations, user_tag)
if c_call_error == GRPC_CALL_OK:
started_tags.add(tag)
else:
grpc_call_cancel(call_state.c_call, NULL)
grpc_call_unref(call_state.c_call)
call_state.c_call = NULL
_raise_call_error(c_call_error, metadata)
else:
call_state.due.update(started_tags)
on_success(started_tags)
else:
raise ValueError('Cannot invoke RPC on closed channel!')
cdef void _process_integrated_call_tag(
_ChannelState state, _BatchOperationTag tag) except *:
cdef _CallState call_state = state.integrated_call_states.pop(tag)
call_state.due.remove(tag)
if not call_state.due:
grpc_call_unref(call_state.c_call)
call_state.c_call = NULL
cdef class IntegratedCall:
def __cinit__(self, _ChannelState channel_state, _CallState call_state):
self._channel_state = channel_state
self._call_state = call_state
def operate(self, operations, tag):
return _operate_from_integrated_call(
self._channel_state, self._call_state, operations, tag)
def cancel(self, code, details):
_cancel(self._channel_state, self._call_state, code, details)
cdef IntegratedCall _integrated_call(
_ChannelState state, int flags, method, host, object deadline,
object metadata, CallCredentials credentials, operationses_and_user_tags):
call_state = _CallState()
def on_success(started_tags):
for started_tag in started_tags:
state.integrated_call_states[started_tag] = call_state
_call(
state, call_state, state.c_call_completion_queue, on_success, flags,
method, host, deadline, credentials, operationses_and_user_tags, metadata)
return IntegratedCall(state, call_state)
cdef object _process_segregated_call_tag(
_ChannelState state, _CallState call_state,
grpc_completion_queue *c_completion_queue, _BatchOperationTag tag):
call_state.due.remove(tag)
if not call_state.due:
grpc_call_unref(call_state.c_call)
call_state.c_call = NULL
state.segregated_call_states.remove(call_state)
_destroy_c_completion_queue(c_completion_queue)
return True
else:
return False
cdef class SegregatedCall:
def __cinit__(self, _ChannelState channel_state, _CallState call_state):
self._channel_state = channel_state
self._call_state = call_state
def operate(self, operations, tag):
return _operate_from_segregated_call(
self._channel_state, self._call_state, operations, tag)
def cancel(self, code, details):
_cancel(self._channel_state, self._call_state, code, details)
def next_event(self):
def on_success(tag):
_process_segregated_call_tag(
self._channel_state, self._call_state, self._c_completion_queue, tag)
return _next_call_event(
self._channel_state, self._c_completion_queue, on_success)
cdef SegregatedCall _segregated_call(
_ChannelState state, int flags, method, host, object deadline,
object metadata, CallCredentials credentials, operationses_and_user_tags):
cdef _CallState call_state = _CallState()
cdef grpc_completion_queue *c_completion_queue = (
grpc_completion_queue_create_for_next(NULL))
cdef SegregatedCall segregated_call
def on_success(started_tags):
state.segregated_call_states.add(call_state)
try:
_call(
state, call_state, c_completion_queue, on_success, flags, method, host,
deadline, credentials, operationses_and_user_tags, metadata)
except:
_destroy_c_completion_queue(c_completion_queue)
raise
segregated_call = SegregatedCall(state, call_state)
segregated_call._c_completion_queue = c_completion_queue
return segregated_call
cdef object _watch_connectivity_state(
_ChannelState state, grpc_connectivity_state last_observed_state,
object deadline):
cdef _ConnectivityTag tag = _ConnectivityTag(object())
with state.condition:
if state.open:
cpython.Py_INCREF(tag)
grpc_channel_watch_connectivity_state(
state.c_channel, last_observed_state, _timespec_from_time(deadline),
state.c_connectivity_completion_queue, <cpython.PyObject *>tag)
state.connectivity_due.add(tag)
else:
raise ValueError('Cannot invoke RPC on closed channel!')
completed_tag, event = _latent_event(
state.c_connectivity_completion_queue, None)
with state.condition:
state.connectivity_due.remove(completed_tag)
state.condition.notify_all()
return event
cdef _close(_ChannelState state, grpc_status_code code, object details):
cdef _CallState call_state
encoded_details = _encode(details)
with state.condition:
if state.open:
state.open = False
for call_state in set(state.integrated_call_states.values()):
grpc_call_cancel_with_status(
call_state.c_call, code, encoded_details, NULL)
for call_state in state.segregated_call_states:
grpc_call_cancel_with_status(
call_state.c_call, code, encoded_details, NULL)
# TODO(https://github.com/grpc/grpc/issues/3064): Cancel connectivity
# watching.
while state.integrated_call_states:
state.condition.wait()
while state.segregated_call_states:
state.condition.wait()
while state.connectivity_due:
state.condition.wait()
_destroy_c_completion_queue(state.c_call_completion_queue)
_destroy_c_completion_queue(state.c_connectivity_completion_queue)
grpc_channel_destroy(state.c_channel)
state.c_channel = NULL
grpc_shutdown()
state.condition.notify_all()
else:
# Another call to close already completed in the past or is currently
# being executed in another thread.
while state.c_channel != NULL:
state.condition.wait()
cdef class Channel:
def __cinit__(self, bytes target, object arguments,
ChannelCredentials channel_credentials=None):
def __cinit__(
self, bytes target, object arguments,
ChannelCredentials channel_credentials):
grpc_init()
self._state = _ChannelState()
self._vtable.copy = &_copy_pointer
self._vtable.destroy = &_destroy_pointer
self._vtable.cmp = &_compare_pointer
cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor(
arguments)
cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable)
self.references = []
c_target = target
if channel_credentials is None:
self.c_channel = grpc_insecure_channel_create(c_target, c_arguments, NULL)
self._state.c_channel = grpc_insecure_channel_create(
<char *>target, c_arguments, NULL)
else:
c_channel_credentials = channel_credentials.c()
self.c_channel = grpc_secure_channel_create(
c_channel_credentials, c_target, c_arguments, NULL)
self._state.c_channel = grpc_secure_channel_create(
c_channel_credentials, <char *>target, c_arguments, NULL)
grpc_channel_credentials_release(c_channel_credentials)
arguments_processor.un_c()
self.references.append(target)
self.references.append(arguments)
def create_call(self, Call parent, int flags,
CompletionQueue queue not None,
method, host, object deadline):
if queue.is_shutting_down:
raise ValueError("queue must not be shutting down or shutdown")
cdef grpc_slice method_slice = _slice_from_bytes(method)
cdef grpc_slice host_slice
cdef grpc_slice *host_slice_ptr = NULL
if host is not None:
host_slice = _slice_from_bytes(host)
host_slice_ptr = &host_slice
cdef Call operation_call = Call()
operation_call.references = [self, queue]
cdef grpc_call *parent_call = NULL
if parent is not None:
parent_call = parent.c_call
operation_call.c_call = grpc_channel_create_call(
self.c_channel, parent_call, flags,
queue.c_completion_queue, method_slice, host_slice_ptr,
_timespec_from_time(deadline), NULL)
grpc_slice_unref(method_slice)
if host_slice_ptr:
grpc_slice_unref(host_slice)
return operation_call
self._state.c_call_completion_queue = (
grpc_completion_queue_create_for_next(NULL))
self._state.c_connectivity_completion_queue = (
grpc_completion_queue_create_for_next(NULL))
def target(self):
cdef char *c_target
with self._state.condition:
c_target = grpc_channel_get_target(self._state.c_channel)
target = <bytes>c_target
gpr_free(c_target)
return target
def integrated_call(
self, int flags, method, host, object deadline, object metadata,
CallCredentials credentials, operationses_and_tags):
return _integrated_call(
self._state, flags, method, host, deadline, metadata, credentials,
operationses_and_tags)
def next_call_event(self):
def on_success(tag):
_process_integrated_call_tag(self._state, tag)
return _next_call_event(
self._state, self._state.c_call_completion_queue, on_success)
def segregated_call(
self, int flags, method, host, object deadline, object metadata,
CallCredentials credentials, operationses_and_tags):
return _segregated_call(
self._state, flags, method, host, deadline, metadata, credentials,
operationses_and_tags)
def check_connectivity_state(self, bint try_to_connect):
cdef grpc_connectivity_state result
with nogil:
result = grpc_channel_check_connectivity_state(self.c_channel,
try_to_connect)
return result
with self._state.condition:
return grpc_channel_check_connectivity_state(
self._state.c_channel, try_to_connect)
def watch_connectivity_state(
self, grpc_connectivity_state last_observed_state,
object deadline, CompletionQueue queue not None, tag):
cdef _ConnectivityTag connectivity_tag = _ConnectivityTag(tag)
cpython.Py_INCREF(connectivity_tag)
grpc_channel_watch_connectivity_state(
self.c_channel, last_observed_state, _timespec_from_time(deadline),
queue.c_completion_queue, <cpython.PyObject *>connectivity_tag)
self, grpc_connectivity_state last_observed_state, object deadline):
return _watch_connectivity_state(self._state, last_observed_state, deadline)
def target(self):
cdef char *target = NULL
with nogil:
target = grpc_channel_get_target(self.c_channel)
result = <bytes>target
with nogil:
gpr_free(target)
return result
def __dealloc__(self):
if self.c_channel != NULL:
grpc_channel_destroy(self.c_channel)
grpc_shutdown()
def close(self, code, details):
_close(self._state, code, details)

@ -13,10 +13,16 @@
# limitations under the License.
cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline)
cdef _interpret_event(grpc_event c_event)
cdef class CompletionQueue:
cdef grpc_completion_queue *c_completion_queue
cdef bint is_shutting_down
cdef bint is_shutdown
cdef _interpret_event(self, grpc_event event)
cdef _interpret_event(self, grpc_event c_event)

@ -20,6 +20,53 @@ import time
cdef int _INTERRUPT_CHECK_PERIOD_MS = 200
cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline):
cdef gpr_timespec c_increment
cdef gpr_timespec c_timeout
cdef gpr_timespec c_deadline
c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN)
if deadline is None:
c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME)
else:
c_deadline = _timespec_from_time(deadline)
with nogil:
while True:
c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment)
if gpr_time_cmp(c_timeout, c_deadline) > 0:
c_timeout = c_deadline
c_event = grpc_completion_queue_next(c_completion_queue, c_timeout, NULL)
if (c_event.type != GRPC_QUEUE_TIMEOUT or
gpr_time_cmp(c_timeout, c_deadline) == 0):
break
# Handle any signals
with gil:
cpython.PyErr_CheckSignals()
return c_event
cdef _interpret_event(grpc_event c_event):
cdef _Tag tag
if c_event.type == GRPC_QUEUE_TIMEOUT:
# NOTE(nathaniel): For now we coopt ConnectivityEvent here.
return None, ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None)
elif c_event.type == GRPC_QUEUE_SHUTDOWN:
# NOTE(nathaniel): For now we coopt ConnectivityEvent here.
return None, ConnectivityEvent(GRPC_QUEUE_SHUTDOWN, False, None)
else:
tag = <_Tag>c_event.tag
# We receive event tags only after they've been inc-ref'd elsewhere in
# the code.
cpython.Py_DECREF(tag)
return tag, tag.event(c_event)
cdef _latent_event(grpc_completion_queue *c_completion_queue, object deadline):
cdef grpc_event c_event = _next(c_completion_queue, deadline)
return _interpret_event(c_event)
cdef class CompletionQueue:
def __cinit__(self, shutdown_cq=False):
@ -36,48 +83,16 @@ cdef class CompletionQueue:
self.is_shutting_down = False
self.is_shutdown = False
cdef _interpret_event(self, grpc_event event):
cdef _Tag tag = None
if event.type == GRPC_QUEUE_TIMEOUT:
# NOTE(nathaniel): For now we coopt ConnectivityEvent here.
return ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None)
elif event.type == GRPC_QUEUE_SHUTDOWN:
cdef _interpret_event(self, grpc_event c_event):
unused_tag, event = _interpret_event(c_event)
if event.completion_type == GRPC_QUEUE_SHUTDOWN:
self.is_shutdown = True
# NOTE(nathaniel): For now we coopt ConnectivityEvent here.
return ConnectivityEvent(GRPC_QUEUE_TIMEOUT, True, None)
else:
tag = <_Tag>event.tag
# We receive event tags only after they've been inc-ref'd elsewhere in
# the code.
cpython.Py_DECREF(tag)
return tag.event(event)
return event
# We name this 'poll' to avoid problems with CPython's expectations for
# 'special' methods (like next and __next__).
def poll(self, deadline=None):
# We name this 'poll' to avoid problems with CPython's expectations for
# 'special' methods (like next and __next__).
cdef gpr_timespec c_increment
cdef gpr_timespec c_timeout
cdef gpr_timespec c_deadline
if deadline is None:
c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME)
else:
c_deadline = _timespec_from_time(deadline)
with nogil:
c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN)
while True:
c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment)
if gpr_time_cmp(c_timeout, c_deadline) > 0:
c_timeout = c_deadline
event = grpc_completion_queue_next(
self.c_completion_queue, c_timeout, NULL)
if event.type != GRPC_QUEUE_TIMEOUT or gpr_time_cmp(c_timeout, c_deadline) == 0:
break;
# Handle any signals
with gil:
cpython.PyErr_CheckSignals()
return self._interpret_event(event)
return self._interpret_event(_next(self.c_completion_queue, deadline))
def shutdown(self):
with nogil:

@ -19,6 +19,7 @@ import unittest
from grpc._cython import cygrpc
from grpc.framework.foundation import logging_pool
from tests.unit.framework.common import test_constants
from tests.unit._cython import test_utilities
_EMPTY_FLAGS = 0
_EMPTY_METADATA = ()
@ -30,6 +31,8 @@ _RECEIVE_MESSAGE_TAG = 'receive_message'
_SERVER_COMPLETE_CALL_TAG = 'server_complete_call'
_SUCCESS_CALL_FRACTION = 1.0 / 8.0
_SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
_UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS
class _State(object):
@ -150,7 +153,8 @@ class CancelManyCallsTest(unittest.TestCase):
server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0')
server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None)
channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None,
None)
state = _State()
@ -165,31 +169,33 @@ class CancelManyCallsTest(unittest.TestCase):
client_condition = threading.Condition()
client_due = set()
client_completion_queue = cygrpc.CompletionQueue()
client_driver = _QueueDriver(client_condition, client_completion_queue,
client_due)
client_driver.start()
with client_condition:
client_calls = []
for index in range(test_constants.RPC_CONCURRENCY):
client_call = channel.create_call(None, _EMPTY_FLAGS,
client_completion_queue,
b'/twinkies', None, None)
operations = (
cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.SendMessageOperation(b'\x45\x56', _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
tag = 'client_complete_call_{0:04d}_tag'.format(index)
client_call.start_client_batch(operations, tag)
client_call = channel.integrated_call(
_EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA,
None, ((
(
cygrpc.SendInitialMetadataOperation(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.SendMessageOperation(b'\x45\x56',
_EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(
_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
),
tag,
),))
client_due.add(tag)
client_calls.append(client_call)
client_events_future = test_utilities.SimpleFuture(
lambda: tuple(channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS)))
with state.condition:
while True:
if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
@ -201,12 +207,14 @@ class CancelManyCallsTest(unittest.TestCase):
state.condition.notify_all()
break
client_driver.events(
test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
client_events_future.result()
with client_condition:
for client_call in client_calls:
client_call.cancel()
client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!')
for _ in range(_UNSUCCESSFUL_CALLS):
channel.next_call_event()
channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!')
with state.condition:
server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)

@ -21,25 +21,20 @@ from grpc._cython import cygrpc
from tests.unit.framework.common import test_constants
def _channel_and_completion_queue():
channel = cygrpc.Channel(b'localhost:54321', ())
completion_queue = cygrpc.CompletionQueue()
return channel, completion_queue
def _channel():
return cygrpc.Channel(b'localhost:54321', (), None)
def _connectivity_loop(channel, completion_queue):
def _connectivity_loop(channel):
for _ in range(100):
connectivity = channel.check_connectivity_state(True)
channel.watch_connectivity_state(connectivity,
time.time() + 0.2, completion_queue,
None)
completion_queue.poll()
channel.watch_connectivity_state(connectivity, time.time() + 0.2)
def _create_loop_destroy():
channel, completion_queue = _channel_and_completion_queue()
_connectivity_loop(channel, completion_queue)
completion_queue.shutdown()
channel = _channel()
_connectivity_loop(channel)
channel.close(cygrpc.StatusCode.ok, 'Channel close!')
def _in_parallel(behavior, arguments):
@ -55,12 +50,9 @@ def _in_parallel(behavior, arguments):
class ChannelTest(unittest.TestCase):
def test_single_channel_lonely_connectivity(self):
channel, completion_queue = _channel_and_completion_queue()
_in_parallel(_connectivity_loop, (
channel,
completion_queue,
))
completion_queue.shutdown()
channel = _channel()
_connectivity_loop(channel)
channel.close(cygrpc.StatusCode.ok, 'Channel close!')
def test_multiple_channels_lonely_connectivity(self):
_in_parallel(_create_loop_destroy, ())

@ -100,7 +100,8 @@ class RpcTest(object):
self.server.register_completion_queue(self.server_completion_queue)
port = self.server.add_http2_port(b'[::]:0')
self.server.start()
self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [])
self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [],
None)
self._server_shutdown_tag = 'server_shutdown_tag'
self.server_condition = threading.Condition()

@ -19,6 +19,7 @@ import unittest
from grpc._cython import cygrpc
from tests.unit._cython import _common
from tests.unit._cython import test_utilities
class Test(_common.RpcTest, unittest.TestCase):
@ -41,31 +42,27 @@ class Test(_common.RpcTest, unittest.TestCase):
server_request_call_tag,
})
client_call = self.channel.create_call(None, _common.EMPTY_FLAGS,
self.client_completion_queue,
b'/twinkies', None, None)
client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
client_complete_rpc_tag = 'client_complete_rpc_tag'
with self.client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch([
cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
self.assertEqual(cygrpc.CallError.ok,
client_receive_initial_metadata_start_batch_result)
client_complete_rpc_start_batch_result = client_call.start_client_batch(
client_call = self.channel.integrated_call(
_common.EMPTY_FLAGS, b'/twinkies', None, None,
_common.INVOCATION_METADATA, None, [(
[
cygrpc.SendInitialMetadataOperation(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS),
], client_complete_rpc_tag)
self.assertEqual(cygrpc.CallError.ok,
client_complete_rpc_start_batch_result)
self.client_driver.add_due({
cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS),
],
client_receive_initial_metadata_tag,
client_complete_rpc_tag,
})
)])
client_call.operate([
cygrpc.SendInitialMetadataOperation(_common.INVOCATION_METADATA,
_common.EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS),
], client_complete_rpc_tag)
client_events_future = test_utilities.SimpleFuture(
lambda: [
self.channel.next_call_event(),
self.channel.next_call_event(),])
server_request_call_event = self.server_driver.event_with_tag(
server_request_call_tag)
@ -96,20 +93,23 @@ class Test(_common.RpcTest, unittest.TestCase):
server_complete_rpc_event = server_call_driver.event_with_tag(
server_complete_rpc_tag)
client_receive_initial_metadata_event = self.client_driver.event_with_tag(
client_receive_initial_metadata_tag)
client_complete_rpc_event = self.client_driver.event_with_tag(
client_complete_rpc_tag)
client_events = client_events_future.result()
if client_events[0].tag is client_receive_initial_metadata_tag:
client_receive_initial_metadata_event = client_events[0]
client_complete_rpc_event = client_events[1]
else:
client_complete_rpc_event = client_events[0]
client_receive_initial_metadata_event = client_events[1]
return (
_common.OperationResult(server_request_call_start_batch_result,
server_request_call_event.completion_type,
server_request_call_event.success),
_common.OperationResult(
client_receive_initial_metadata_start_batch_result,
cygrpc.CallError.ok,
client_receive_initial_metadata_event.completion_type,
client_receive_initial_metadata_event.success),
_common.OperationResult(client_complete_rpc_start_batch_result,
_common.OperationResult(cygrpc.CallError.ok,
client_complete_rpc_event.completion_type,
client_complete_rpc_event.success),
_common.OperationResult(

@ -19,6 +19,7 @@ import unittest
from grpc._cython import cygrpc
from tests.unit._cython import _common
from tests.unit._cython import test_utilities
class Test(_common.RpcTest, unittest.TestCase):
@ -36,28 +37,31 @@ class Test(_common.RpcTest, unittest.TestCase):
server_request_call_tag,
})
client_call = self.channel.create_call(None, _common.EMPTY_FLAGS,
self.client_completion_queue,
b'/twinkies', None, None)
client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
client_complete_rpc_tag = 'client_complete_rpc_tag'
with self.client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch([
cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_complete_rpc_start_batch_result = client_call.start_client_batch(
[
cygrpc.SendInitialMetadataOperation(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS),
], client_complete_rpc_tag)
self.client_driver.add_due({
client_receive_initial_metadata_tag,
client_complete_rpc_tag,
})
client_call = self.channel.integrated_call(
_common.EMPTY_FLAGS, b'/twinkies', None, None,
_common.INVOCATION_METADATA, None, [
(
[
cygrpc.SendInitialMetadataOperation(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(
_common.EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(
_common.EMPTY_FLAGS),
],
client_complete_rpc_tag,
),
])
client_call.operate([
cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS),
], client_receive_initial_metadata_tag)
client_events_future = test_utilities.SimpleFuture(
lambda: [
self.channel.next_call_event(),
self.channel.next_call_event(),])
server_request_call_event = self.server_driver.event_with_tag(
server_request_call_tag)
@ -87,20 +91,19 @@ class Test(_common.RpcTest, unittest.TestCase):
server_complete_rpc_event = self.server_driver.event_with_tag(
server_complete_rpc_tag)
client_receive_initial_metadata_event = self.client_driver.event_with_tag(
client_receive_initial_metadata_tag)
client_complete_rpc_event = self.client_driver.event_with_tag(
client_complete_rpc_tag)
client_events = client_events_future.result()
client_receive_initial_metadata_event = client_events[0]
client_complete_rpc_event = client_events[1]
return (
_common.OperationResult(server_request_call_start_batch_result,
server_request_call_event.completion_type,
server_request_call_event.success),
_common.OperationResult(
client_receive_initial_metadata_start_batch_result,
cygrpc.CallError.ok,
client_receive_initial_metadata_event.completion_type,
client_receive_initial_metadata_event.success),
_common.OperationResult(client_complete_rpc_start_batch_result,
_common.OperationResult(cygrpc.CallError.ok,
client_complete_rpc_event.completion_type,
client_complete_rpc_event.success),
_common.OperationResult(

@ -17,6 +17,7 @@ import threading
import unittest
from grpc._cython import cygrpc
from tests.unit._cython import test_utilities
_EMPTY_FLAGS = 0
_EMPTY_METADATA = ()
@ -118,7 +119,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0')
server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set())
channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set(),
None)
server_shutdown_tag = 'server_shutdown_tag'
server_driver = _ServerDriver(server_completion_queue,
@ -127,10 +129,6 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
client_condition = threading.Condition()
client_due = set()
client_completion_queue = cygrpc.CompletionQueue()
client_driver = _QueueDriver(client_condition, client_completion_queue,
client_due)
client_driver.start()
server_call_condition = threading.Condition()
server_send_initial_metadata_tag = 'server_send_initial_metadata_tag'
@ -154,25 +152,28 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_completion_queue,
server_rpc_tag)
client_call = channel.create_call(None, _EMPTY_FLAGS,
client_completion_queue, b'/twinkies',
None, None)
client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
client_complete_rpc_tag = 'client_complete_rpc_tag'
with client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch([
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_due.add(client_receive_initial_metadata_tag)
client_complete_rpc_start_batch_result = (
client_call.start_client_batch([
cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
], client_complete_rpc_tag))
client_due.add(client_complete_rpc_tag)
client_call = channel.segregated_call(
_EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, None, (
(
[
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
],
client_receive_initial_metadata_tag,
),
(
[
cygrpc.SendInitialMetadataOperation(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
],
client_complete_rpc_tag,
),
))
client_receive_initial_metadata_event_future = test_utilities.SimpleFuture(
client_call.next_event)
server_rpc_event = server_driver.first_event()
@ -208,19 +209,20 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_complete_rpc_tag)
server_call_driver.events()
with client_condition:
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_receive_first_message_start_batch_result = (
client_call.start_client_batch([
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
], client_receive_first_message_tag))
client_due.add(client_receive_first_message_tag)
client_receive_first_message_event = client_driver.event_with_tag(
client_receive_first_message_tag)
client_recieve_initial_metadata_event = client_receive_initial_metadata_event_future.result(
)
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_call.operate([
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
], client_receive_first_message_tag)
client_receive_first_message_event = client_call.next_event()
client_call_cancel_result = client_call.cancel()
client_driver.events()
client_call_cancel_result = client_call.cancel(
cygrpc.StatusCode.cancelled, 'Cancelled during test!')
client_complete_rpc_event = client_call.next_event()
channel.close(cygrpc.StatusCode.unknown, 'Channel closed!')
server.shutdown(server_completion_queue, server_shutdown_tag)
server.cancel_all_calls()
server_driver.events()
@ -228,11 +230,6 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
self.assertEqual(cygrpc.CallError.ok, request_call_result)
self.assertEqual(cygrpc.CallError.ok,
server_send_initial_metadata_start_batch_result)
self.assertEqual(cygrpc.CallError.ok,
client_receive_initial_metadata_start_batch_result)
self.assertEqual(cygrpc.CallError.ok,
client_complete_rpc_start_batch_result)
self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result)
self.assertIs(server_rpc_tag, server_rpc_event.tag)
self.assertEqual(cygrpc.CompletionType.operation_complete,
server_rpc_event.completion_type)

@ -51,8 +51,8 @@ class TypeSmokeTest(unittest.TestCase):
del server
def testChannelUpDown(self):
channel = cygrpc.Channel(b'[::]:0', None)
del channel
channel = cygrpc.Channel(b'[::]:0', None, None)
channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!')
def test_metadata_plugin_call_credentials_up_down(self):
cygrpc.MetadataPluginCallCredentials(_metadata_plugin,
@ -121,7 +121,7 @@ class ServerClientMixin(object):
client_credentials)
else:
self.client_channel = cygrpc.Channel('localhost:{}'.format(
self.port).encode(), set())
self.port).encode(), set(), None)
if host_override:
self.host_argument = None # default host
self.expected_host = host_override
@ -131,17 +131,20 @@ class ServerClientMixin(object):
self.expected_host = self.host_argument
def tearDownMixin(self):
self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!')
del self.client_channel
del self.server
del self.client_completion_queue
del self.server_completion_queue
def _perform_operations(self, operations, call, queue, deadline,
description):
"""Perform the list of operations with given call, queue, and deadline.
def _perform_queue_operations(self, operations, call, queue, deadline,
description):
"""Perform the operations with given call, queue, and deadline.
Invocation errors are reported with as an exception with `description` in
the message. Performs the operations asynchronously, returning a future.
"""
Invocation errors are reported with as an exception with `description`
in the message. Performs the operations asynchronously, returning a
future.
"""
def performer():
tag = object()
@ -185,9 +188,6 @@ class ServerClientMixin(object):
self.assertEqual(cygrpc.CallError.ok, request_call_result)
client_call_tag = object()
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
DEADLINE)
client_initial_metadata = (
(
CLIENT_METADATA_ASCII_KEY,
@ -198,18 +198,24 @@ class ServerClientMixin(object):
CLIENT_METADATA_BIN_VALUE,
),
)
client_start_batch_result = client_call.start_client_batch([
cygrpc.SendInitialMetadataOperation(client_initial_metadata,
_EMPTY_FLAGS),
cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
], client_call_tag)
self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
client_event_future = test_utilities.CompletionQueuePollFuture(
self.client_completion_queue, DEADLINE)
client_call = self.client_channel.integrated_call(
0, METHOD, self.host_argument, DEADLINE, client_initial_metadata,
None, [
(
[
cygrpc.SendInitialMetadataOperation(
client_initial_metadata, _EMPTY_FLAGS),
cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
],
client_call_tag,
),
])
client_event_future = test_utilities.SimpleFuture(
self.client_channel.next_call_event)
request_event = self.server_completion_queue.poll(deadline=DEADLINE)
self.assertEqual(cygrpc.CompletionType.operation_complete,
@ -304,66 +310,76 @@ class ServerClientMixin(object):
del client_call
del server_call
def test6522(self):
def test_6522(self):
DEADLINE = time.time() + 5
DEADLINE_TOLERANCE = 0.25
METHOD = b'twinkies'
empty_metadata = ()
# Prologue
server_request_tag = object()
self.server.request_call(self.server_completion_queue,
self.server_completion_queue,
server_request_tag)
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
DEADLINE)
# Prologue
def perform_client_operations(operations, description):
return self._perform_operations(operations, client_call,
self.client_completion_queue,
DEADLINE, description)
client_event_future = perform_client_operations([
cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
], "Client prologue")
client_call = self.client_channel.segregated_call(
0, METHOD, self.host_argument, DEADLINE, None, None, ([(
[
cygrpc.SendInitialMetadataOperation(empty_metadata,
_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
],
object(),
), (
[
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
],
object(),
)]))
client_initial_metadata_event_future = test_utilities.SimpleFuture(
client_call.next_event)
request_event = self.server_completion_queue.poll(deadline=DEADLINE)
server_call = request_event.call
def perform_server_operations(operations, description):
return self._perform_operations(operations, server_call,
self.server_completion_queue,
DEADLINE, description)
return self._perform_queue_operations(operations, server_call,
self.server_completion_queue,
DEADLINE, description)
server_event_future = perform_server_operations([
cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
], "Server prologue")
client_event_future.result() # force completion
client_initial_metadata_event_future.result() # force completion
server_event_future.result()
# Messaging
for _ in range(10):
client_event_future = perform_client_operations([
client_call.operate([
cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
], "Client message")
client_message_event_future = test_utilities.SimpleFuture(
client_call.next_event)
server_event_future = perform_server_operations([
cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
], "Server receive")
client_event_future.result() # force completion
client_message_event_future.result() # force completion
server_event_future.result()
# Epilogue
client_event_future = perform_client_operations([
client_call.operate([
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
], "Client epilogue")
# One for ReceiveStatusOnClient, one for SendCloseFromClient.
client_events_future = test_utilities.SimpleFuture(
lambda: {
client_call.next_event(),
client_call.next_event(),})
server_event_future = perform_server_operations([
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
@ -371,7 +387,7 @@ class ServerClientMixin(object):
empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
], "Server epilogue")
client_event_future.result() # force completion
client_events_future.result() # force completion
server_event_future.result()

@ -81,29 +81,16 @@ class InvalidMetadataTest(unittest.TestCase):
request = b'\x07\x08'
metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata
response_future = self._unary_unary.future(request, metadata=metadata)
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertEqual(exception_context.exception.details(),
expected_error_details)
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.INTERNAL)
self.assertEqual(response_future.details(), expected_error_details)
self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
with self.assertRaises(ValueError) as exception_context:
self._unary_unary.future(request, metadata=metadata)
def testUnaryRequestStreamResponse(self):
request = b'\x37\x58'
metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata
response_iterator = self._unary_stream(request, metadata=metadata)
with self.assertRaises(grpc.RpcError) as exception_context:
next(response_iterator)
self.assertEqual(exception_context.exception.details(),
expected_error_details)
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.INTERNAL)
self.assertEqual(response_iterator.details(), expected_error_details)
self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)
with self.assertRaises(ValueError) as exception_context:
self._unary_stream(request, metadata=metadata)
self.assertIn(expected_error_details, str(exception_context.exception))
def testStreamRequestBlockingUnaryResponse(self):
request_iterator = (
@ -129,32 +116,18 @@ class InvalidMetadataTest(unittest.TestCase):
b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata
response_future = self._stream_unary.future(
request_iterator, metadata=metadata)
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertEqual(exception_context.exception.details(),
expected_error_details)
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.INTERNAL)
self.assertEqual(response_future.details(), expected_error_details)
self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
with self.assertRaises(ValueError) as exception_context:
self._stream_unary.future(request_iterator, metadata=metadata)
self.assertIn(expected_error_details, str(exception_context.exception))
def testStreamRequestStreamResponse(self):
request_iterator = (
b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestStreamResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata
response_iterator = self._stream_stream(
request_iterator, metadata=metadata)
with self.assertRaises(grpc.RpcError) as exception_context:
next(response_iterator)
self.assertEqual(exception_context.exception.details(),
expected_error_details)
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.INTERNAL)
self.assertEqual(response_iterator.details(), expected_error_details)
self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)
with self.assertRaises(ValueError) as exception_context:
self._stream_stream(request_iterator, metadata=metadata)
self.assertIn(expected_error_details, str(exception_context.exception))
if __name__ == '__main__':

Loading…
Cancel
Save