Merge pull request #6752 from nathanielmanistaatgoogle/ga

Python GA channel and server
pull/6749/merge
Jan Tattermusch 9 years ago
commit e5d37dd50e
  1. 57
      src/python/grpcio/grpc/__init__.py
  2. 852
      src/python/grpcio/grpc/_channel.py
  3. 99
      src/python/grpcio/grpc/_common.py
  4. 734
      src/python/grpcio/grpc/_server.py
  5. 147
      src/python/grpcio/grpc/_utilities.py
  6. 3
      src/python/grpcio/tests/tests.json
  7. 161
      src/python/grpcio/tests/unit/_channel_connectivity_test.py
  8. 103
      src/python/grpcio/tests/unit/_channel_ready_future_test.py
  9. 775
      src/python/grpcio/tests/unit/_rpc_test.py

@ -787,3 +787,60 @@ class Server(six.with_metaclass(abc.ABCMeta)):
very early in the grace period).
"""
raise NotImplementedError()
################################# Functions ################################
def channel_ready_future(channel):
"""Creates a Future tracking when a Channel is ready.
Cancelling the returned Future does not tell the given Channel to abandon
attempts it may have been making to connect; cancelling merely deactivates the
returned Future's subscription to the given Channel's connectivity.
Args:
channel: A Channel.
Returns:
A Future that matures when the given Channel has connectivity
ChannelConnectivity.READY.
"""
from grpc import _utilities
return _utilities.channel_ready_future(channel)
def insecure_channel(target, options=None):
"""Creates an insecure Channel to a server.
Args:
target: The target to which to connect.
options: A sequence of string-value pairs according to which to configure
the created channel.
Returns:
A Channel to the target through which RPCs may be conducted.
"""
from grpc import _channel
return _channel.Channel(target, None, options)
def server(generic_rpc_handlers, thread_pool, options=None):
"""Creates a Server with which RPCs can be serviced.
The GenericRpcHandlers passed to this function needn't be the only
GenericRpcHandlers that will be used to serve RPCs; others may be added later
by calling add_generic_rpc_handlers any time before the returned server is
started.
Args:
generic_rpc_handlers: Some number of GenericRpcHandlers that will be used
to service RPCs after the returned Server is started.
thread_pool: A futures.ThreadPoolExecutor to be used by the returned Server
to service RPCs.
Returns:
A Server with which RPCs can be serviced.
"""
from grpc import _server
return _server.Server(generic_rpc_handlers, thread_pool)

@ -0,0 +1,852 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Invocation-side implementation of gRPC Python."""
import sys
import threading
import time
import grpc
from grpc import _common
from grpc import _grpcio_metadata
from grpc.framework.foundation import callable_util
from grpc._cython import cygrpc
_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
_EMPTY_FLAGS = 0
_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
_EMPTY_METADATA = cygrpc.Metadata(())
_UNARY_UNARY_INITIAL_DUE = (
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_message,
cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client,
)
_UNARY_STREAM_INITIAL_DUE = (
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_message,
cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_status_on_client,
)
_STREAM_UNARY_INITIAL_DUE = (
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client,
)
_STREAM_STREAM_INITIAL_DUE = (
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_status_on_client,
)
_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
'Exception calling channel subscription callback!')
def _deadline(timeout):
if timeout is None:
return None, _INFINITE_FUTURE
else:
deadline = time.time() + timeout
return deadline, cygrpc.Timespec(deadline)
def _unknown_code_details(unknown_cygrpc_code, details):
return b'Server sent unknown code {} and details "{}"'.format(
unknown_cygrpc_code, details)
def _wait_once_until(condition, until):
if until is None:
condition.wait()
else:
remaining = until - time.time()
if remaining < 0:
raise grpc.FutureTimeoutError()
else:
condition.wait(timeout=remaining)
class _RPCState(object):
def __init__(self, due, initial_metadata, trailing_metadata, code, details):
self.condition = threading.Condition()
# The cygrpc.OperationType objects representing events due from the RPC's
# completion queue.
self.due = set(due)
self.initial_metadata = initial_metadata
self.response = None
self.trailing_metadata = trailing_metadata
self.code = code
self.details = details
# The semantics of grpc.Future.cancel and grpc.Future.cancelled are
# slightly wonky, so they have to be tracked separately from the rest of the
# result of the RPC. This field tracks whether cancellation was requested
# prior to termination of the RPC.
self.cancelled = False
self.callbacks = []
def _abort(state, code, details):
if state.code is None:
state.code = code
state.details = details
if state.initial_metadata is None:
state.initial_metadata = _EMPTY_METADATA
state.trailing_metadata = _EMPTY_METADATA
def _handle_event(event, state, response_deserializer):
callbacks = []
for batch_operation in event.batch_operations:
operation_type = batch_operation.type
state.due.remove(operation_type)
if operation_type is cygrpc.OperationType.receive_initial_metadata:
state.initial_metadata = batch_operation.received_metadata
elif operation_type is cygrpc.OperationType.receive_message:
serialized_response = batch_operation.received_message.bytes()
if serialized_response is not None:
response = _common.deserialize(
serialized_response, response_deserializer)
if response is None:
details = b'Exception deserializing response!'
_abort(state, grpc.StatusCode.INTERNAL, details)
else:
state.response = response
elif operation_type is cygrpc.OperationType.receive_status_on_client:
state.trailing_metadata = batch_operation.received_metadata
if state.code is None:
code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE.get(
batch_operation.received_status_code)
if code is None:
state.code = grpc.StatusCode.UNKNOWN
state.details = _unknown_code_details(
batch_operation.received_status_code,
batch_operation.received_status_details)
else:
state.code = code
state.details = batch_operation.received_status_details
callbacks.extend(state.callbacks)
state.callbacks = None
return callbacks
def _event_handler(state, call, response_deserializer):
def handle_event(event):
with state.condition:
callbacks = _handle_event(event, state, response_deserializer)
state.condition.notify_all()
done = not state.due
for callback in callbacks:
callback()
return call if done else None
return handle_event
def _consume_request_iterator(
request_iterator, state, call, request_serializer):
event_handler = _event_handler(state, call, None)
def consume_request_iterator():
for request in request_iterator:
serialized_request = _common.serialize(request, request_serializer)
with state.condition:
if state.code is None and not state.cancelled:
if serialized_request is None:
call.cancel()
details = b'Exception serializing request!'
_abort(state, grpc.StatusCode.INTERNAL, details)
return
else:
operations = (
cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
state.due.add(cygrpc.OperationType.send_message)
while True:
state.condition.wait()
if state.code is None:
if cygrpc.OperationType.send_message not in state.due:
break
else:
return
else:
return
with state.condition:
if state.code is None:
operations = (
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
thread = threading.Thread(target=consume_request_iterator)
thread.start()
class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
def __init__(self, state, call, response_deserializer, deadline):
super(_Rendezvous, self).__init__()
self._state = state
self._call = call
self._response_deserializer = response_deserializer
self._deadline = deadline
def cancel(self):
with self._state.condition:
if self._state.code is None:
self._call.cancel()
self._state.cancelled = True
_abort(self._state, grpc.StatusCode.CANCELLED, b'Cancelled!')
self._state.condition.notify_all()
return False
def cancelled(self):
with self._state.condition:
return self._state.cancelled
def running(self):
with self._state.condition:
return self._state.code is None
def done(self):
with self._state.condition:
return self._state.code is not None
def result(self, timeout=None):
until = None if timeout is None else time.time() + timeout
with self._state.condition:
while True:
if self._state.code is None:
_wait_once_until(self._state.condition, until)
elif self._state.code is grpc.StatusCode.OK:
return self._state.response
elif self._state.cancelled:
raise grpc.FutureCancelledError()
else:
raise self
def exception(self, timeout=None):
until = None if timeout is None else time.time() + timeout
with self._state.condition:
while True:
if self._state.code is None:
_wait_once_until(self._state.condition, until)
elif self._state.code is grpc.StatusCode.OK:
return None
elif self._state.cancelled:
raise grpc.FutureCancelledError()
else:
return self
def traceback(self, timeout=None):
until = None if timeout is None else time.time() + timeout
with self._state.condition:
while True:
if self._state.code is None:
_wait_once_until(self._state.condition, until)
elif self._state.code is grpc.StatusCode.OK:
return None
elif self._state.cancelled:
raise grpc.FutureCancelledError()
else:
try:
raise self
except grpc.RpcError:
return sys.exc_info()[2]
def add_done_callback(self, fn):
with self._state.condition:
if self._state.code is None:
self._state.callbacks.append(lambda: fn(self))
return
fn(self)
def _next(self):
with self._state.condition:
if self._state.code is None:
event_handler = _event_handler(
self._state, self._call, self._response_deserializer)
self._call.start_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
event_handler)
self._state.due.add(cygrpc.OperationType.receive_message)
elif self._state.code is grpc.StatusCode.OK:
raise StopIteration()
else:
raise self
while True:
self._state.condition.wait()
if self._state.response is not None:
response = self._state.response
self._state.response = None
return response
elif cygrpc.OperationType.receive_message not in self._state.due:
if self._state.code is grpc.StatusCode.OK:
raise StopIteration()
elif self._state.code is not None:
raise self
def __iter__(self):
return self
def __next__(self):
return self._next()
def next(self):
return self._next()
def is_active(self):
with self._state.condition:
return self._state.code is None
def time_remaining(self):
if self._deadline is None:
return None
else:
return max(self._deadline - time.time(), 0)
def add_cancellation_callback(self, callback):
with self._state.condition:
if self._state.callbacks is None:
return False
else:
self._state.callbacks.append(lambda unused_future: callback())
return True
def initial_metadata(self):
with self._state.condition:
while self._state.initial_metadata is None:
self._state.condition.wait()
return self._state.initial_metadata
def trailing_metadata(self):
with self._state.condition:
while self._state.trailing_metadata is None:
self._state.condition.wait()
return self._state.trailing_metadata
def code(self):
with self._state.condition:
while self._state.code is None:
self._state.condition.wait()
return self._state.code
def details(self):
with self._state.condition:
while self._state.details is None:
self._state.condition.wait()
return self._state.details
def _repr(self):
with self._state.condition:
if self._state.code is None:
return '<_Rendezvous object of in-flight RPC>'
else:
return '<_Rendezvous of RPC that terminated with ({}, {})>'.format(
self._state.code, self._state.details)
def __repr__(self):
return self._repr()
def __str__(self):
return self._repr()
def __del__(self):
with self._state.condition:
if self._state.code is None:
self._call.cancel()
self._state.cancelled = True
self._state.code = grpc.StatusCode.CANCELLED
self._state.condition.notify_all()
def _start_unary_request(request, timeout, request_serializer):
deadline, deadline_timespec = _deadline(timeout)
serialized_request = _common.serialize(request, request_serializer)
if serialized_request is None:
state = _RPCState(
(), _EMPTY_METADATA, _EMPTY_METADATA, grpc.StatusCode.INTERNAL,
b'Exception serializing request!')
rendezvous = _Rendezvous(state, None, None, deadline)
return deadline, deadline_timespec, None, rendezvous
else:
return deadline, deadline_timespec, serialized_request, None
def _end_unary_response_blocking(state, with_call, deadline):
if state.code is grpc.StatusCode.OK:
if with_call:
rendezvous = _Rendezvous(state, None, None, deadline)
return state.response, rendezvous
else:
return state.response
else:
raise _Rendezvous(state, None, None, deadline)
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
def __init__(
self, channel, create_managed_call, method, request_serializer,
response_deserializer):
self._channel = channel
self._create_managed_call = create_managed_call
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def _prepare(self, request, timeout, metadata):
deadline, deadline_timespec, serialized_request, rendezvous = (
_start_unary_request(request, timeout, self._request_serializer))
if serialized_request is None:
return None, None, None, None, rendezvous
else:
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
operations = (
cygrpc.operation_send_initial_metadata(
_common.metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
return state, operations, deadline, deadline_timespec, None
def __call__(
self, request, timeout=None, metadata=None, credentials=None,
with_call=False):
state, operations, deadline, deadline_timespec, rendezvous = self._prepare(
request, timeout, metadata)
if rendezvous:
raise rendezvous
else:
completion_queue = cygrpc.CompletionQueue()
call = self._channel.create_call(
None, 0, completion_queue, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
call.start_batch(cygrpc.Operations(operations), None)
_handle_event(completion_queue.poll(), state, self._response_deserializer)
return _end_unary_response_blocking(state, with_call, deadline)
def future(self, request, timeout=None, metadata=None, credentials=None):
state, operations, deadline, deadline_timespec, rendezvous = self._prepare(
request, timeout, metadata)
if rendezvous:
return rendezvous
else:
call = self._create_managed_call(
None, 0, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_batch(cygrpc.Operations(operations), event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
def __init__(
self, channel, create_managed_call, method, request_serializer,
response_deserializer):
self._channel = channel
self._create_managed_call = create_managed_call
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(self, request, timeout=None, metadata=None, credentials=None):
deadline, deadline_timespec, serialized_request, rendezvous = (
_start_unary_request(request, timeout, self._request_serializer))
if serialized_request is None:
raise rendezvous
else:
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
call = self._create_managed_call(
None, 0, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
_common.metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
def __init__(
self, channel, create_managed_call, method, request_serializer,
response_deserializer):
self._channel = channel
self._create_managed_call = create_managed_call
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request_iterator, timeout=None, metadata=None, credentials=None,
with_call=False):
deadline, deadline_timespec = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
completion_queue = cygrpc.CompletionQueue()
call = self._channel.create_call(
None, 0, completion_queue, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
with state.condition:
call.start_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
None)
operations = (
cygrpc.operation_send_initial_metadata(
_common.metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), None)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
while True:
event = completion_queue.poll()
with state.condition:
_handle_event(event, state, self._response_deserializer)
state.condition.notify_all()
if not state.due:
break
return _end_unary_response_blocking(state, with_call, deadline)
def future(
self, request_iterator, timeout=None, metadata=None, credentials=None):
deadline, deadline_timespec = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
call = self._create_managed_call(
None, 0, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
_common.metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
return _Rendezvous(state, call, self._response_deserializer, deadline)
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
def __init__(
self, channel, create_managed_call, method, request_serializer,
response_deserializer):
self._channel = channel
self._create_managed_call = create_managed_call
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request_iterator, timeout=None, metadata=None, credentials=None):
deadline, deadline_timespec = _deadline(timeout)
state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None)
call = self._create_managed_call(
None, 0, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
_common.metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
return _Rendezvous(state, call, self._response_deserializer, deadline)
class _ChannelCallState(object):
def __init__(self, channel):
self.lock = threading.Lock()
self.channel = channel
self.completion_queue = cygrpc.CompletionQueue()
self.managed_calls = None
def _call_spin(state):
while True:
event = state.completion_queue.poll()
completed_call = event.tag(event)
if completed_call is not None:
with state.lock:
state.managed_calls.remove(completed_call)
if not state.managed_calls:
state.managed_calls = None
return
def _create_channel_managed_call(state):
def create_channel_managed_call(parent, flags, method, host, deadline):
"""Creates a managed cygrpc.Call.
Callers of this function must conduct at least one operation on the returned
call. The tags associated with operations conducted on the returned call
must be no-argument callables that return None to indicate that this channel
should continue polling for events associated with the call and return the
call itself to indicate that no more events associated with the call will be
generated.
Args:
parent: A cygrpc.Call to be used as the parent of the created call.
flags: An integer bitfield of call flags.
method: The RPC method.
host: A host string for the created call.
deadline: A cygrpc.Timespec to be the deadline of the created call.
Returns:
A cygrpc.Call with which to conduct an RPC.
"""
with state.lock:
call = state.channel.create_call(
parent, flags, state.completion_queue, method, host, deadline)
if state.managed_calls is None:
state.managed_calls = set((call,))
spin_thread = threading.Thread(target=_call_spin, args=(state,))
spin_thread.start()
else:
state.managed_calls.add(call)
return call
return create_channel_managed_call
class _ChannelConnectivityState(object):
def __init__(self, channel):
self.lock = threading.Lock()
self.channel = channel
self.polling = False
self.connectivity = None
self.try_to_connect = False
self.callbacks_and_connectivities = []
self.delivering = False
def _deliveries(state):
callbacks_needing_update = []
for callback_and_connectivity in state.callbacks_and_connectivities:
callback, callback_connectivity, = callback_and_connectivity
if callback_connectivity is not state.connectivity:
callbacks_needing_update.append(callback)
callback_and_connectivity[1] = state.connectivity
return callbacks_needing_update
def _deliver(state, initial_connectivity, initial_callbacks):
connectivity = initial_connectivity
callbacks = initial_callbacks
while True:
for callback in callbacks:
callable_util.call_logging_exceptions(
callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE,
connectivity)
with state.lock:
callbacks = _deliveries(state)
if callbacks:
connectivity = state.connectivity
else:
state.delivering = False
return
def _spawn_delivery(state, callbacks):
delivering_thread = threading.Thread(
target=_deliver, args=(state, state.connectivity, callbacks,))
delivering_thread.start()
state.delivering = True
# NOTE(https://github.com/grpc/grpc/issues/3064): We'd rather not poll.
def _poll_connectivity(state, channel, initial_try_to_connect):
try_to_connect = initial_try_to_connect
connectivity = channel.check_connectivity_state(try_to_connect)
with state.lock:
state.connectivity = (
_common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
connectivity])
callbacks = tuple(
callback for callback, unused_but_known_to_be_none_connectivity
in state.callbacks_and_connectivities)
for callback_and_connectivity in state.callbacks_and_connectivities:
callback_and_connectivity[1] = state.connectivity
if callbacks:
_spawn_delivery(state, callbacks)
completion_queue = cygrpc.CompletionQueue()
while True:
channel.watch_connectivity_state(
connectivity, cygrpc.Timespec(time.time() + 0.2),
completion_queue, None)
event = completion_queue.poll()
with state.lock:
if not state.callbacks_and_connectivities and not state.try_to_connect:
state.polling = False
state.connectivity = None
break
try_to_connect = state.try_to_connect
state.try_to_connect = False
if event.success or try_to_connect:
connectivity = channel.check_connectivity_state(try_to_connect)
with state.lock:
state.connectivity = (
_common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
connectivity])
if not state.delivering:
callbacks = _deliveries(state)
if callbacks:
_spawn_delivery(state, callbacks)
def _subscribe(state, callback, try_to_connect):
with state.lock:
if not state.callbacks_and_connectivities and not state.polling:
polling_thread = threading.Thread(
target=_poll_connectivity,
args=(state, state.channel, bool(try_to_connect)))
polling_thread.start()
state.polling = True
state.callbacks_and_connectivities.append([callback, None])
elif not state.delivering and state.connectivity is not None:
_spawn_delivery(state, (callback,))
state.try_to_connect |= bool(try_to_connect)
state.callbacks_and_connectivities.append(
[callback, state.connectivity])
else:
state.try_to_connect |= bool(try_to_connect)
state.callbacks_and_connectivities.append([callback, None])
def _unsubscribe(state, callback):
with state.lock:
for index, (subscribed_callback, unused_connectivity) in enumerate(
state.callbacks_and_connectivities):
if callback == subscribed_callback:
state.callbacks_and_connectivities.pop(index)
break
def _moot(state):
with state.lock:
del state.callbacks_and_connectivities[:]
def _options(options):
if options is None:
pairs = ((cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT),)
else:
pairs = list(options) + [
(cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)]
return cygrpc.ChannelArgs(
cygrpc.ChannelArg(arg_name, arg_value) for arg_name, arg_value in pairs)
class Channel(grpc.Channel):
def __init__(self, target, options, credentials):
self._channel = cygrpc.Channel(target, _options(options), credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
def subscribe(self, callback, try_to_connect=None):
_subscribe(self._connectivity_state, callback, try_to_connect)
def unsubscribe(self, callback):
_unsubscribe(self._connectivity_state, callback)
def unary_unary(
self, method, request_serializer=None, response_deserializer=None):
return _UnaryUnaryMultiCallable(
self._channel, _create_channel_managed_call(self._call_state), method,
request_serializer, response_deserializer)
def unary_stream(
self, method, request_serializer=None, response_deserializer=None):
return _UnaryStreamMultiCallable(
self._channel, _create_channel_managed_call(self._call_state), method,
request_serializer, response_deserializer)
def stream_unary(
self, method, request_serializer=None, response_deserializer=None):
return _StreamUnaryMultiCallable(
self._channel, _create_channel_managed_call(self._call_state), method,
request_serializer, response_deserializer)
def stream_stream(
self, method, request_serializer=None, response_deserializer=None):
return _StreamStreamMultiCallable(
self._channel, _create_channel_managed_call(self._call_state), method,
request_serializer, response_deserializer)
def __del__(self):
_moot(self._connectivity_state)

@ -0,0 +1,99 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Shared implementation."""
import logging
import six
import grpc
from grpc._cython import cygrpc
_EMPTY_METADATA = cygrpc.Metadata(())
CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = {
cygrpc.ConnectivityState.idle: grpc.ChannelConnectivity.IDLE,
cygrpc.ConnectivityState.connecting: grpc.ChannelConnectivity.CONNECTING,
cygrpc.ConnectivityState.ready: grpc.ChannelConnectivity.READY,
cygrpc.ConnectivityState.transient_failure:
grpc.ChannelConnectivity.TRANSIENT_FAILURE,
cygrpc.ConnectivityState.fatal_failure:
grpc.ChannelConnectivity.FATAL_FAILURE,
}
CYGRPC_STATUS_CODE_TO_STATUS_CODE = {
cygrpc.StatusCode.ok: grpc.StatusCode.OK,
cygrpc.StatusCode.cancelled: grpc.StatusCode.CANCELLED,
cygrpc.StatusCode.unknown: grpc.StatusCode.UNKNOWN,
cygrpc.StatusCode.invalid_argument: grpc.StatusCode.INVALID_ARGUMENT,
cygrpc.StatusCode.deadline_exceeded: grpc.StatusCode.DEADLINE_EXCEEDED,
cygrpc.StatusCode.not_found: grpc.StatusCode.NOT_FOUND,
cygrpc.StatusCode.already_exists: grpc.StatusCode.ALREADY_EXISTS,
cygrpc.StatusCode.permission_denied: grpc.StatusCode.PERMISSION_DENIED,
cygrpc.StatusCode.unauthenticated: grpc.StatusCode.UNAUTHENTICATED,
cygrpc.StatusCode.resource_exhausted: grpc.StatusCode.RESOURCE_EXHAUSTED,
cygrpc.StatusCode.failed_precondition: grpc.StatusCode.FAILED_PRECONDITION,
cygrpc.StatusCode.aborted: grpc.StatusCode.ABORTED,
cygrpc.StatusCode.out_of_range: grpc.StatusCode.OUT_OF_RANGE,
cygrpc.StatusCode.unimplemented: grpc.StatusCode.UNIMPLEMENTED,
cygrpc.StatusCode.internal: grpc.StatusCode.INTERNAL,
cygrpc.StatusCode.unavailable: grpc.StatusCode.UNAVAILABLE,
cygrpc.StatusCode.data_loss: grpc.StatusCode.DATA_LOSS,
}
STATUS_CODE_TO_CYGRPC_STATUS_CODE = {
grpc_code: cygrpc_code
for cygrpc_code, grpc_code in six.iteritems(
CYGRPC_STATUS_CODE_TO_STATUS_CODE)
}
def metadata(application_metadata):
return _EMPTY_METADATA if application_metadata is None else cygrpc.Metadata(
cygrpc.Metadatum(key, value) for key, value in application_metadata)
def _transform(message, transformer, exception_message):
if transformer is None:
return message
else:
try:
return transformer(message)
except Exception: # pylint: disable=broad-except
logging.exception(exception_message)
return None
def serialize(message, serializer):
return _transform(message, serializer, 'Exception serializing message!')
def deserialize(serialized_message, deserializer):
return _transform(serialized_message, deserializer,
'Exception deserializing message!')

@ -0,0 +1,734 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Service-side implementation of gRPC Python."""
import collections
import enum
import logging
import threading
import time
import grpc
from grpc import _common
from grpc._cython import cygrpc
from grpc.framework.foundation import callable_util
_SHUTDOWN_TAG = 'shutdown'
_REQUEST_CALL_TAG = 'request_call'
_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server'
_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata'
_RECEIVE_MESSAGE_TOKEN = 'receive_message'
_SEND_MESSAGE_TOKEN = 'send_message'
_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
'send_initial_metadata * send_message')
_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server'
_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
'send_initial_metadata * send_status_from_server')
_OPEN = 'open'
_CLOSED = 'closed'
_CANCELLED = 'cancelled'
_EMPTY_FLAGS = 0
_EMPTY_METADATA = cygrpc.Metadata(())
def _serialized_request(request_event):
return request_event.batch_operations[0].received_message.bytes()
def _code(state):
if state.code is None:
return cygrpc.StatusCode.ok
else:
code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(state.code)
return cygrpc.StatusCode.unknown if code is None else code
def _details(state):
return b'' if state.details is None else state.details
class _HandlerCallDetails(
collections.namedtuple(
'_HandlerCallDetails', ('method', 'invocation_metadata',)),
grpc.HandlerCallDetails):
pass
class _RPCState(object):
def __init__(self):
self.condition = threading.Condition()
self.due = set()
self.request = None
self.client = _OPEN
self.initial_metadata_allowed = True
self.disable_next_compression = False
self.trailing_metadata = None
self.code = None
self.details = None
self.statused = False
self.rpc_errors = []
self.callbacks = []
def _raise_rpc_error(state):
rpc_error = grpc.RpcError()
state.rpc_errors.append(rpc_error)
raise rpc_error
def _possibly_finish_call(state, token):
state.due.remove(token)
if (state.client is _CANCELLED or state.statused) and not state.due:
callbacks = state.callbacks
state.callbacks = None
return state, callbacks
else:
return None, ()
def _send_status_from_server(state, token):
def send_status_from_server(unused_send_status_from_server_event):
with state.condition:
return _possibly_finish_call(state, token)
return send_status_from_server
def _abort(state, call, code, details):
if state.client is not _CANCELLED:
if state.initial_metadata_allowed:
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_common.metadata(state.trailing_metadata), code, details,
_EMPTY_FLAGS),
)
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
else:
operations = (
cygrpc.operation_send_status_from_server(
_common.metadata(state.trailing_metadata), code, details,
_EMPTY_FLAGS),
)
token = _SEND_STATUS_FROM_SERVER_TOKEN
call.start_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, token))
state.statused = True
state.due.add(token)
def _receive_close_on_server(state):
def receive_close_on_server(receive_close_on_server_event):
with state.condition:
if receive_close_on_server_event.batch_operations[0].received_cancelled:
state.client = _CANCELLED
elif state.client is _OPEN:
state.client = _CLOSED
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
return receive_close_on_server
def _receive_message(state, call, request_deserializer):
def receive_message(receive_message_event):
serialized_request = _serialized_request(receive_message_event)
if serialized_request is None:
with state.condition:
if state.client is _OPEN:
state.client = _CLOSED
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
else:
request = _common.deserialize(serialized_request, request_deserializer)
with state.condition:
if request is None:
_abort(
state, call, cygrpc.StatusCode.internal,
b'Exception deserializing request!')
else:
state.request = request
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
return receive_message
def _send_initial_metadata(state):
def send_initial_metadata(unused_send_initial_metadata_event):
with state.condition:
return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
return send_initial_metadata
def _send_message(state, token):
def send_message(unused_send_message_event):
with state.condition:
state.condition.notify_all()
return _possibly_finish_call(state, token)
return send_message
class _Context(grpc.ServicerContext):
def __init__(self, rpc_event, state, request_deserializer):
self._rpc_event = rpc_event
self._state = state
self._request_deserializer = request_deserializer
def is_active(self):
with self._state.condition:
return self._state.client is not _CANCELLED and not self._state.statused
def time_remaining(self):
return max(self._rpc_event.request_call_details.deadline - time.time(), 0)
def cancel(self):
self._rpc_event.operation_call.cancel()
def add_callback(self, callback):
with self._state.condition:
if self._state.callbacks is None:
return False
else:
self._state.callbacks.append(callback)
return True
def disable_next_message_compression(self):
with self._state.condition:
self._state.disable_next_compression = True
def invocation_metadata(self):
return self._rpc_event.request_metadata
def peer(self):
return self._rpc_event.operation_call.peer()
def send_initial_metadata(self, initial_metadata):
with self._state.condition:
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
else:
if self._state.initial_metadata_allowed:
operation = cygrpc.operation_send_initial_metadata(
cygrpc.Metadata(initial_metadata), _EMPTY_FLAGS)
self._rpc_event.operation_call.start_batch(
cygrpc.Operations((operation,)),
_send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False
self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
else:
raise ValueError('Initial metadata no longer allowed!')
def set_trailing_metadata(self, trailing_metadata):
with self._state.condition:
self._state.trailing_metadata = trailing_metadata
def set_code(self, code):
with self._state.condition:
self._state.code = code
def set_details(self, details):
with self._state.condition:
self._state.details = details
class _RequestIterator(object):
def __init__(self, state, call, request_deserializer):
self._state = state
self._call = call
self._request_deserializer = request_deserializer
def _raise_or_start_receive_message(self):
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
elif self._state.client is _CLOSED or self._state.statused:
raise StopIteration()
else:
self._call.start_batch(
cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(self._state, self._call, self._request_deserializer))
self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
def _look_for_request(self):
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
elif (self._state.request is None and
_RECEIVE_MESSAGE_TOKEN not in self._state.due):
raise StopIteration()
else:
request = self._state.request
self._state.request = None
return request
def _next(self):
with self._state.condition:
self._raise_or_start_receive_message()
while True:
self._state.condition.wait()
request = self._look_for_request()
if request is not None:
return request
def __iter__(self):
return self
def __next__(self):
return self._next()
def next(self):
return self._next()
def _unary_request(rpc_event, state, request_deserializer):
def unary_request():
with state.condition:
if state.client is _CANCELLED or state.statused:
return None
else:
start_batch_result = rpc_event.operation_call.start_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(
state, rpc_event.operation_call, request_deserializer))
state.due.add(_RECEIVE_MESSAGE_TOKEN)
while True:
state.condition.wait()
if state.request is None:
if state.client is _CLOSED:
details = b'"{}" requires exactly one request message.'.format(
rpc_event.request_call_details.method)
# TODO(5992#issuecomment-220761992): really, what status code?
_abort(
state, rpc_event.operation_call,
cygrpc.StatusCode.unavailable, details)
return None
elif state.client is _CANCELLED:
return None
else:
request = state.request
state.request = None
return request
return unary_request
def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
context = _Context(rpc_event, state, request_deserializer)
try:
return behavior(argument, context)
except Exception as e: # pylint: disable=broad-except
with state.condition:
if e not in state.rpc_errors:
details = b'Exception calling application: {}'.format(e)
logging.exception(details)
_abort(
state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details)
return None
def _take_response_from_response_iterator(rpc_event, state, response_iterator):
try:
return next(response_iterator), True
except StopIteration:
return None, True
except Exception as e: # pylint: disable=broad-except
with state.condition:
if e not in state.rpc_errors:
details = b'Exception iterating responses: {}'.format(e)
logging.exception(details)
_abort(
state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details)
return None, False
def _serialize_response(rpc_event, state, response, response_serializer):
serialized_response = _common.serialize(response, response_serializer)
if serialized_response is None:
with state.condition:
_abort(
state, rpc_event.operation_call, cygrpc.StatusCode.internal,
b'Failed to serialize response!')
return None
else:
return serialized_response
def _send_response(rpc_event, state, serialized_response):
with state.condition:
if state.client is _CANCELLED or state.statused:
return False
else:
if state.initial_metadata_allowed:
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS),
)
state.initial_metadata_allowed = False
token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
else:
operations = (
cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS),
)
token = _SEND_MESSAGE_TOKEN
rpc_event.operation_call.start_batch(
cygrpc.Operations(operations), _send_message(state, token))
state.due.add(token)
while True:
state.condition.wait()
if token not in state.due:
return state.client is not _CANCELLED and not state.statused
def _status(rpc_event, state, serialized_response):
with state.condition:
if state.client is not _CANCELLED:
trailing_metadata = _common.metadata(state.trailing_metadata)
code = _code(state)
details = _details(state)
operations = [
cygrpc.operation_send_status_from_server(
trailing_metadata, code, details, _EMPTY_FLAGS),
]
if state.initial_metadata_allowed:
operations.append(
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS))
if serialized_response is not None:
operations.append(cygrpc.operation_send_message(
serialized_response, _EMPTY_FLAGS))
rpc_event.operation_call.start_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
state.statused = True
state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
def _unary_response_in_pool(
rpc_event, state, behavior, argument_thunk, request_deserializer,
response_serializer):
argument = argument_thunk()
if argument is not None:
response = _call_behavior(
rpc_event, state, behavior, argument, request_deserializer)
if response is not None:
serialized_response = _serialize_response(
rpc_event, state, response, response_serializer)
if serialized_response is not None:
_status(rpc_event, state, serialized_response)
return
def _stream_response_in_pool(
rpc_event, state, behavior, argument_thunk, request_deserializer,
response_serializer):
argument = argument_thunk()
if argument is not None:
response_iterator = _call_behavior(
rpc_event, state, behavior, argument, request_deserializer)
if response_iterator is not None:
while True:
response, proceed = _take_response_from_response_iterator(
rpc_event, state, response_iterator)
if proceed:
if response is None:
_status(rpc_event, state, None)
break
else:
serialized_response = _serialize_response(
rpc_event, state, response, response_serializer)
if serialized_response is not None:
proceed = _send_response(rpc_event, state, serialized_response)
if not proceed:
break
else:
break
else:
break
def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(
rpc_event, state, method_handler.request_deserializer)
thread_pool.submit(
_unary_response_in_pool, rpc_event, state, method_handler.unary_unary,
unary_request, method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(
rpc_event, state, method_handler.request_deserializer)
thread_pool.submit(
_stream_response_in_pool, rpc_event, state, method_handler.unary_stream,
unary_request, method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(
state, rpc_event.operation_call, method_handler.request_deserializer)
thread_pool.submit(
_unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
lambda: request_iterator, method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(
state, rpc_event.operation_call, method_handler.request_deserializer)
thread_pool.submit(
_stream_response_in_pool, rpc_event, state, method_handler.stream_stream,
lambda: request_iterator, method_handler.request_deserializer,
method_handler.response_serializer)
def _find_method_handler(rpc_event, generic_handlers):
for generic_handler in generic_handlers:
method_handler = generic_handler.service(
_HandlerCallDetails(
rpc_event.request_call_details.method, rpc_event.request_metadata))
if method_handler is not None:
return method_handler
else:
return None
def _handle_unrecognized_method(rpc_event):
operations = (
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
b'Method not found!', _EMPTY_FLAGS),
)
rpc_state = _RPCState()
rpc_event.operation_call.start_batch(
operations, lambda ignored_event: (rpc_state, (),))
return rpc_state
def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state = _RPCState()
with state.condition:
rpc_event.operation_call.start_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
_receive_close_on_server(state))
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
if method_handler.request_streaming:
if method_handler.response_streaming:
_handle_stream_stream(rpc_event, state, method_handler, thread_pool)
else:
_handle_stream_unary(rpc_event, state, method_handler, thread_pool)
else:
if method_handler.response_streaming:
_handle_unary_stream(rpc_event, state, method_handler, thread_pool)
else:
_handle_unary_unary(rpc_event, state, method_handler, thread_pool)
return state
def _handle_call(rpc_event, generic_handlers, thread_pool):
if rpc_event.request_call_details.method is not None:
method_handler = _find_method_handler(rpc_event, generic_handlers)
if method_handler is None:
return _handle_unrecognized_method(rpc_event)
else:
return _handle_with_method_handler(rpc_event, method_handler, thread_pool)
else:
return None
@enum.unique
class _ServerStage(enum.Enum):
STOPPED = 'stopped'
STARTED = 'started'
GRACE = 'grace'
class _ServerState(object):
def __init__(self, completion_queue, server, generic_handlers, thread_pool):
self.lock = threading.Lock()
self.completion_queue = completion_queue
self.server = server
self.generic_handlers = list(generic_handlers)
self.thread_pool = thread_pool
self.stage = _ServerStage.STOPPED
self.shutdown_events = None
# TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
self.rpc_states = set()
self.due = set()
def _add_generic_handlers(state, generic_handlers):
with state.lock:
state.generic_handlers.extend(generic_handlers)
def _add_insecure_port(state, address):
with state.lock:
return state.server.add_http2_port(address)
def _add_secure_port(state, address, server_credentials):
with state.lock:
return state.server.add_http2_port(address, server_credentials._credentials)
def _request_call(state):
state.server.request_call(
state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG)
state.due.add(_REQUEST_CALL_TAG)
# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
def _stop_serving(state):
if not state.rpc_states and not state.due:
for shutdown_event in state.shutdown_events:
shutdown_event.set()
state.stage = _ServerStage.STOPPED
return True
else:
return False
def _serve(state):
while True:
event = state.completion_queue.poll()
if event.tag is _SHUTDOWN_TAG:
with state.lock:
state.due.remove(_SHUTDOWN_TAG)
if _stop_serving(state):
return
elif event.tag is _REQUEST_CALL_TAG:
with state.lock:
state.due.remove(_REQUEST_CALL_TAG)
rpc_state = _handle_call(
event, state.generic_handlers, state.thread_pool)
if rpc_state is not None:
state.rpc_states.add(rpc_state)
if state.stage is _ServerStage.STARTED:
_request_call(state)
elif _stop_serving(state):
return
else:
rpc_state, callbacks = event.tag(event)
for callback in callbacks:
callable_util.call_logging_exceptions(
callback, 'Exception calling callback!')
if rpc_state is not None:
with state.lock:
state.rpc_states.remove(rpc_state)
if _stop_serving(state):
return
def _start(state):
with state.lock:
if state.stage is not _ServerStage.STOPPED:
raise ValueError('Cannot start already-started server!')
state.server.start()
state.stage = _ServerStage.STARTED
_request_call(state)
thread = threading.Thread(target=_serve, args=(state,))
thread.start()
def _stop(state, grace):
with state.lock:
if state.stage is _ServerStage.STOPPED:
shutdown_event = threading.Event()
shutdown_event.set()
return shutdown_event
else:
if state.stage is _ServerStage.STARTED:
state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
state.stage = _ServerStage.GRACE
state.shutdown_events = []
state.due.add(_SHUTDOWN_TAG)
shutdown_event = threading.Event()
state.shutdown_events.append(shutdown_event)
if grace is None:
state.server.cancel_all_calls()
# TODO(https://github.com/grpc/grpc/issues/6597): delete this loop.
for rpc_state in state.rpc_states:
with rpc_state.condition:
rpc_state.client = _CANCELLED
rpc_state.condition.notify_all()
else:
def cancel_all_calls_after_grace():
shutdown_event.wait(timeout=grace)
with state.lock:
state.server.cancel_all_calls()
# TODO(https://github.com/grpc/grpc/issues/6597): delete this loop.
for rpc_state in state.rpc_states:
with rpc_state.condition:
rpc_state.client = _CANCELLED
rpc_state.condition.notify_all()
thread = threading.Thread(target=cancel_all_calls_after_grace)
thread.start()
return shutdown_event
shutdown_event.wait()
return shutdown_event
class Server(grpc.Server):
def __init__(self, generic_handlers, thread_pool):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
server.register_completion_queue(completion_queue)
self._state = _ServerState(
completion_queue, server, generic_handlers, thread_pool)
def add_generic_rpc_handlers(self, generic_rpc_handlers):
_add_generic_handlers(self._state, generic_rpc_handlers)
def add_insecure_port(self, address):
return _add_insecure_port(self._state, address)
def add_secure_port(self, address, server_credentials):
return _add_secure_port(self._state, address, server_credentials)
def start(self):
_start(self._state)
def stop(self, grace):
return _stop(self._state, grace)
def __del__(self):
_stop(self._state, None)

@ -0,0 +1,147 @@
# Copyright 2015, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Internal utilities for gRPC Python."""
import threading
import time
import grpc
from grpc.framework.foundation import callable_util
_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = (
'Exception calling connectivity future "done" callback!')
class _ChannelReadyFuture(grpc.Future):
def __init__(self, channel):
self._condition = threading.Condition()
self._channel = channel
self._matured = False
self._cancelled = False
self._done_callbacks = []
def _block(self, timeout):
until = None if timeout is None else time.time() + timeout
with self._condition:
while True:
if self._cancelled:
raise grpc.FutureCancelledError()
elif self._matured:
return
else:
if until is None:
self._condition.wait()
else:
remaining = until - time.time()
if remaining < 0:
raise grpc.FutureTimeoutError()
else:
self._condition.wait(timeout=remaining)
def _update(self, connectivity):
with self._condition:
if (not self._cancelled and
connectivity is grpc.ChannelConnectivity.READY):
self._matured = True
self._channel.unsubscribe(self._update)
self._condition.notify_all()
done_callbacks = tuple(self._done_callbacks)
self._done_callbacks = None
else:
return
for done_callback in done_callbacks:
callable_util.call_logging_exceptions(
done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self)
def cancel(self):
with self._condition:
if not self._matured:
self._cancelled = True
self._channel.unsubscribe(self._update)
self._condition.notify_all()
done_callbacks = tuple(self._done_callbacks)
self._done_callbacks = None
else:
return False
for done_callback in done_callbacks:
callable_util.call_logging_exceptions(
done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self)
def cancelled(self):
with self._condition:
return self._cancelled
def running(self):
with self._condition:
return not self._cancelled and not self._matured
def done(self):
with self._condition:
return self._cancelled or self._matured
def result(self, timeout=None):
self._block(timeout)
return None
def exception(self, timeout=None):
self._block(timeout)
return None
def traceback(self, timeout=None):
self._block(timeout)
return None
def add_done_callback(self, fn):
with self._condition:
if not self._cancelled and not self._matured:
self._done_callbacks.append(fn)
return
fn(self)
def start(self):
with self._condition:
self._channel.subscribe(self._update, try_to_connect=True)
def __del__(self):
with self._condition:
if not self._cancelled and not self._matured:
self._channel.unsubscribe(self._update)
def channel_ready_future(channel):
ready_future = _ChannelReadyFuture(channel)
ready_future.start()
return ready_future

@ -6,6 +6,8 @@
"_beta_features_test.BetaFeaturesTest",
"_beta_features_test.ContextManagementAndLifecycleTest",
"_cancel_many_calls_test.CancelManyCallsTest",
"_channel_connectivity_test.ChannelConnectivityTest",
"_channel_ready_future_test.ChannelReadyFutureTest",
"_channel_test.ChannelTest",
"_connectivity_channel_test.ChannelConnectivityTest",
"_core_over_links_base_interface_test.AsyncEasyTest",
@ -43,6 +45,7 @@
"_low_test.HangingServerShutdown",
"_low_test.InsecureServerInsecureClient",
"_not_found_test.NotFoundTest",
"_rpc_test.RPCTest",
"_sanity_test.Sanity",
"_secure_interop_test.SecureInteropTest",
"_transmission_test.RoundTripTest",

@ -0,0 +1,161 @@
# Copyright 2015, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of grpc._channel.Channel connectivity."""
import threading
import time
import unittest
from concurrent import futures
import grpc
from grpc import _channel
from grpc import _server
from tests.unit.framework.common import test_constants
def _ready_in_connectivities(connectivities):
return grpc.ChannelConnectivity.READY in connectivities
def _last_connectivity_is_not_ready(connectivities):
return connectivities[-1] is not grpc.ChannelConnectivity.READY
class _Callback(object):
def __init__(self):
self._condition = threading.Condition()
self._connectivities = []
def update(self, connectivity):
with self._condition:
self._connectivities.append(connectivity)
self._condition.notify()
def connectivities(self):
with self._condition:
return tuple(self._connectivities)
def block_until_connectivities_satisfy(self, predicate):
with self._condition:
while True:
connectivities = tuple(self._connectivities)
if predicate(connectivities):
return connectivities
else:
self._condition.wait()
class ChannelConnectivityTest(unittest.TestCase):
def test_lonely_channel_connectivity(self):
callback = _Callback()
channel = _channel.Channel('localhost:12345', None, None)
channel.subscribe(callback.update, try_to_connect=False)
first_connectivities = callback.block_until_connectivities_satisfy(bool)
channel.subscribe(callback.update, try_to_connect=True)
second_connectivities = callback.block_until_connectivities_satisfy(
lambda connectivities: 2 <= len(connectivities))
# Wait for a connection that will never happen.
time.sleep(test_constants.SHORT_TIMEOUT)
third_connectivities = callback.connectivities()
channel.unsubscribe(callback.update)
fourth_connectivities = callback.connectivities()
channel.unsubscribe(callback.update)
fifth_connectivities = callback.connectivities()
self.assertSequenceEqual(
(grpc.ChannelConnectivity.IDLE,), first_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.READY, second_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.READY, third_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.READY, fourth_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.READY, fifth_connectivities)
def test_immediately_connectable_channel_connectivity(self):
server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0))
port = server.add_insecure_port('[::]:0')
server.start()
first_callback = _Callback()
second_callback = _Callback()
channel = _channel.Channel('localhost:{}'.format(port), None, None)
channel.subscribe(first_callback.update, try_to_connect=False)
first_connectivities = first_callback.block_until_connectivities_satisfy(
bool)
# Wait for a connection that will never happen because try_to_connect=True
# has not yet been passed.
time.sleep(test_constants.SHORT_TIMEOUT)
second_connectivities = first_callback.connectivities()
channel.subscribe(second_callback.update, try_to_connect=True)
third_connectivities = first_callback.block_until_connectivities_satisfy(
lambda connectivities: 2 <= len(connectivities))
fourth_connectivities = second_callback.block_until_connectivities_satisfy(
bool)
# Wait for a connection that will happen (or may already have happened).
first_callback.block_until_connectivities_satisfy(_ready_in_connectivities)
second_callback.block_until_connectivities_satisfy(_ready_in_connectivities)
del channel
self.assertSequenceEqual(
(grpc.ChannelConnectivity.IDLE,), first_connectivities)
self.assertSequenceEqual(
(grpc.ChannelConnectivity.IDLE,), second_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.FATAL_FAILURE, third_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.TRANSIENT_FAILURE,
fourth_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.FATAL_FAILURE, fourth_connectivities)
def test_reachable_then_unreachable_channel_connectivity(self):
server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0))
port = server.add_insecure_port('[::]:0')
server.start()
callback = _Callback()
channel = _channel.Channel('localhost:{}'.format(port), None, None)
channel.subscribe(callback.update, try_to_connect=True)
callback.block_until_connectivities_satisfy(_ready_in_connectivities)
# Now take down the server and confirm that channel readiness is repudiated.
server.stop(None)
callback.block_until_connectivities_satisfy(_last_connectivity_is_not_ready)
channel.unsubscribe(callback.update)
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -0,0 +1,103 @@
# Copyright 2015, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of grpc.channel_ready_future."""
import threading
import unittest
from concurrent import futures
import grpc
from grpc import _channel
from grpc import _server
from tests.unit.framework.common import test_constants
class _Callback(object):
def __init__(self):
self._condition = threading.Condition()
self._value = None
def accept_value(self, value):
with self._condition:
self._value = value
self._condition.notify_all()
def block_until_called(self):
with self._condition:
while self._value is None:
self._condition.wait()
return self._value
class ChannelReadyFutureTest(unittest.TestCase):
def test_lonely_channel_connectivity(self):
channel = grpc.insecure_channel('localhost:12345')
callback = _Callback()
ready_future = grpc.channel_ready_future(channel)
ready_future.add_done_callback(callback.accept_value)
with self.assertRaises(grpc.FutureTimeoutError):
ready_future.result(test_constants.SHORT_TIMEOUT)
self.assertFalse(ready_future.cancelled())
self.assertFalse(ready_future.done())
self.assertTrue(ready_future.running())
ready_future.cancel()
value_passed_to_callback = callback.block_until_called()
self.assertIs(ready_future, value_passed_to_callback)
self.assertTrue(ready_future.cancelled())
self.assertTrue(ready_future.done())
self.assertFalse(ready_future.running())
def test_immediately_connectable_channel_connectivity(self):
server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0))
port = server.add_insecure_port('[::]:0')
server.start()
channel = grpc.insecure_channel('localhost:{}'.format(port))
callback = _Callback()
ready_future = grpc.channel_ready_future(channel)
ready_future.add_done_callback(callback.accept_value)
self.assertIsNone(ready_future.result(test_constants.SHORT_TIMEOUT))
value_passed_to_callback = callback.block_until_called()
self.assertIs(ready_future, value_passed_to_callback)
self.assertFalse(ready_future.cancelled())
self.assertTrue(ready_future.done())
self.assertFalse(ready_future.running())
# Cancellation after maturity has no effect.
ready_future.cancel()
self.assertFalse(ready_future.cancelled())
self.assertTrue(ready_future.done())
self.assertFalse(ready_future.running())
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -0,0 +1,775 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test of gRPC Python's application-layer API."""
import itertools
import threading
import unittest
from concurrent import futures
import grpc
from grpc.framework.foundation import logging_pool
from tests.unit.framework.common import test_constants
from tests.unit.framework.common import test_control
_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) / 2:]
_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) / 3]
_UNARY_UNARY = b'/test/UnaryUnary'
_UNARY_STREAM = b'/test/UnaryStream'
_STREAM_UNARY = b'/test/StreamUnary'
_STREAM_STREAM = b'/test/StreamStream'
class _Callback(object):
def __init__(self):
self._condition = threading.Condition()
self._value = None
self._called = False
def __call__(self, value):
with self._condition:
self._value = value
self._called = True
self._condition.notify_all()
def value(self):
with self._condition:
while not self._called:
self._condition.wait()
return self._value
class _Handler(object):
def __init__(self, control):
self._control = control
def handle_unary_unary(self, request, servicer_context):
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
return request
def handle_unary_stream(self, request, servicer_context):
for _ in range(test_constants.STREAM_LENGTH):
self._control.control()
yield request
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
def handle_stream_unary(self, request_iterator, servicer_context):
if servicer_context is not None:
servicer_context.invocation_metadata()
self._control.control()
response_elements = []
for request in request_iterator:
self._control.control()
response_elements.append(request)
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
return b''.join(response_elements)
def handle_stream_stream(self, request_iterator, servicer_context):
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
for request in request_iterator:
self._control.control()
yield request
self._control.control()
class _MethodHandler(grpc.RpcMethodHandler):
def __init__(
self, request_streaming, response_streaming, request_deserializer,
response_serializer, unary_unary, unary_stream, stream_unary,
stream_stream):
self.request_streaming = request_streaming
self.response_streaming = response_streaming
self.request_deserializer = request_deserializer
self.response_serializer = response_serializer
self.unary_unary = unary_unary
self.unary_stream = unary_stream
self.stream_unary = stream_unary
self.stream_stream = stream_stream
class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self, handler):
self._handler = handler
def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY:
return _MethodHandler(
False, False, None, None, self._handler.handle_unary_unary, None,
None, None)
elif handler_call_details.method == _UNARY_STREAM:
return _MethodHandler(
False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None,
self._handler.handle_unary_stream, None, None)
elif handler_call_details.method == _STREAM_UNARY:
return _MethodHandler(
True, False, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, None,
self._handler.handle_stream_unary, None)
elif handler_call_details.method == _STREAM_STREAM:
return _MethodHandler(
True, True, None, None, None, None, None,
self._handler.handle_stream_stream)
else:
return None
def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
def _unary_stream_multi_callable(channel):
return channel.unary_stream(
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_unary_multi_callable(channel):
return channel.stream_unary(
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
class RPCTest(unittest.TestCase):
def setUp(self):
self._control = test_control.PauseFailControl()
self._handler = _Handler(self._control)
self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
self._server = grpc.server((), self._server_pool)
port = self._server.add_insecure_port(b'[::]:0')
self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
self._server.start()
self._channel = grpc.insecure_channel(b'localhost:%d' % port)
# TODO(nathaniel): Why is this necessary, and only in some development
# environments?
def tearDown(self):
del self._channel
del self._server
del self._server_pool
def testUnrecognizedMethod(self):
request = b'abc'
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(b'NoSuchMethod')(request)
self.assertEqual(
grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code())
def testSuccessfulUnaryRequestBlockingUnaryResponse(self):
request = b'\x07\x08'
expected_response = self._handler.handle_unary_unary(request, None)
multi_callable = _unary_unary_multi_callable(self._channel)
response = multi_callable(
request, metadata=(
(b'test', b'SuccessfulUnaryRequestBlockingUnaryResponse'),))
self.assertEqual(expected_response, response)
def testSuccessfulUnaryRequestBlockingUnaryResponseWithCall(self):
request = b'\x07\x08'
expected_response = self._handler.handle_unary_unary(request, None)
multi_callable = _unary_unary_multi_callable(self._channel)
response, call = multi_callable(
request, metadata=(
(b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),),
with_call=True)
self.assertEqual(expected_response, response)
self.assertIs(grpc.StatusCode.OK, call.code())
def testSuccessfulUnaryRequestFutureUnaryResponse(self):
request = b'\x07\x08'
expected_response = self._handler.handle_unary_unary(request, None)
multi_callable = _unary_unary_multi_callable(self._channel)
response_future = multi_callable.future(
request, metadata=(
(b'test', b'SuccessfulUnaryRequestFutureUnaryResponse'),))
response = response_future.result()
self.assertEqual(expected_response, response)
def testSuccessfulUnaryRequestStreamResponse(self):
request = b'\x37\x58'
expected_responses = tuple(self._handler.handle_unary_stream(request, None))
multi_callable = _unary_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request,
metadata=((b'test', b'SuccessfulUnaryRequestStreamResponse'),))
responses = tuple(response_iterator)
self.assertSequenceEqual(expected_responses, responses)
def testSuccessfulStreamRequestBlockingUnaryResponse(self):
requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary(iter(requests), None)
request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel)
response = multi_callable(
request_iterator,
metadata=((b'test', b'SuccessfulStreamRequestBlockingUnaryResponse'),))
self.assertEqual(expected_response, response)
def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self):
requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary(iter(requests), None)
request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel)
response, call = multi_callable(
request_iterator,
metadata=(
(b'test', b'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),
), with_call=True)
self.assertEqual(expected_response, response)
self.assertIs(grpc.StatusCode.OK, call.code())
def testSuccessfulStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary(iter(requests), None)
request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel)
response_future = multi_callable.future(
request_iterator,
metadata=(
(b'test', b'SuccessfulStreamRequestFutureUnaryResponse'),))
response = response_future.result()
self.assertEqual(expected_response, response)
def testSuccessfulStreamRequestStreamResponse(self):
requests = tuple(b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH))
expected_responses = tuple(
self._handler.handle_stream_stream(iter(requests), None))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request_iterator,
metadata=((b'test', b'SuccessfulStreamRequestStreamResponse'),))
responses = tuple(response_iterator)
self.assertSequenceEqual(expected_responses, responses)
def testSequentialInvocations(self):
first_request = b'\x07\x08'
second_request = b'\x0809'
expected_first_response = self._handler.handle_unary_unary(
first_request, None)
expected_second_response = self._handler.handle_unary_unary(
second_request, None)
multi_callable = _unary_unary_multi_callable(self._channel)
first_response = multi_callable(
first_request, metadata=((b'test', b'SequentialInvocations'),))
second_response = multi_callable(
second_request, metadata=((b'test', b'SequentialInvocations'),))
self.assertEqual(expected_first_response, first_response)
self.assertEqual(expected_second_response, second_response)
def testConcurrentBlockingInvocations(self):
pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary(iter(requests), None)
expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY
response_futures = [None] * test_constants.THREAD_CONCURRENCY
multi_callable = _stream_unary_multi_callable(self._channel)
for index in range(test_constants.THREAD_CONCURRENCY):
request_iterator = iter(requests)
response_future = pool.submit(
multi_callable, request_iterator,
metadata=((b'test', b'ConcurrentBlockingInvocations'),))
response_futures[index] = response_future
responses = tuple(
response_future.result() for response_future in response_futures)
pool.shutdown(wait=True)
self.assertSequenceEqual(expected_responses, responses)
def testConcurrentFutureInvocations(self):
requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary(iter(requests), None)
expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY
response_futures = [None] * test_constants.THREAD_CONCURRENCY
multi_callable = _stream_unary_multi_callable(self._channel)
for index in range(test_constants.THREAD_CONCURRENCY):
request_iterator = iter(requests)
response_future = multi_callable.future(
request_iterator,
metadata=((b'test', b'ConcurrentFutureInvocations'),))
response_futures[index] = response_future
responses = tuple(
response_future.result() for response_future in response_futures)
self.assertSequenceEqual(expected_responses, responses)
def testWaitingForSomeButNotAllConcurrentFutureInvocations(self):
pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
request = b'\x67\x68'
expected_response = self._handler.handle_unary_unary(request, None)
response_futures = [None] * test_constants.THREAD_CONCURRENCY
lock = threading.Lock()
test_is_running_cell = [True]
def wrap_future(future):
def wrap():
try:
return future.result()
except grpc.RpcError:
with lock:
if test_is_running_cell[0]:
raise
return None
return wrap
multi_callable = _unary_unary_multi_callable(self._channel)
for index in range(test_constants.THREAD_CONCURRENCY):
inner_response_future = multi_callable.future(
request,
metadata=(
(b'test',
b'WaitingForSomeButNotAllConcurrentFutureInvocations'),))
outer_response_future = pool.submit(wrap_future(inner_response_future))
response_futures[index] = outer_response_future
some_completed_response_futures_iterator = itertools.islice(
futures.as_completed(response_futures),
test_constants.THREAD_CONCURRENCY // 2)
for response_future in some_completed_response_futures_iterator:
self.assertEqual(expected_response, response_future.result())
with lock:
test_is_running_cell[0] = False
def testConsumingOneStreamResponseUnaryRequest(self):
request = b'\x57\x38'
multi_callable = _unary_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request,
metadata=(
(b'test', b'ConsumingOneStreamResponseUnaryRequest'),))
next(response_iterator)
def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self):
request = b'\x57\x38'
multi_callable = _unary_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request,
metadata=(
(b'test', b'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator)
def testConsumingSomeButNotAllStreamResponsesStreamRequest(self):
requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request_iterator,
metadata=(
(b'test', b'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator)
def testConsumingTooManyStreamResponsesStreamRequest(self):
requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request_iterator,
metadata=(
(b'test', b'ConsumingTooManyStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH):
next(response_iterator)
for _ in range(test_constants.STREAM_LENGTH):
with self.assertRaises(StopIteration):
next(response_iterator)
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.OK, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def testCancelledUnaryRequestUnaryResponse(self):
request = b'\x07\x17'
multi_callable = _unary_unary_multi_callable(self._channel)
with self._control.pause():
response_future = multi_callable.future(
request,
metadata=((b'test', b'CancelledUnaryRequestUnaryResponse'),))
response_future.cancel()
self.assertTrue(response_future.cancelled())
with self.assertRaises(grpc.FutureCancelledError):
response_future.result()
self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
def testCancelledUnaryRequestStreamResponse(self):
request = b'\x07\x19'
multi_callable = _unary_stream_multi_callable(self._channel)
with self._control.pause():
response_iterator = multi_callable(
request,
metadata=((b'test', b'CancelledUnaryRequestStreamResponse'),))
self._control.block_until_paused()
response_iterator.cancel()
with self.assertRaises(grpc.RpcError) as exception_context:
next(response_iterator)
self.assertIs(grpc.StatusCode.CANCELLED, exception_context.exception.code())
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def testCancelledStreamRequestUnaryResponse(self):
requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel)
with self._control.pause():
response_future = multi_callable.future(
request_iterator,
metadata=((b'test', b'CancelledStreamRequestUnaryResponse'),))
self._control.block_until_paused()
response_future.cancel()
self.assertTrue(response_future.cancelled())
with self.assertRaises(grpc.FutureCancelledError):
response_future.result()
self.assertIsNotNone(response_future.initial_metadata())
self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
self.assertIsNotNone(response_future.details())
self.assertIsNotNone(response_future.trailing_metadata())
def testCancelledStreamRequestStreamResponse(self):
requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
with self._control.pause():
response_iterator = multi_callable(
request_iterator,
metadata=((b'test', b'CancelledStreamRequestStreamResponse'),))
response_iterator.cancel()
with self.assertRaises(grpc.RpcError):
next(response_iterator)
self.assertIsNotNone(response_iterator.initial_metadata())
self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
self.assertIsNotNone(response_iterator.details())
self.assertIsNotNone(response_iterator.trailing_metadata())
def testExpiredUnaryRequestBlockingUnaryResponse(self):
request = b'\x07\x17'
multi_callable = _unary_unary_multi_callable(self._channel)
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(
request, timeout=test_constants.SHORT_TIMEOUT,
metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),),
with_call=True)
self.assertIsNotNone(exception_context.exception.initial_metadata())
self.assertIs(
grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
self.assertIsNotNone(exception_context.exception.details())
self.assertIsNotNone(exception_context.exception.trailing_metadata())
def testExpiredUnaryRequestFutureUnaryResponse(self):
request = b'\x07\x17'
callback = _Callback()
multi_callable = _unary_unary_multi_callable(self._channel)
with self._control.pause():
response_future = multi_callable.future(
request, timeout=test_constants.SHORT_TIMEOUT,
metadata=((b'test', b'ExpiredUnaryRequestFutureUnaryResponse'),))
response_future.add_done_callback(callback)
value_passed_to_callback = callback.value()
self.assertIs(response_future, value_passed_to_callback)
self.assertIsNotNone(response_future.initial_metadata())
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
self.assertIsNotNone(response_future.details())
self.assertIsNotNone(response_future.trailing_metadata())
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertIs(
grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
self.assertIsInstance(response_future.exception(), grpc.RpcError)
self.assertIs(
grpc.StatusCode.DEADLINE_EXCEEDED, response_future.exception().code())
def testExpiredUnaryRequestStreamResponse(self):
request = b'\x07\x19'
multi_callable = _unary_stream_multi_callable(self._channel)
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request, timeout=test_constants.SHORT_TIMEOUT,
metadata=((b'test', b'ExpiredUnaryRequestStreamResponse'),))
next(response_iterator)
self.assertIs(
grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code())
def testExpiredStreamRequestBlockingUnaryResponse(self):
requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel)
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(
request_iterator, timeout=test_constants.SHORT_TIMEOUT,
metadata=((b'test', b'ExpiredStreamRequestBlockingUnaryResponse'),))
self.assertIsNotNone(exception_context.exception.initial_metadata())
self.assertIs(
grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
self.assertIsNotNone(exception_context.exception.details())
self.assertIsNotNone(exception_context.exception.trailing_metadata())
def testExpiredStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
callback = _Callback()
multi_callable = _stream_unary_multi_callable(self._channel)
with self._control.pause():
response_future = multi_callable.future(
request_iterator, timeout=test_constants.SHORT_TIMEOUT,
metadata=((b'test', b'ExpiredStreamRequestFutureUnaryResponse'),))
response_future.add_done_callback(callback)
value_passed_to_callback = callback.value()
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
self.assertIs(
grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
self.assertIsInstance(response_future.exception(), grpc.RpcError)
self.assertIs(response_future, value_passed_to_callback)
self.assertIsNotNone(response_future.initial_metadata())
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
self.assertIsNotNone(response_future.details())
self.assertIsNotNone(response_future.trailing_metadata())
def testExpiredStreamRequestStreamResponse(self):
requests = tuple(b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request_iterator, timeout=test_constants.SHORT_TIMEOUT,
metadata=((b'test', b'ExpiredStreamRequestStreamResponse'),))
next(response_iterator)
self.assertIs(
grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code())
def testFailedUnaryRequestBlockingUnaryResponse(self):
request = b'\x37\x17'
multi_callable = _unary_unary_multi_callable(self._channel)
with self._control.fail():
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(
request,
metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),),
with_call=True)
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
def testFailedUnaryRequestFutureUnaryResponse(self):
request = b'\x37\x17'
callback = _Callback()
multi_callable = _unary_unary_multi_callable(self._channel)
with self._control.fail():
response_future = multi_callable.future(
request,
metadata=((b'test', b'FailedUnaryRequestFutureUnaryResponse'),))
response_future.add_done_callback(callback)
value_passed_to_callback = callback.value()
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertIs(
grpc.StatusCode.UNKNOWN, exception_context.exception.code())
self.assertIsInstance(response_future.exception(), grpc.RpcError)
self.assertIs(grpc.StatusCode.UNKNOWN, response_future.exception().code())
self.assertIs(response_future, value_passed_to_callback)
def testFailedUnaryRequestStreamResponse(self):
request = b'\x37\x17'
multi_callable = _unary_stream_multi_callable(self._channel)
with self.assertRaises(grpc.RpcError) as exception_context:
with self._control.fail():
response_iterator = multi_callable(
request,
metadata=((b'test', b'FailedUnaryRequestStreamResponse'),))
next(response_iterator)
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
def testFailedStreamRequestBlockingUnaryResponse(self):
requests = tuple(b'\x47\x58' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel)
with self._control.fail():
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(
request_iterator,
metadata=((b'test', b'FailedStreamRequestBlockingUnaryResponse'),))
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
def testFailedStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
callback = _Callback()
multi_callable = _stream_unary_multi_callable(self._channel)
with self._control.fail():
response_future = multi_callable.future(
request_iterator,
metadata=((b'test', b'FailedStreamRequestFutureUnaryResponse'),))
response_future.add_done_callback(callback)
value_passed_to_callback = callback.value()
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertIs(grpc.StatusCode.UNKNOWN, response_future.code())
self.assertIs(
grpc.StatusCode.UNKNOWN, exception_context.exception.code())
self.assertIsInstance(response_future.exception(), grpc.RpcError)
self.assertIs(response_future, value_passed_to_callback)
def testFailedStreamRequestStreamResponse(self):
requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
with self._control.fail():
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request_iterator,
metadata=((b'test', b'FailedStreamRequestStreamResponse'),))
tuple(response_iterator)
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
def testIgnoredUnaryRequestFutureUnaryResponse(self):
request = b'\x37\x17'
multi_callable = _unary_unary_multi_callable(self._channel)
multi_callable.future(
request,
metadata=((b'test', b'IgnoredUnaryRequestFutureUnaryResponse'),))
def testIgnoredUnaryRequestStreamResponse(self):
request = b'\x37\x17'
multi_callable = _unary_stream_multi_callable(self._channel)
multi_callable(
request,
metadata=((b'test', b'IgnoredUnaryRequestStreamResponse'),))
def testIgnoredStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel)
multi_callable.future(
request_iterator,
metadata=((b'test', b'IgnoredStreamRequestFutureUnaryResponse'),))
def testIgnoredStreamRequestStreamResponse(self):
requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel)
multi_callable(
request_iterator,
metadata=((b'test', b'IgnoredStreamRequestStreamResponse'),))
if __name__ == '__main__':
unittest.main(verbosity=2)
Loading…
Cancel
Save