From 4821221e3a430b7276408048d8f3fb4ee4c55fd6 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Mon, 5 Nov 2018 16:21:41 -0800 Subject: [PATCH] Add wait-for-ready semantics * Include unit tests to test default behaviour, disable behaviour, enable behaviour of the wait-for-ready mechanism * Import flags constants from grpc_types.h * Use WaitGroup to wait for TRANSIENT_FAILURE state in unit test --- src/python/grpcio/grpc/__init__.py | 57 +++- src/python/grpcio/grpc/_channel.py | 121 +++++++-- .../grpcio/grpc/_cython/_cygrpc/grpc.pxi | 4 + .../grpc/_cython/_cygrpc/metadata.pyx.pxi | 6 + src/python/grpcio/grpc/_interceptor.py | 126 ++++++--- src/python/grpcio_tests/tests/tests.json | 1 + .../tests/unit/_metadata_flags_test.py | 251 ++++++++++++++++++ .../grpcio_tests/tests/unit/test_common.py | 26 ++ 8 files changed, 517 insertions(+), 75 deletions(-) create mode 100644 src/python/grpcio_tests/tests/unit/_metadata_flags_test.py diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 863696d2360..b3f00de2ab7 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -357,6 +357,7 @@ class ClientCallDetails(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready mechanism. """ @@ -609,7 +610,12 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): """Affords invoking a unary-unary RPC from client-side.""" @abc.abstractmethod - def __call__(self, request, timeout=None, metadata=None, credentials=None): + def __call__(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): """Synchronously invokes the underlying RPC. Args: @@ -619,6 +625,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready + mechanism Returns: The response value for the RPC. @@ -631,7 +639,12 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): raise NotImplementedError() @abc.abstractmethod - def with_call(self, request, timeout=None, metadata=None, credentials=None): + def with_call(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): """Synchronously invokes the underlying RPC. Args: @@ -641,6 +654,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready + mechanism Returns: The response value for the RPC and a Call value for the RPC. @@ -653,7 +668,12 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): raise NotImplementedError() @abc.abstractmethod - def future(self, request, timeout=None, metadata=None, credentials=None): + def future(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): """Asynchronously invokes the underlying RPC. Args: @@ -663,6 +683,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready + mechanism Returns: An object that is both a Call for the RPC and a Future. @@ -678,7 +700,12 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): """Affords invoking a unary-stream RPC from client-side.""" @abc.abstractmethod - def __call__(self, request, timeout=None, metadata=None, credentials=None): + def __call__(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): """Invokes the underlying RPC. Args: @@ -688,6 +715,8 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: An optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready + mechanism Returns: An object that is both a Call for the RPC and an iterator of @@ -706,7 +735,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): """Synchronously invokes the underlying RPC. Args: @@ -717,6 +747,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready + mechanism Returns: The response value for the RPC. @@ -733,7 +765,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): """Synchronously invokes the underlying RPC on the client. Args: @@ -744,6 +777,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready + mechanism Returns: The response value for the RPC and a Call object for the RPC. @@ -760,7 +795,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): """Asynchronously invokes the underlying RPC on the client. Args: @@ -770,6 +806,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready + mechanism Returns: An object that is both a Call for the RPC and a Future. @@ -789,7 +827,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): """Invokes the underlying RPC on the client. Args: @@ -799,6 +838,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable wait for ready + mechanism Returns: An object that is both a Call for the RPC and an iterator of diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 734eac38019..3ff76587484 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -467,10 +467,11 @@ def _end_unary_response_blocking(state, call, with_call, deadline): raise _Rendezvous(state, None, None, deadline) -def _stream_unary_invocation_operationses(metadata): +def _stream_unary_invocation_operationses(metadata, initial_metadata_flags): return ( ( - cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), + cygrpc.SendInitialMetadataOperation(metadata, + initial_metadata_flags), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), ), @@ -478,15 +479,19 @@ def _stream_unary_invocation_operationses(metadata): ) -def _stream_unary_invocation_operationses_and_tags(metadata): +def _stream_unary_invocation_operationses_and_tags(metadata, + initial_metadata_flags): return tuple(( operations, None, - ) for operations in _stream_unary_invocation_operationses(metadata)) + ) + for operations in _stream_unary_invocation_operationses( + metadata, initial_metadata_flags)) class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): + # pylint: disable=too-many-arguments def __init__(self, channel, managed_call, method, request_serializer, response_deserializer): self._channel = channel @@ -495,15 +500,18 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer - def _prepare(self, request, timeout, metadata): + def _prepare(self, request, timeout, metadata, wait_for_ready): deadline, serialized_request, rendezvous = _start_unary_request( request, timeout, self._request_serializer) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) if serialized_request is None: return None, None, None, rendezvous else: state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) operations = ( - cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), + cygrpc.SendInitialMetadataOperation(metadata, + initial_metadata_flags), cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), @@ -512,9 +520,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): ) return state, operations, deadline, None - def _blocking(self, request, timeout, metadata, credentials): + def _blocking(self, request, timeout, metadata, credentials, + wait_for_ready): state, operations, deadline, rendezvous = self._prepare( - request, timeout, metadata) + request, timeout, metadata, wait_for_ready) if state is None: raise rendezvous else: @@ -528,17 +537,34 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): _handle_event(event, state, self._response_deserializer) return state, call, - def __call__(self, request, timeout=None, metadata=None, credentials=None): - state, call, = self._blocking(request, timeout, metadata, credentials) + def __call__(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): + state, call, = self._blocking(request, timeout, metadata, credentials, + wait_for_ready) return _end_unary_response_blocking(state, call, False, None) - def with_call(self, request, timeout=None, metadata=None, credentials=None): - state, call, = self._blocking(request, timeout, metadata, credentials) + def with_call(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): + state, call, = self._blocking(request, timeout, metadata, credentials, + wait_for_ready) return _end_unary_response_blocking(state, call, True, None) - def future(self, request, timeout=None, metadata=None, credentials=None): + def future(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): state, operations, deadline, rendezvous = self._prepare( - request, timeout, metadata) + request, timeout, metadata, wait_for_ready) if state is None: raise rendezvous else: @@ -553,6 +579,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + # pylint: disable=too-many-arguments def __init__(self, channel, managed_call, method, request_serializer, response_deserializer): self._channel = channel @@ -561,16 +588,24 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer - def __call__(self, request, timeout=None, metadata=None, credentials=None): + def __call__(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): deadline, serialized_request, rendezvous = _start_unary_request( request, timeout, self._request_serializer) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) if serialized_request is None: raise rendezvous else: state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) operationses = ( ( - cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), + cygrpc.SendInitialMetadataOperation(metadata, + initial_metadata_flags), cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), @@ -589,6 +624,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): + # pylint: disable=too-many-arguments def __init__(self, channel, managed_call, method, request_serializer, response_deserializer): self._channel = channel @@ -597,13 +633,17 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer - def _blocking(self, request_iterator, timeout, metadata, credentials): + def _blocking(self, request_iterator, timeout, metadata, credentials, + wait_for_ready): deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) call = self._channel.segregated_call( 0, self._method, None, deadline, metadata, None if credentials is None else credentials._credentials, - _stream_unary_invocation_operationses_and_tags(metadata)) + _stream_unary_invocation_operationses_and_tags( + metadata, initial_metadata_flags)) _consume_request_iterator(request_iterator, state, call, self._request_serializer, None) while True: @@ -619,32 +659,38 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): state, call, = self._blocking(request_iterator, timeout, metadata, - credentials) + credentials, wait_for_ready) return _end_unary_response_blocking(state, call, False, None) def with_call(self, request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): state, call, = self._blocking(request_iterator, timeout, metadata, - credentials) + credentials, wait_for_ready) return _end_unary_response_blocking(state, call, True, None) def future(self, request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) event_handler = _event_handler(state, self._response_deserializer) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) call = self._managed_call( 0, self._method, None, deadline, metadata, None if credentials is None else credentials._credentials, - _stream_unary_invocation_operationses(metadata), event_handler) + _stream_unary_invocation_operationses( + metadata, initial_metadata_flags), event_handler) _consume_request_iterator(request_iterator, state, call, self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -652,6 +698,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): + # pylint: disable=too-many-arguments def __init__(self, channel, managed_call, method, request_serializer, response_deserializer): self._channel = channel @@ -664,12 +711,16 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): deadline = _deadline(timeout) state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) + initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( + wait_for_ready) operationses = ( ( - cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), + cygrpc.SendInitialMetadataOperation(metadata, + initial_metadata_flags), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), ), (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), @@ -684,6 +735,24 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): return _Rendezvous(state, call, self._response_deserializer, deadline) +class _InitialMetadataFlags(int): + """Stores immutable initial metadata flags""" + + def __new__(cls, value=_EMPTY_FLAGS): + value &= cygrpc.InitialMetadataFlags.used_mask + return super(_InitialMetadataFlags, cls).__new__(cls, value) + + def with_wait_for_ready(self, wait_for_ready): + if wait_for_ready is not None: + if wait_for_ready: + self = self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \ + cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) + elif not wait_for_ready: + self = self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \ + cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) + return self + + class _ChannelCallState(object): def __init__(self, channel): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index 47812193194..23428f0b0c0 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -140,6 +140,10 @@ cdef extern from "grpc/grpc.h": const int GRPC_WRITE_NO_COMPRESS const int GRPC_WRITE_USED_MASK + const int GRPC_INITIAL_METADATA_WAIT_FOR_READY + const int GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET + const int GRPC_INITIAL_METADATA_USED_MASK + const int GRPC_MAX_COMPLETION_QUEUE_PLUCKERS ctypedef struct grpc_completion_queue: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi index c39fef08fa4..53f0c7f0bbe 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi @@ -15,6 +15,12 @@ import collections +class InitialMetadataFlags: + used_mask = GRPC_INITIAL_METADATA_USED_MASK + wait_for_ready = GRPC_INITIAL_METADATA_WAIT_FOR_READY + wait_for_ready_explicitly_set = GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET + + _Metadatum = collections.namedtuple('_Metadatum', ('key', 'value',)) diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index 1d2d374ad19..43451140265 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -46,7 +46,7 @@ def service_pipeline(interceptors): class _ClientCallDetails( collections.namedtuple( '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials')), + ('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')), grpc.ClientCallDetails): pass @@ -72,7 +72,12 @@ def _unwrap_client_call_details(call_details, default_details): except AttributeError: credentials = default_details.credentials - return method, timeout, metadata, credentials + try: + wait_for_ready = call_details.wait_for_ready + except AttributeError: + wait_for_ready = default_details.wait_for_ready + + return method, timeout, metadata, credentials, wait_for_ready class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): @@ -193,28 +198,39 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._method = method self._interceptor = interceptor - def __call__(self, request, timeout=None, metadata=None, credentials=None): + def __call__(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): response, ignored_call = self._with_call( request, timeout=timeout, metadata=metadata, - credentials=credentials) + credentials=credentials, + wait_for_ready=wait_for_ready) return response - def _with_call(self, request, timeout=None, metadata=None, - credentials=None): - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials) + def _with_call(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): + client_call_details = _ClientCallDetails( + self._method, timeout, metadata, credentials, wait_for_ready) def continuation(new_details, request): - new_method, new_timeout, new_metadata, new_credentials = ( + new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( _unwrap_client_call_details(new_details, client_call_details)) try: response, call = self._thunk(new_method).with_call( request, timeout=new_timeout, metadata=new_metadata, - credentials=new_credentials) + credentials=new_credentials, + wait_for_ready=new_wait_for_ready) return _UnaryOutcome(response, call) except grpc.RpcError: raise @@ -225,25 +241,37 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): continuation, client_call_details, request) return call.result(), call - def with_call(self, request, timeout=None, metadata=None, credentials=None): + def with_call(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): return self._with_call( request, timeout=timeout, metadata=metadata, - credentials=credentials) + credentials=credentials, + wait_for_ready=wait_for_ready) - def future(self, request, timeout=None, metadata=None, credentials=None): - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials) + def future(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): + client_call_details = _ClientCallDetails( + self._method, timeout, metadata, credentials, wait_for_ready) def continuation(new_details, request): - new_method, new_timeout, new_metadata, new_credentials = ( + new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( _unwrap_client_call_details(new_details, client_call_details)) return self._thunk(new_method).future( request, timeout=new_timeout, metadata=new_metadata, - credentials=new_credentials) + credentials=new_credentials, + wait_for_ready=new_wait_for_ready) try: return self._interceptor.intercept_unary_unary( @@ -259,18 +287,24 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._method = method self._interceptor = interceptor - def __call__(self, request, timeout=None, metadata=None, credentials=None): - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials) + def __call__(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None): + client_call_details = _ClientCallDetails( + self._method, timeout, metadata, credentials, wait_for_ready) def continuation(new_details, request): - new_method, new_timeout, new_metadata, new_credentials = ( + new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( _unwrap_client_call_details(new_details, client_call_details)) return self._thunk(new_method)( request, timeout=new_timeout, metadata=new_metadata, - credentials=new_credentials) + credentials=new_credentials, + wait_for_ready=new_wait_for_ready) try: return self._interceptor.intercept_unary_stream( @@ -290,31 +324,35 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): response, ignored_call = self._with_call( request_iterator, timeout=timeout, metadata=metadata, - credentials=credentials) + credentials=credentials, + wait_for_ready=wait_for_ready) return response def _with_call(self, request_iterator, timeout=None, metadata=None, - credentials=None): - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials) + credentials=None, + wait_for_ready=None): + client_call_details = _ClientCallDetails( + self._method, timeout, metadata, credentials, wait_for_ready) def continuation(new_details, request_iterator): - new_method, new_timeout, new_metadata, new_credentials = ( + new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( _unwrap_client_call_details(new_details, client_call_details)) try: response, call = self._thunk(new_method).with_call( request_iterator, timeout=new_timeout, metadata=new_metadata, - credentials=new_credentials) + credentials=new_credentials, + wait_for_ready=new_wait_for_ready) return _UnaryOutcome(response, call) except grpc.RpcError: raise @@ -329,29 +367,33 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): request_iterator, timeout=None, metadata=None, - credentials=None): + credentials=None, + wait_for_ready=None): return self._with_call( request_iterator, timeout=timeout, metadata=metadata, - credentials=credentials) + credentials=credentials, + wait_for_ready=wait_for_ready) def future(self, request_iterator, timeout=None, metadata=None, - credentials=None): - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials) + credentials=None, + wait_for_ready=None): + client_call_details = _ClientCallDetails( + self._method, timeout, metadata, credentials, wait_for_ready) def continuation(new_details, request_iterator): - new_method, new_timeout, new_metadata, new_credentials = ( + new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( _unwrap_client_call_details(new_details, client_call_details)) return self._thunk(new_method).future( request_iterator, timeout=new_timeout, metadata=new_metadata, - credentials=new_credentials) + credentials=new_credentials, + wait_for_ready=new_wait_for_ready) try: return self._interceptor.intercept_stream_unary( @@ -371,18 +413,20 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): request_iterator, timeout=None, metadata=None, - credentials=None): - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials) + credentials=None, + wait_for_ready=None): + client_call_details = _ClientCallDetails( + self._method, timeout, metadata, credentials, wait_for_ready) def continuation(new_details, request_iterator): - new_method, new_timeout, new_metadata, new_credentials = ( + new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( _unwrap_client_call_details(new_details, client_call_details)) return self._thunk(new_method)( request_iterator, timeout=new_timeout, metadata=new_metadata, - credentials=new_credentials) + credentials=new_credentials, + wait_for_ready=new_wait_for_ready) try: return self._interceptor.intercept_stream_stream( diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 5505369867e..072a3d8a90e 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -48,6 +48,7 @@ "unit._invocation_defects_test.InvocationDefectsTest", "unit._logging_test.LoggingTest", "unit._metadata_code_details_test.MetadataCodeDetailsTest", + "unit._metadata_flags_test.MetadataFlagsTest", "unit._metadata_test.MetadataTest", "unit._reconnect_test.ReconnectTest", "unit._resource_exhausted_test.ResourceExhaustedTest", diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py new file mode 100644 index 00000000000..2d352e99d4a --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -0,0 +1,251 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests metadata flags feature by testing wait-for-ready semantics""" + +import time +import weakref +import unittest +import threading +import socket +from six.moves import queue + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + + +def handle_unary_unary(test, request, servicer_context): + return _RESPONSE + + +def handle_unary_stream(test, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + yield _RESPONSE + + +def handle_stream_unary(test, request_iterator, servicer_context): + for _ in request_iterator: + pass + return _RESPONSE + + +def handle_stream_stream(test, request_iterator, servicer_context): + for _ in request_iterator: + yield _RESPONSE + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, test, request_streaming, response_streaming): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + self.stream_stream = lambda req, ctx: handle_stream_stream(test, req, ctx) + elif self.request_streaming: + self.stream_unary = lambda req, ctx: handle_stream_unary(test, req, ctx) + elif self.response_streaming: + self.unary_stream = lambda req, ctx: handle_unary_stream(test, req, ctx) + else: + self.unary_unary = lambda req, ctx: handle_unary_unary(test, req, ctx) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, test): + self._test = test + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(self._test, False, False) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(self._test, False, True) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(self._test, True, False) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(self._test, True, True) + else: + return None + + +def get_free_loopback_tcp_port(): + tcp = socket.socket(socket.AF_INET6) + tcp.bind(('', 0)) + address_tuple = tcp.getsockname() + return tcp, "[::1]:%s" % (address_tuple[1]) + + +def create_dummy_channel(): + """Creating dummy channels is a workaround for retries""" + _, addr = get_free_loopback_tcp_port() + return grpc.insecure_channel(addr) + + +def perform_unary_unary_call(channel, wait_for_ready=None): + channel.unary_unary(_UNARY_UNARY).__call__( + _REQUEST, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +def perform_unary_unary_with_call(channel, wait_for_ready=None): + channel.unary_unary(_UNARY_UNARY).with_call( + _REQUEST, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +def perform_unary_unary_future(channel, wait_for_ready=None): + channel.unary_unary(_UNARY_UNARY).future( + _REQUEST, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready).result( + timeout=test_constants.LONG_TIMEOUT) + + +def perform_unary_stream_call(channel, wait_for_ready=None): + response_iterator = channel.unary_stream(_UNARY_STREAM).__call__( + _REQUEST, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + for _ in response_iterator: + pass + + +def perform_stream_unary_call(channel, wait_for_ready=None): + channel.stream_unary(_STREAM_UNARY).__call__( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +def perform_stream_unary_with_call(channel, wait_for_ready=None): + channel.stream_unary(_STREAM_UNARY).with_call( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +def perform_stream_unary_future(channel, wait_for_ready=None): + channel.stream_unary(_STREAM_UNARY).future( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready).result( + timeout=test_constants.LONG_TIMEOUT) + + +def perform_stream_stream_call(channel, wait_for_ready=None): + response_iterator = channel.stream_stream(_STREAM_STREAM).__call__( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + for _ in response_iterator: + pass + + +_ALL_CALL_CASES = [ + perform_unary_unary_call, perform_unary_unary_with_call, + perform_unary_unary_future, perform_unary_stream_call, + perform_stream_unary_call, perform_stream_unary_with_call, + perform_stream_unary_future, perform_stream_stream_call +] + + +class MetadataFlagsTest(unittest.TestCase): + + def check_connection_does_failfast(self, fn, channel, wait_for_ready=None): + try: + fn(channel, wait_for_ready) + self.fail("The Call should fail") + except BaseException as e: # pylint: disable=broad-except + self.assertIn('StatusCode.UNAVAILABLE', str(e)) + + def test_call_wait_for_ready_default(self): + for perform_call in _ALL_CALL_CASES: + self.check_connection_does_failfast(perform_call, + create_dummy_channel()) + + def test_call_wait_for_ready_disabled(self): + for perform_call in _ALL_CALL_CASES: + self.check_connection_does_failfast( + perform_call, create_dummy_channel(), wait_for_ready=False) + + def test_call_wait_for_ready_enabled(self): + # To test the wait mechanism, Python thread is required to make + # client set up first without handling them case by case. + # Also, Python thread don't pass the unhandled exceptions to + # main thread. So, it need another method to store the + # exceptions and raise them again in main thread. + unhandled_exceptions = queue.Queue() + tcp, addr = get_free_loopback_tcp_port() + wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) + + def wait_for_transient_failure(channel_connectivity): + if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: + wg.done() + + def test_call(perform_call): + try: + channel = grpc.insecure_channel(addr) + channel.subscribe(wait_for_transient_failure) + perform_call(channel, wait_for_ready=True) + except BaseException as e: # pylint: disable=broad-except + # If the call failed, the thread would be destroyed. The channel + # object can be collected before calling the callback, which + # will result in a deadlock. + wg.done() + unhandled_exceptions.put(e, True) + + test_threads = [] + for perform_call in _ALL_CALL_CASES: + test_thread = threading.Thread( + target=test_call, args=(perform_call,)) + test_thread.exception = None + test_thread.start() + test_threads.append(test_thread) + + # Start the server after the connections are waiting + wg.wait() + tcp.close() + server = test_common.test_server() + server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),)) + server.add_insecure_port(addr) + server.start() + + for test_thread in test_threads: + test_thread.join() + + # Stop the server to make test end properly + server.stop(0) + + if not unhandled_exceptions.empty(): + raise unhandled_exceptions.get(True) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/test_common.py b/src/python/grpcio_tests/tests/unit/test_common.py index 61717ae1358..bc3b24862dc 100644 --- a/src/python/grpcio_tests/tests/unit/test_common.py +++ b/src/python/grpcio_tests/tests/unit/test_common.py @@ -14,6 +14,7 @@ """Common code used throughout tests of gRPC.""" import collections +import threading from concurrent import futures import grpc @@ -107,3 +108,28 @@ def test_server(max_workers=10): return grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers), options=(('grpc.so_reuseport', 0),)) + + +class WaitGroup(object): + + def __init__(self, n=0): + self.count = n + self.cv = threading.Condition() + + def add(self, n): + self.cv.acquire() + self.count += n + self.cv.release() + + def done(self): + self.cv.acquire() + self.count -= 1 + if self.count == 0: + self.cv.notify_all() + self.cv.release() + + def wait(self): + self.cv.acquire() + while self.count > 0: + self.cv.wait() + self.cv.release()