diff --git a/src/python/grpcio/grpc/_links/invocation.py b/src/python/grpcio/grpc/_links/invocation.py index 729b987dd12..1676fe79414 100644 --- a/src/python/grpcio/grpc/_links/invocation.py +++ b/src/python/grpcio/grpc/_links/invocation.py @@ -41,6 +41,8 @@ from grpc.framework.foundation import logging_pool from grpc.framework.foundation import relay from grpc.framework.interfaces.links import links +_IDENTITY = lambda x: x + _STOP = _intermediary_low.Event.Kind.STOP _WRITE = _intermediary_low.Event.Kind.WRITE_ACCEPTED _COMPLETE = _intermediary_low.Event.Kind.COMPLETE_ACCEPTED @@ -95,11 +97,12 @@ def _no_longer_due(kind, rpc_state, key, rpc_states): class _Kernel(object): def __init__( - self, channel, host, request_serializers, response_deserializers, - ticket_relay): + self, channel, host, metadata_transformer, request_serializers, + response_deserializers, ticket_relay): self._lock = threading.Lock() self._channel = channel self._host = host + self._metadata_transformer = metadata_transformer self._request_serializers = request_serializers self._response_deserializers = response_deserializers self._relay = ticket_relay @@ -225,20 +228,17 @@ class _Kernel(object): else: return - request_serializer = self._request_serializers.get((group, method)) - response_deserializer = self._response_deserializers.get((group, method)) - if request_serializer is None or response_deserializer is None: - cancellation_ticket = links.Ticket( - operation_id, 0, None, None, None, None, None, None, None, None, None, - None, links.Ticket.Termination.CANCELLATION) - self._relay.add_value(cancellation_ticket) - return + transformed_initial_metadata = self._metadata_transformer(initial_metadata) + request_serializer = self._request_serializers.get( + (group, method), _IDENTITY) + response_deserializer = self._response_deserializers.get( + (group, method), _IDENTITY) call = _intermediary_low.Call( self._channel, self._completion_queue, '/%s/%s' % (group, method), self._host, time.time() + timeout) - if initial_metadata is not None: - for metadata_key, metadata_value in initial_metadata: + if transformed_initial_metadata is not None: + for metadata_key, metadata_value in transformed_initial_metadata: call.add_metadata(metadata_key, metadata_value) call.invoke(self._completion_queue, operation_id, operation_id) if payload is None: @@ -336,10 +336,15 @@ class InvocationLink(links.Link, activated.Activated): class _InvocationLink(InvocationLink): def __init__( - self, channel, host, request_serializers, response_deserializers): + self, channel, host, metadata_transformer, request_serializers, + response_deserializers): self._relay = relay.relay(None) self._kernel = _Kernel( - channel, host, request_serializers, response_deserializers, self._relay) + channel, host, + _IDENTITY if metadata_transformer is None else metadata_transformer, + {} if request_serializers is None else request_serializers, + {} if response_deserializers is None else response_deserializers, + self._relay) def _start(self): self._relay.start() @@ -376,12 +381,17 @@ class _InvocationLink(InvocationLink): self._stop() -def invocation_link(channel, host, request_serializers, response_deserializers): +def invocation_link( + channel, host, metadata_transformer, request_serializers, + response_deserializers): """Creates an InvocationLink. Args: channel: An _intermediary_low.Channel for use by the link. host: The host to specify when invoking RPCs. + metadata_transformer: A callable that takes an invocation-side initial + metadata value and returns another metadata value to send in its place. + May be None. request_serializers: A dict from group-method pair to request object serialization behavior. response_deserializers: A dict from group-method pair to response object @@ -391,4 +401,5 @@ def invocation_link(channel, host, request_serializers, response_deserializers): An InvocationLink. """ return _InvocationLink( - channel, host, request_serializers, response_deserializers) + channel, host, metadata_transformer, request_serializers, + response_deserializers) diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py index bbfe9bcd55f..94e7cfc716b 100644 --- a/src/python/grpcio/grpc/_links/service.py +++ b/src/python/grpcio/grpc/_links/service.py @@ -40,6 +40,8 @@ from grpc.framework.foundation import logging_pool from grpc.framework.foundation import relay from grpc.framework.interfaces.links import links +_IDENTITY = lambda x: x + _TERMINATION_KIND_TO_CODE = { links.Ticket.Termination.COMPLETION: _intermediary_low.Code.OK, links.Ticket.Termination.CANCELLATION: _intermediary_low.Code.CANCELLED, @@ -154,12 +156,10 @@ class _Kernel(object): except ValueError: logging.info('Illegal path "%s"!', service_acceptance.method) return - request_deserializer = self._request_deserializers.get((group, method)) - response_serializer = self._response_serializers.get((group, method)) - if request_deserializer is None or response_serializer is None: - # TODO(nathaniel): Terminate the RPC with code NOT_FOUND. - call.cancel() - return + request_deserializer = self._request_deserializers.get( + (group, method), _IDENTITY) + response_serializer = self._response_serializers.get( + (group, method), _IDENTITY) call.read(call) self._rpc_states[call] = _RPCState( @@ -433,7 +433,9 @@ class _ServiceLink(ServiceLink): def __init__(self, request_deserializers, response_serializers): self._relay = relay.relay(None) self._kernel = _Kernel( - request_deserializers, response_serializers, self._relay) + {} if request_deserializers is None else request_deserializers, + {} if response_serializers is None else response_serializers, + self._relay) def accept_ticket(self, ticket): self._kernel.add_ticket(ticket) diff --git a/src/python/grpcio/grpc/beta/_stub.py b/src/python/grpcio/grpc/beta/_stub.py index 178f06d21e5..cfbecb852b2 100644 --- a/src/python/grpcio/grpc/beta/_stub.py +++ b/src/python/grpcio/grpc/beta/_stub.py @@ -54,11 +54,12 @@ class _AutoIntermediary(object): def _assemble( - channel, host, request_serializers, response_deserializers, thread_pool, - thread_pool_size): + channel, host, metadata_transformer, request_serializers, + response_deserializers, thread_pool, thread_pool_size): end_link = _core_implementations.invocation_end_link() grpc_link = invocation.invocation_link( - channel, host, request_serializers, response_deserializers) + channel, host, metadata_transformer, request_serializers, + response_deserializers) if thread_pool is None: invocation_pool = logging_pool.pool( _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size) @@ -89,21 +90,22 @@ def _wrap_assembly(stub, end_link, grpc_link, assembly_pool): def generic_stub( - channel, host, request_serializers, response_deserializers, thread_pool, - thread_pool_size): + channel, host, metadata_transformer, request_serializers, + response_deserializers, thread_pool, thread_pool_size): end_link, grpc_link, invocation_pool, assembly_pool = _assemble( - channel, host, request_serializers, response_deserializers, thread_pool, - thread_pool_size) + channel, host, metadata_transformer, request_serializers, + response_deserializers, thread_pool, thread_pool_size) stub = _crust_implementations.generic_stub(end_link, invocation_pool) return _wrap_assembly(stub, end_link, grpc_link, assembly_pool) def dynamic_stub( - channel, host, service, cardinalities, request_serializers, - response_deserializers, thread_pool, thread_pool_size): + channel, host, service, cardinalities, metadata_transformer, + request_serializers, response_deserializers, thread_pool, + thread_pool_size): end_link, grpc_link, invocation_pool, assembly_pool = _assemble( - channel, host, request_serializers, response_deserializers, thread_pool, - thread_pool_size) + channel, host, metadata_transformer, request_serializers, + response_deserializers, thread_pool, thread_pool_size) stub = _crust_implementations.dynamic_stub( end_link, service, cardinalities, invocation_pool) return _wrap_assembly(stub, end_link, grpc_link, assembly_pool) diff --git a/src/python/grpcio/grpc/beta/beta.py b/src/python/grpcio/grpc/beta/beta.py index 640e4eb86b5..b3a161087f3 100644 --- a/src/python/grpcio/grpc/beta/beta.py +++ b/src/python/grpcio/grpc/beta/beta.py @@ -238,6 +238,7 @@ def generic_stub(channel, options=None): effective_options = _EMPTY_STUB_OPTIONS if options is None else options return _stub.generic_stub( channel._intermediary_low_channel, effective_options.host, # pylint: disable=protected-access + effective_options.metadata_transformer, effective_options.request_serializers, effective_options.response_deserializers, effective_options.thread_pool, effective_options.thread_pool_size) @@ -260,7 +261,8 @@ def dynamic_stub(channel, service, cardinalities, options=None): effective_options = StubOptions() if options is None else options return _stub.dynamic_stub( channel._intermediary_low_channel, effective_options.host, service, # pylint: disable=protected-access - cardinalities, effective_options.request_serializers, + cardinalities, effective_options.metadata_transformer, + effective_options.request_serializers, effective_options.response_deserializers, effective_options.thread_pool, effective_options.thread_pool_size) diff --git a/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py b/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py index 9112c341900..f0bd989ea66 100644 --- a/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py +++ b/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py @@ -94,7 +94,7 @@ class _Implementation(test_interfaces.Implementation): port = service_grpc_link.add_port('[::]:0', None) channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_grpc_link = invocation.invocation_link( - channel, b'localhost', + channel, b'localhost', None, serialization_behaviors.request_serializers, serialization_behaviors.response_deserializers) diff --git a/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py b/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py index 14015365030..28c0619f7c9 100644 --- a/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py +++ b/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py @@ -87,7 +87,7 @@ class _Implementation(test_interfaces.Implementation): port = service_grpc_link.add_port('[::]:0', None) channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_grpc_link = invocation.invocation_link( - channel, b'localhost', + channel, b'localhost', None, serialization_behaviors.request_serializers, serialization_behaviors.response_deserializers) diff --git a/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py b/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py index 373a2b2a1f7..8e12e8cc223 100644 --- a/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py +++ b/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py @@ -45,7 +45,8 @@ class LonelyInvocationLinkTest(unittest.TestCase): def testUpAndDown(self): channel = _intermediary_low.Channel('nonexistent:54321', None) - invocation_link = invocation.invocation_link(channel, 'nonexistent', {}, {}) + invocation_link = invocation.invocation_link( + channel, 'nonexistent', None, {}, {}) invocation_link.start() invocation_link.stop() @@ -58,8 +59,7 @@ class LonelyInvocationLinkTest(unittest.TestCase): channel = _intermediary_low.Channel('nonexistent:54321', None) invocation_link = invocation.invocation_link( - channel, 'nonexistent', {(test_group, test_method): _NULL_BEHAVIOR}, - {(test_group, test_method): _NULL_BEHAVIOR}) + channel, 'nonexistent', None, {}, {}) invocation_link.join_link(invocation_link_mate) invocation_link.start() diff --git a/src/python/grpcio_test/grpc_test/_links/_transmission_test.py b/src/python/grpcio_test/grpc_test/_links/_transmission_test.py index c114cef6a60..716323cc20d 100644 --- a/src/python/grpcio_test/grpc_test/_links/_transmission_test.py +++ b/src/python/grpcio_test/grpc_test/_links/_transmission_test.py @@ -54,7 +54,7 @@ class TransmissionTest(test_cases.TransmissionTest, unittest.TestCase): service_link.start() channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_link = invocation.invocation_link( - channel, 'localhost', + channel, 'localhost', None, {self.group_and_method(): self.serialize_request}, {self.group_and_method(): self.deserialize_response}) invocation_link.start() @@ -121,7 +121,7 @@ class RoundTripTest(unittest.TestCase): service_link.start() channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_link = invocation.invocation_link( - channel, 'localhost', identity_transformation, identity_transformation) + channel, None, None, identity_transformation, identity_transformation) invocation_mate = test_utilities.RecordingLink() invocation_link.join_link(invocation_mate) invocation_link.start() @@ -166,7 +166,7 @@ class RoundTripTest(unittest.TestCase): service_link.start() channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_link = invocation.invocation_link( - channel, 'localhost', + channel, 'localhost', None, {(test_group, test_method): scenario.serialize_request}, {(test_group, test_method): scenario.deserialize_response}) invocation_mate = test_utilities.RecordingLink()