From 13db8e517dd390cb1ede0146f5b0182e5f1e4dd5 Mon Sep 17 00:00:00 2001 From: Nathaniel Manista Date: Fri, 4 Sep 2015 19:03:32 +0000 Subject: [PATCH] Plumb protocol objects through RPC Framework core --- src/python/grpcio/grpc/_links/service.py | 4 +- .../grpcio/grpc/framework/core/_ingestion.py | 17 +- .../grpcio/grpc/framework/core/_interfaces.py | 25 +++ .../grpcio/grpc/framework/core/_operation.py | 16 +- .../grpcio/grpc/framework/core/_protocol.py | 176 ++++++++++++++++++ .../grpcio/grpc/framework/core/_reception.py | 14 +- .../grpcio/grpc/framework/crust/_calls.py | 8 +- .../grpcio/grpc/framework/crust/_control.py | 29 ++- .../grpcio/grpc/framework/crust/_service.py | 8 +- .../grpc/framework/interfaces/base/base.py | 17 ++ .../framework/interfaces/base/utilities.py | 13 +- .../framework/interfaces/base/test_cases.py | 16 +- 12 files changed, 319 insertions(+), 24 deletions(-) create mode 100644 src/python/grpcio/grpc/framework/core/_protocol.py diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py index 34d3b262c98..07772c7de3e 100644 --- a/src/python/grpcio/grpc/_links/service.py +++ b/src/python/grpcio/grpc/_links/service.py @@ -167,10 +167,12 @@ class _Kernel(object): request_deserializer, response_serializer, 1, _Read.READING, None, 1, _HighWrite.OPEN, _LowWrite.OPEN, False, None, None, None, set((_READ, _FINISH,))) + protocol = links.Protocol( + links.Protocol.Kind.SERVICER_CONTEXT, 'TODO: Service Context Object!') ticket = links.Ticket( call, 0, group, method, links.Ticket.Subscription.FULL, service_acceptance.deadline - time.time(), None, event.metadata, None, - None, None, None, None, 'TODO: Service Context Object!') + None, None, None, None, protocol) self._relay.add_value(ticket) def _on_read_event(self, event): diff --git a/src/python/grpcio/grpc/framework/core/_ingestion.py b/src/python/grpcio/grpc/framework/core/_ingestion.py index 9a7959a2ddb..4129a8ce43e 100644 --- a/src/python/grpcio/grpc/framework/core/_ingestion.py +++ b/src/python/grpcio/grpc/framework/core/_ingestion.py @@ -140,7 +140,7 @@ class _IngestionManager(_interfaces.IngestionManager): def __init__( self, lock, pool, subscription, subscription_creator, termination_manager, - transmission_manager, expiration_manager): + transmission_manager, expiration_manager, protocol_manager): """Constructor. Args: @@ -157,12 +157,14 @@ class _IngestionManager(_interfaces.IngestionManager): transmission_manager: The _interfaces.TransmissionManager for the operation. expiration_manager: The _interfaces.ExpirationManager for the operation. + protocol_manager: The _interfaces.ProtocolManager for the operation. """ self._lock = lock self._pool = pool self._termination_manager = termination_manager self._transmission_manager = transmission_manager self._expiration_manager = expiration_manager + self._protocol_manager = protocol_manager if subscription is None: self._subscription_creator = subscription_creator @@ -296,6 +298,8 @@ class _IngestionManager(_interfaces.IngestionManager): self._abort_and_notify( base.Outcome.Kind.REMOTE_FAILURE, code, details) elif outcome.return_value.subscription.kind is base.Subscription.Kind.FULL: + self._protocol_manager.set_protocol_receiver( + outcome.return_value.subscription.protocol_receiver) self._operator_post_create(outcome.return_value.subscription) else: # TODO(nathaniel): Support other subscriptions. @@ -378,7 +382,7 @@ class _IngestionManager(_interfaces.IngestionManager): def invocation_ingestion_manager( subscription, lock, pool, termination_manager, transmission_manager, - expiration_manager): + expiration_manager, protocol_manager): """Creates an IngestionManager appropriate for invocation-side use. Args: @@ -390,18 +394,20 @@ def invocation_ingestion_manager( transmission_manager: The _interfaces.TransmissionManager for the operation. expiration_manager: The _interfaces.ExpirationManager for the operation. + protocol_manager: The _interfaces.ProtocolManager for the operation. Returns: An IngestionManager appropriate for invocation-side use. """ return _IngestionManager( lock, pool, subscription, None, termination_manager, transmission_manager, - expiration_manager) + expiration_manager, protocol_manager) def service_ingestion_manager( servicer, operation_context, output_operator, lock, pool, - termination_manager, transmission_manager, expiration_manager): + termination_manager, transmission_manager, expiration_manager, + protocol_manager): """Creates an IngestionManager appropriate for service-side use. The returned IngestionManager will require its set_group_and_name method to be @@ -420,6 +426,7 @@ def service_ingestion_manager( transmission_manager: The _interfaces.TransmissionManager for the operation. expiration_manager: The _interfaces.ExpirationManager for the operation. + protocol_manager: The _interfaces.ProtocolManager for the operation. Returns: An IngestionManager appropriate for service-side use. @@ -428,4 +435,4 @@ def service_ingestion_manager( servicer, operation_context, output_operator) return _IngestionManager( lock, pool, None, subscription_creator, termination_manager, - transmission_manager, expiration_manager) + transmission_manager, expiration_manager, protocol_manager) diff --git a/src/python/grpcio/grpc/framework/core/_interfaces.py b/src/python/grpcio/grpc/framework/core/_interfaces.py index 7ac440722cb..ffa686b2b79 100644 --- a/src/python/grpcio/grpc/framework/core/_interfaces.py +++ b/src/python/grpcio/grpc/framework/core/_interfaces.py @@ -203,6 +203,31 @@ class ExpirationManager(object): raise NotImplementedError() +class ProtocolManager(object): + """A manager of protocol-specific values passing through an operation.""" + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def set_protocol_receiver(self, protocol_receiver): + """Registers the customer object that will receive protocol objects. + + Args: + protocol_receiver: A base.ProtocolReceiver to which protocol objects for + the operation should be passed. + """ + raise NotImplementedError() + + @abc.abstractmethod + def accept_protocol_context(self, protocol_context): + """Accepts the protocol context object for the operation. + + Args: + protocol_context: An object designated for use as the protocol context + of the operation, with further semantics implementation-determined. + """ + raise NotImplementedError() + + class EmissionManager(base.Operator): """A manager of values emitted by customer code.""" __metaclass__ = abc.ABCMeta diff --git a/src/python/grpcio/grpc/framework/core/_operation.py b/src/python/grpcio/grpc/framework/core/_operation.py index d4eacc5a3fa..020c0c9ed9a 100644 --- a/src/python/grpcio/grpc/framework/core/_operation.py +++ b/src/python/grpcio/grpc/framework/core/_operation.py @@ -36,6 +36,7 @@ from grpc.framework.core import _emission from grpc.framework.core import _expiration from grpc.framework.core import _ingestion from grpc.framework.core import _interfaces +from grpc.framework.core import _protocol from grpc.framework.core import _reception from grpc.framework.core import _termination from grpc.framework.core import _transmission @@ -123,16 +124,19 @@ def invocation_operate( operation_id, ticket_sink, lock, pool, termination_manager) expiration_manager = _expiration.invocation_expiration_manager( timeout, lock, termination_manager, transmission_manager) + protocol_manager = _protocol.invocation_protocol_manager( + subscription, lock, pool, termination_manager, transmission_manager, + expiration_manager) operation_context = _context.OperationContext( lock, termination_manager, transmission_manager, expiration_manager) emission_manager = _emission.EmissionManager( lock, termination_manager, transmission_manager, expiration_manager) ingestion_manager = _ingestion.invocation_ingestion_manager( subscription, lock, pool, termination_manager, transmission_manager, - expiration_manager) + expiration_manager, protocol_manager) reception_manager = _reception.ReceptionManager( termination_manager, transmission_manager, expiration_manager, - ingestion_manager) + protocol_manager, ingestion_manager) termination_manager.set_expiration_manager(expiration_manager) transmission_manager.set_expiration_manager(expiration_manager) @@ -174,16 +178,20 @@ def service_operate( ticket.timeout, servicer_package.default_timeout, servicer_package.maximum_timeout, lock, termination_manager, transmission_manager) + protocol_manager = _protocol.service_protocol_manager( + lock, pool, termination_manager, transmission_manager, + expiration_manager) operation_context = _context.OperationContext( lock, termination_manager, transmission_manager, expiration_manager) emission_manager = _emission.EmissionManager( lock, termination_manager, transmission_manager, expiration_manager) ingestion_manager = _ingestion.service_ingestion_manager( servicer_package.servicer, operation_context, emission_manager, lock, - pool, termination_manager, transmission_manager, expiration_manager) + pool, termination_manager, transmission_manager, expiration_manager, + protocol_manager) reception_manager = _reception.ReceptionManager( termination_manager, transmission_manager, expiration_manager, - ingestion_manager) + protocol_manager, ingestion_manager) termination_manager.set_expiration_manager(expiration_manager) transmission_manager.set_expiration_manager(expiration_manager) diff --git a/src/python/grpcio/grpc/framework/core/_protocol.py b/src/python/grpcio/grpc/framework/core/_protocol.py new file mode 100644 index 00000000000..3177b5e302a --- /dev/null +++ b/src/python/grpcio/grpc/framework/core/_protocol.py @@ -0,0 +1,176 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""State and behavior for passing protocol objects in an operation.""" + +import collections +import enum + +from grpc.framework.core import _constants +from grpc.framework.core import _interfaces +from grpc.framework.core import _utilities +from grpc.framework.foundation import callable_util +from grpc.framework.interfaces.base import base + +_EXCEPTION_LOG_MESSAGE = 'Exception delivering protocol object!' + +_LOCAL_FAILURE_OUTCOME = _utilities.Outcome( + base.Outcome.Kind.LOCAL_FAILURE, None, None) + + +class _Awaited( + collections.namedtuple('_Awaited', ('kind', 'value',))): + + @enum.unique + class Kind(enum.Enum): + NOT_YET_ARRIVED = 'not yet arrived' + ARRIVED = 'arrived' + +_NOT_YET_ARRIVED = _Awaited(_Awaited.Kind.NOT_YET_ARRIVED, None) +_ARRIVED_AND_NONE = _Awaited(_Awaited.Kind.ARRIVED, None) + + +class _Transitory( + collections.namedtuple('_Transitory', ('kind', 'value',))): + + @enum.unique + class Kind(enum.Enum): + NOT_YET_SEEN = 'not yet seen' + PRESENT = 'present' + GONE = 'gone' + +_NOT_YET_SEEN = _Transitory(_Transitory.Kind.NOT_YET_SEEN, None) +_GONE = _Transitory(_Transitory.Kind.GONE, None) + + +class _ProtocolManager(_interfaces.ProtocolManager): + """An implementation of _interfaces.ExpirationManager.""" + + def __init__( + self, protocol_receiver, lock, pool, termination_manager, + transmission_manager, expiration_manager): + """Constructor. + + Args: + protocol_receiver: An _Awaited wrapping of the base.ProtocolReceiver to + which protocol objects should be passed during the operation. May be + of kind _Awaited.Kind.NOT_YET_ARRIVED if the customer's subscription is + not yet known and may be of kind _Awaited.Kind.ARRIVED but with a value + of None if the customer's subscription did not include a + ProtocolReceiver. + lock: The operation-wide lock. + pool: A thread pool. + termination_manager: The _interfaces.TerminationManager for the operation. + transmission_manager: The _interfaces.TransmissionManager for the + operation. + expiration_manager: The _interfaces.ExpirationManager for the operation. + """ + self._lock = lock + self._pool = pool + self._termination_manager = termination_manager + self._transmission_manager = transmission_manager + self._expiration_manager = expiration_manager + + self._protocol_receiver = protocol_receiver + self._context = _NOT_YET_SEEN + + def _abort_and_notify(self, outcome): + if self._termination_manager.outcome is None: + self._termination_manager.abort(outcome) + self._transmission_manager.abort(outcome) + self._expiration_manager.terminate() + + def _deliver(self, behavior, value): + def deliver(): + delivery_outcome = callable_util.call_logging_exceptions( + behavior, _EXCEPTION_LOG_MESSAGE, value) + if delivery_outcome.kind is callable_util.Outcome.Kind.RAISED: + with self._lock: + self._abort_and_notify(_LOCAL_FAILURE_OUTCOME) + self._pool.submit( + callable_util.with_exceptions_logged( + deliver, _constants.INTERNAL_ERROR_LOG_MESSAGE)) + + def set_protocol_receiver(self, protocol_receiver): + """See _interfaces.ProtocolManager.set_protocol_receiver for spec.""" + self._protocol_receiver = _Awaited(_Awaited.Kind.ARRIVED, protocol_receiver) + if (self._context.kind is _Transitory.Kind.PRESENT and + protocol_receiver is not None): + self._deliver(protocol_receiver.context, self._context.value) + self._context = _GONE + + def accept_protocol_context(self, protocol_context): + """See _interfaces.ProtocolManager.accept_protocol_context for spec.""" + if self._protocol_receiver.kind is _Awaited.Kind.ARRIVED: + if self._protocol_receiver.value is not None: + self._deliver(self._protocol_receiver.value.context, protocol_context) + self._context = _GONE + else: + self._context = _Transitory(_Transitory.Kind.PRESENT, protocol_context) + + +def invocation_protocol_manager( + subscription, lock, pool, termination_manager, transmission_manager, + expiration_manager): + """Creates an _interfaces.ProtocolManager for invocation-side use. + + Args: + subscription: The local customer's subscription to the operation. + lock: The operation-wide lock. + pool: A thread pool. + termination_manager: The _interfaces.TerminationManager for the operation. + transmission_manager: The _interfaces.TransmissionManager for the + operation. + expiration_manager: The _interfaces.ExpirationManager for the operation. + """ + if subscription.kind is base.Subscription.Kind.FULL: + awaited_protocol_receiver = _Awaited( + _Awaited.Kind.ARRIVED, subscription.protocol_receiver) + else: + awaited_protocol_receiver = _ARRIVED_AND_NONE + return _ProtocolManager( + awaited_protocol_receiver, lock, pool, termination_manager, + transmission_manager, expiration_manager) + + +def service_protocol_manager( + lock, pool, termination_manager, transmission_manager, expiration_manager): + """Creates an _interfaces.ProtocolManager for service-side use. + + Args: + lock: The operation-wide lock. + pool: A thread pool. + termination_manager: The _interfaces.TerminationManager for the operation. + transmission_manager: The _interfaces.TransmissionManager for the + operation. + expiration_manager: The _interfaces.ExpirationManager for the operation. + """ + return _ProtocolManager( + _NOT_YET_ARRIVED, lock, pool, termination_manager, transmission_manager, + expiration_manager) diff --git a/src/python/grpcio/grpc/framework/core/_reception.py b/src/python/grpcio/grpc/framework/core/_reception.py index d374cf0c8ce..ff81450dee9 100644 --- a/src/python/grpcio/grpc/framework/core/_reception.py +++ b/src/python/grpcio/grpc/framework/core/_reception.py @@ -51,23 +51,31 @@ _RECEPTION_FAILURE_OUTCOME = _utilities.Outcome( base.Outcome.Kind.RECEPTION_FAILURE, None, None) +def _carrying_protocol_context(ticket): + return ticket.protocol is not None and ticket.protocol.kind in ( + links.Protocol.Kind.INVOCATION_CONTEXT, + links.Protocol.Kind.SERVICER_CONTEXT,) + + class ReceptionManager(_interfaces.ReceptionManager): """A ReceptionManager based around a _Receiver passed to it.""" def __init__( self, termination_manager, transmission_manager, expiration_manager, - ingestion_manager): + protocol_manager, ingestion_manager): """Constructor. Args: termination_manager: The operation's _interfaces.TerminationManager. transmission_manager: The operation's _interfaces.TransmissionManager. expiration_manager: The operation's _interfaces.ExpirationManager. + protocol_manager: The operation's _interfaces.ProtocolManager. ingestion_manager: The operation's _interfaces.IngestionManager. """ self._termination_manager = termination_manager self._transmission_manager = transmission_manager self._expiration_manager = expiration_manager + self._protocol_manager = protocol_manager self._ingestion_manager = ingestion_manager self._lowest_unseen_sequence_number = 0 @@ -100,6 +108,10 @@ class ReceptionManager(_interfaces.ReceptionManager): def _process_one(self, ticket): if ticket.sequence_number == 0: self._ingestion_manager.set_group_and_method(ticket.group, ticket.method) + if _carrying_protocol_context(ticket): + self._protocol_manager.accept_protocol_context(ticket.protocol.value) + else: + self._protocol_manager.accept_protocol_context(None) if ticket.timeout is not None: self._expiration_manager.change_timeout(ticket.timeout) if ticket.termination is None: diff --git a/src/python/grpcio/grpc/framework/crust/_calls.py b/src/python/grpcio/grpc/framework/crust/_calls.py index 68db9fab8ef..bff940d7471 100644 --- a/src/python/grpcio/grpc/framework/crust/_calls.py +++ b/src/python/grpcio/grpc/framework/crust/_calls.py @@ -42,10 +42,12 @@ def _invoke( end, group, method, timeout, protocol_options, initial_metadata, payload, complete): rendezvous = _control.Rendezvous(None, None) + subscription = utilities.full_subscription( + rendezvous, _control.protocol_receiver(rendezvous)) operation_context, operator = end.operate( - group, method, utilities.full_subscription(rendezvous), timeout, - protocol_options=protocol_options, initial_metadata=initial_metadata, - payload=payload, completion=_EMPTY_COMPLETION if complete else None) + group, method, subscription, timeout, protocol_options=protocol_options, + initial_metadata=initial_metadata, payload=payload, + completion=_EMPTY_COMPLETION if complete else None) rendezvous.set_operator_and_context(operator, operation_context) outcome = operation_context.add_termination_callback(rendezvous.set_outcome) if outcome is not None: diff --git a/src/python/grpcio/grpc/framework/crust/_control.py b/src/python/grpcio/grpc/framework/crust/_control.py index e02a41d7209..5e9efdf7322 100644 --- a/src/python/grpcio/grpc/framework/crust/_control.py +++ b/src/python/grpcio/grpc/framework/crust/_control.py @@ -182,6 +182,8 @@ class Rendezvous(base.Operator, future.Future, stream.Consumer, face.Call): self._operator = operator self._operation_context = operation_context + self._protocol_context = _NOT_YET_ARRIVED + self._up_initial_metadata = _NOT_YET_ARRIVED self._up_payload = None self._up_allowance = 1 @@ -444,7 +446,13 @@ class Rendezvous(base.Operator, future.Future, stream.Consumer, face.Call): def protocol_context(self): with self._condition: - raise NotImplementedError('TODO: protocol context implementation!') + while True: + if self._protocol_context.kind is _Awaited.Kind.ARRIVED: + return self._protocol_context.value + elif self._termination.abortion_error is not None: + raise self._termination.abortion_error + else: + self._condition.wait() def initial_metadata(self): with self._condition: @@ -518,11 +526,30 @@ class Rendezvous(base.Operator, future.Future, stream.Consumer, face.Call): else: self._down_details = _Transitory(_Transitory.Kind.PRESENT, details) + def set_protocol_context(self, protocol_context): + with self._condition: + self._protocol_context = _Awaited( + _Awaited.Kind.ARRIVED, protocol_context) + self._condition.notify_all() + def set_outcome(self, outcome): with self._condition: return self._set_outcome(outcome) +class _ProtocolReceiver(base.ProtocolReceiver): + + def __init__(self, rendezvous): + self._rendezvous = rendezvous + + def context(self, protocol_context): + self._rendezvous.set_protocol_context(protocol_context) + + +def protocol_receiver(rendezvous): + return _ProtocolReceiver(rendezvous) + + def pool_wrap(behavior, operation_context): """Wraps an operation-related behavior so that it may be called in a pool. diff --git a/src/python/grpcio/grpc/framework/crust/_service.py b/src/python/grpcio/grpc/framework/crust/_service.py index f1855c2f471..9903415c099 100644 --- a/src/python/grpcio/grpc/framework/crust/_service.py +++ b/src/python/grpcio/grpc/framework/crust/_service.py @@ -74,10 +74,12 @@ class _ServicerContext(face.ServicerContext): def _adaptation(pool, in_pool): def adaptation(operator, operation_context): rendezvous = _control.Rendezvous(operator, operation_context) + subscription = utilities.full_subscription( + rendezvous, _control.protocol_receiver(rendezvous)) outcome = operation_context.add_termination_callback(rendezvous.set_outcome) if outcome is None: pool.submit(_control.pool_wrap(in_pool, operation_context), rendezvous) - return utilities.full_subscription(rendezvous) + return subscription else: raise abandonment.Abandoned() return adaptation @@ -154,6 +156,8 @@ def adapt_event_stream_stream(method, pool): def adapt_multi_method(multi_method, pool): def adaptation(group, method, operator, operation_context): rendezvous = _control.Rendezvous(operator, operation_context) + subscription = utilities.full_subscription( + rendezvous, _control.protocol_receiver(rendezvous)) outcome = operation_context.add_termination_callback(rendezvous.set_outcome) if outcome is None: def in_pool(): @@ -163,7 +167,7 @@ def adapt_multi_method(multi_method, pool): request_consumer.consume(request) request_consumer.terminate() pool.submit(_control.pool_wrap(in_pool, operation_context), rendezvous) - return utilities.full_subscription(rendezvous) + return subscription else: raise abandonment.Abandoned() return adaptation diff --git a/src/python/grpcio/grpc/framework/interfaces/base/base.py b/src/python/grpcio/grpc/framework/interfaces/base/base.py index 013e7c66f22..a1e70be5e8d 100644 --- a/src/python/grpcio/grpc/framework/interfaces/base/base.py +++ b/src/python/grpcio/grpc/framework/interfaces/base/base.py @@ -184,6 +184,19 @@ class Operator(object): """ raise NotImplementedError() +class ProtocolReceiver(object): + """A means of receiving protocol values during an operation.""" + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def context(self, protocol_context): + """Accepts the protocol context object for the operation. + + Args: + protocol_context: The protocol context object for the operation. + """ + raise NotImplementedError() + class Subscription(object): """Describes customer code's interest in values from the other side. @@ -199,7 +212,11 @@ class Subscription(object): otherwise. operator: An Operator to be passed values from the other side of the operation. Must be non-None if kind is Kind.FULL. Must be None otherwise. + protocol_receiver: A ProtocolReceiver to be passed protocol objects as they + become available during the operation. Must be non-None if kind is + Kind.FULL. """ + __metaclass__ = abc.ABCMeta @enum.unique class Kind(enum.Enum): diff --git a/src/python/grpcio/grpc/framework/interfaces/base/utilities.py b/src/python/grpcio/grpc/framework/interfaces/base/utilities.py index a9ee1a09816..87a85018f52 100644 --- a/src/python/grpcio/grpc/framework/interfaces/base/utilities.py +++ b/src/python/grpcio/grpc/framework/interfaces/base/utilities.py @@ -45,11 +45,12 @@ class _Subscription( base.Subscription, collections.namedtuple( '_Subscription', - ('kind', 'termination_callback', 'allowance', 'operator',))): + ('kind', 'termination_callback', 'allowance', 'operator', + 'protocol_receiver',))): """A trivial implementation of base.Subscription.""" _NONE_SUBSCRIPTION = _Subscription( - base.Subscription.Kind.NONE, None, None, None) + base.Subscription.Kind.NONE, None, None, None, None) def completion(terminal_metadata, code, message): @@ -66,14 +67,16 @@ def completion(terminal_metadata, code, message): return _Completion(terminal_metadata, code, message) -def full_subscription(operator): +def full_subscription(operator, protocol_receiver): """Creates a "full" base.Subscription for the given base.Operator. Args: operator: A base.Operator to be used in an operation. + protocol_receiver: A base.ProtocolReceiver to be used in an operation. Returns: A base.Subscription of kind base.Subscription.Kind.FULL wrapping the given - base.Operator. + base.Operator and base.ProtocolReceiver. """ - return _Subscription(base.Subscription.Kind.FULL, None, None, operator) + return _Subscription( + base.Subscription.Kind.FULL, None, None, operator, protocol_receiver) diff --git a/src/python/grpcio_test/grpc_test/framework/interfaces/base/test_cases.py b/src/python/grpcio_test/grpc_test/framework/interfaces/base/test_cases.py index 5065a3f38a8..ddda1018c34 100644 --- a/src/python/grpcio_test/grpc_test/framework/interfaces/base/test_cases.py +++ b/src/python/grpcio_test/grpc_test/framework/interfaces/base/test_cases.py @@ -119,6 +119,17 @@ class _Operator(base.Operator): 'Deliberately raised exception from Operator.advance (in a test)!') +class _ProtocolReceiver(base.ProtocolReceiver): + + def __init__(self): + self._condition = threading.Condition() + self._contexts = [] + + def context(self, protocol_context): + with self._condition: + self._contexts.append(protocol_context) + + class _Servicer(base.Servicer): """A base.Servicer with instrumented for testing.""" @@ -144,7 +155,7 @@ class _Servicer(base.Servicer): controller.service_on_termination) if outcome is not None: controller.service_on_termination(outcome) - return utilities.full_subscription(operator) + return utilities.full_subscription(operator, _ProtocolReceiver()) class _OperationTest(unittest.TestCase): @@ -169,7 +180,8 @@ class _OperationTest(unittest.TestCase): test_operator = _Operator( self._controller, self._controller.on_invocation_advance, self._pool, None) - subscription = utilities.full_subscription(test_operator) + subscription = utilities.full_subscription( + test_operator, _ProtocolReceiver()) else: # TODO(nathaniel): support and test other subscription kinds. self.fail('Non-full subscriptions not yet supported!')