diff --git a/.pylintrc b/.pylintrc index ba74decb047..fcc8e73cb41 100644 --- a/.pylintrc +++ b/.pylintrc @@ -6,6 +6,8 @@ ignore= src/python/grpcio/grpc/framework/foundation, src/python/grpcio/grpc/framework/interfaces, +extension-pkg-whitelist=grpc._cython.cygrpc + [VARIABLES] # TODO(https://github.com/PyCQA/pylint/issues/1345): How does the inspection @@ -17,7 +19,7 @@ dummy-variables-rgx=^ignored_|^unused_ # NOTE(nathaniel): Not particularly attached to this value; it just seems to # be what works for us at the moment (excepting the dead-code-walking Beta # API). -max-args=6 +max-args=7 [MISCELLANEOUS] diff --git a/doc/python/sphinx/grpc.rst b/doc/python/sphinx/grpc.rst index f534d25c639..0934db7188b 100644 --- a/doc/python/sphinx/grpc.rst +++ b/doc/python/sphinx/grpc.rst @@ -172,3 +172,9 @@ Future Interfaces .. autoexception:: FutureTimeoutError .. autoexception:: FutureCancelledError .. autoclass:: Future + + +Compression +^^^^^^^^^^^ + +.. autoclass:: Compression diff --git a/src/python/grpcio/grpc/BUILD.bazel b/src/python/grpcio/grpc/BUILD.bazel index 27d5d2e4bb2..a2bedae4bea 100644 --- a/src/python/grpcio/grpc/BUILD.bazel +++ b/src/python/grpcio/grpc/BUILD.bazel @@ -12,6 +12,7 @@ py_library( ":channel", ":interceptor", ":server", + ":compression", "//src/python/grpcio/grpc/_cython:cygrpc", "//src/python/grpcio/grpc/experimental", "//src/python/grpcio/grpc/framework", @@ -31,12 +32,18 @@ py_library( srcs = ["_auth.py"], ) +py_library( + name = "compression", + srcs = ["_compression.py"], +) + py_library( name = "channel", srcs = ["_channel.py"], deps = [ ":common", ":grpcio_metadata", + ":compression", ], ) @@ -68,6 +75,7 @@ py_library( srcs = ["_server.py"], deps = [ ":common", + ":compression", ":interceptor", ], ) diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 76314106ca4..6175180e92a 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -21,6 +21,7 @@ import sys import six from grpc._cython import cygrpc as _cygrpc +from grpc import _compression logging.getLogger(__name__).addHandler(logging.NullHandler()) @@ -413,6 +414,8 @@ class ClientCallDetails(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag t enable wait for ready mechanism. + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. """ @@ -669,7 +672,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): """Synchronously invokes the underlying RPC. Args: @@ -681,6 +685,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: The response value for the RPC. @@ -698,7 +704,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): """Synchronously invokes the underlying RPC. Args: @@ -710,6 +717,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: The response value for the RPC and a Call value for the RPC. @@ -727,7 +736,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): """Asynchronously invokes the underlying RPC. Args: @@ -739,6 +749,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: An object that is both a Call for the RPC and a Future. @@ -759,7 +771,8 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): """Invokes the underlying RPC. Args: @@ -771,6 +784,8 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: An object that is both a Call for the RPC and an iterator of @@ -790,7 +805,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): """Synchronously invokes the underlying RPC. Args: @@ -803,6 +819,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: The response value for the RPC. @@ -820,7 +838,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): """Synchronously invokes the underlying RPC on the client. Args: @@ -833,6 +852,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: The response value for the RPC and a Call object for the RPC. @@ -850,7 +871,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): """Asynchronously invokes the underlying RPC on the client. Args: @@ -862,6 +884,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: An object that is both a Call for the RPC and a Future. @@ -882,7 +906,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): """Invokes the underlying RPC on the client. Args: @@ -894,6 +919,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): credentials: An optional CallCredentials for the RPC. wait_for_ready: This is an EXPERIMENTAL argument. An optional flag to enable wait for ready mechanism + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This is an EXPERIMENTAL option. Returns: An object that is both a Call for the RPC and an iterator of @@ -1097,6 +1124,17 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): """ raise NotImplementedError() + def set_compression(self, compression): + """Set the compression algorithm to be used for the entire call. + + This is an EXPERIMENTAL method. + + Args: + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. + """ + raise NotImplementedError() + @abc.abstractmethod def send_initial_metadata(self, initial_metadata): """Sends the initial metadata value to the client. @@ -1184,6 +1222,16 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): """ raise NotImplementedError() + def disable_next_message_compression(self): + """Disables compression for the next response message. + + This is an EXPERIMENTAL method. + + This method will override any compression configuration set during + server creation or set on the call. + """ + raise NotImplementedError() + ##################### Service-Side Handler Interfaces ######################## @@ -1682,7 +1730,7 @@ def channel_ready_future(channel): return _utilities.channel_ready_future(channel) -def insecure_channel(target, options=None): +def insecure_channel(target, options=None, compression=None): """Creates an insecure Channel to a server. The returned Channel is thread-safe. @@ -1691,15 +1739,18 @@ def insecure_channel(target, options=None): target: The server address options: An optional list of key-value pairs (channel args in gRPC Core runtime) to configure the channel. + compression: An optional value indicating the compression method to be + used over the lifetime of the channel. This is an EXPERIMENTAL option. Returns: A Channel. """ from grpc import _channel # pylint: disable=cyclic-import - return _channel.Channel(target, () if options is None else options, None) + return _channel.Channel(target, () + if options is None else options, None, compression) -def secure_channel(target, credentials, options=None): +def secure_channel(target, credentials, options=None, compression=None): """Creates a secure Channel to a server. The returned Channel is thread-safe. @@ -1709,13 +1760,15 @@ def secure_channel(target, credentials, options=None): credentials: A ChannelCredentials instance. options: An optional list of key-value pairs (channel args in gRPC Core runtime) to configure the channel. + compression: An optional value indicating the compression method to be + used over the lifetime of the channel. This is an EXPERIMENTAL option. Returns: A Channel. """ from grpc import _channel # pylint: disable=cyclic-import return _channel.Channel(target, () if options is None else options, - credentials._credentials) + credentials._credentials, compression) def intercept_channel(channel, *interceptors): @@ -1750,7 +1803,8 @@ def server(thread_pool, handlers=None, interceptors=None, options=None, - maximum_concurrent_rpcs=None): + maximum_concurrent_rpcs=None, + compression=None): """Creates a Server with which RPCs can be serviced. Args: @@ -1768,6 +1822,9 @@ def server(thread_pool, maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server will service before returning RESOURCE_EXHAUSTED status, or None to indicate no limit. + compression: An element of grpc.compression, e.g. + grpc.compression.Gzip. This compression algorithm will be used for the + lifetime of the server unless overridden. This is an EXPERIMENTAL option. Returns: A Server object. @@ -1777,7 +1834,7 @@ def server(thread_pool, if handlers is None else handlers, () if interceptors is None else interceptors, () if options is None else options, - maximum_concurrent_rpcs) + maximum_concurrent_rpcs, compression) @contextlib.contextmanager @@ -1788,6 +1845,16 @@ def _create_servicer_context(rpc_event, state, request_deserializer): context._finalize_state() # pylint: disable=protected-access +class Compression(enum.IntEnum): + """Indicates the compression method to be used for an RPC. + + This enumeration is part of an EXPERIMENTAL API. + """ + NoCompression = _compression.NoCompression + Deflate = _compression.Deflate + Gzip = _compression.Gzip + + ################################### __all__ ################################# __all__ = ( @@ -1805,6 +1872,7 @@ __all__ = ( 'AuthMetadataContext', 'AuthMetadataPluginCallback', 'AuthMetadataPlugin', + 'Compression', 'ClientCallDetails', 'ServerCertificateConfiguration', 'ServerCredentials', diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index ed4c871b684..1272ee802bc 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -19,6 +19,7 @@ import threading import time import grpc +from grpc import _compression from grpc import _common from grpc import _grpcio_metadata from grpc._cython import cygrpc @@ -512,17 +513,19 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() - def _prepare(self, request, timeout, metadata, wait_for_ready): + def _prepare(self, request, timeout, metadata, wait_for_ready, compression): deadline, serialized_request, rendezvous = _start_unary_request( request, timeout, self._request_serializer) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) + augmented_metadata = _compression.augment_metadata( + metadata, compression) if serialized_request is None: return None, None, None, rendezvous else: state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) operations = ( - cygrpc.SendInitialMetadataOperation(metadata, + cygrpc.SendInitialMetadataOperation(augmented_metadata, initial_metadata_flags), cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), @@ -532,18 +535,17 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): ) return state, operations, deadline, None - def _blocking(self, request, timeout, metadata, credentials, - wait_for_ready): + def _blocking(self, request, timeout, metadata, credentials, wait_for_ready, + compression): state, operations, deadline, rendezvous = self._prepare( - request, timeout, metadata, wait_for_ready) + request, timeout, metadata, wait_for_ready, compression) if state is None: raise rendezvous # pylint: disable-msg=raising-bad-type else: - deadline_to_propagate = _determine_deadline(deadline) call = self._channel.segregated_call( cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, - self._method, None, deadline_to_propagate, metadata, None - if credentials is None else credentials._credentials, (( + self._method, None, _determine_deadline(deadline), metadata, + None if credentials is None else credentials._credentials, (( operations, None, ),), self._context) @@ -556,9 +558,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): state, call, = self._blocking(request, timeout, metadata, credentials, - wait_for_ready) + wait_for_ready, compression) return _end_unary_response_blocking(state, call, False, None) def with_call(self, @@ -566,9 +569,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): state, call, = self._blocking(request, timeout, metadata, credentials, - wait_for_ready) + wait_for_ready, compression) return _end_unary_response_blocking(state, call, True, None) def future(self, @@ -576,9 +580,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): state, operations, deadline, rendezvous = self._prepare( - request, timeout, metadata, wait_for_ready) + request, timeout, metadata, wait_for_ready, compression) if state is None: raise rendezvous # pylint: disable-msg=raising-bad-type else: @@ -604,12 +609,14 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() - def __call__(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None): + def __call__( # pylint: disable=too-many-locals + self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None): deadline, serialized_request, rendezvous = _start_unary_request( request, timeout, self._request_serializer) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( @@ -617,10 +624,12 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): if serialized_request is None: raise rendezvous # pylint: disable-msg=raising-bad-type else: + augmented_metadata = _compression.augment_metadata( + metadata, compression) state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) operationses = ( ( - cygrpc.SendInitialMetadataOperation(metadata, + cygrpc.SendInitialMetadataOperation(augmented_metadata, initial_metadata_flags), cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), @@ -629,12 +638,13 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): ), (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), ) - event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, None, _determine_deadline(deadline), metadata, - None if credentials is None else credentials._credentials, - operationses, event_handler, self._context) + None if credentials is None else + credentials._credentials, operationses, + _event_handler(state, + self._response_deserializer), self._context) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -652,18 +662,19 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): self._context = cygrpc.build_census_context() def _blocking(self, request_iterator, timeout, metadata, credentials, - wait_for_ready): + wait_for_ready, compression): deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) - deadline_to_propagate = _determine_deadline(deadline) + augmented_metadata = _compression.augment_metadata( + metadata, compression) call = self._channel.segregated_call( cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, - None, deadline_to_propagate, metadata, None + None, _determine_deadline(deadline), augmented_metadata, None if credentials is None else credentials._credentials, _stream_unary_invocation_operationses_and_tags( - metadata, initial_metadata_flags), self._context) + augmented_metadata, initial_metadata_flags), self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, None) while True: @@ -680,9 +691,10 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): state, call, = self._blocking(request_iterator, timeout, metadata, - credentials, wait_for_ready) + credentials, wait_for_ready, compression) return _end_unary_response_blocking(state, call, False, None) def with_call(self, @@ -690,9 +702,10 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): state, call, = self._blocking(request_iterator, timeout, metadata, - credentials, wait_for_ready) + credentials, wait_for_ready, compression) return _end_unary_response_blocking(state, call, True, None) def future(self, @@ -700,15 +713,18 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) event_handler = _event_handler(state, self._response_deserializer) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) + augmented_metadata = _compression.augment_metadata( + metadata, compression) call = self._managed_call( cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, - None, deadline, metadata, None + None, deadline, augmented_metadata, None if credentials is None else credentials._credentials, _stream_unary_invocation_operationses( metadata, initial_metadata_flags), event_handler, self._context) @@ -734,24 +750,26 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): deadline = _deadline(timeout) state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) + augmented_metadata = _compression.augment_metadata( + metadata, compression) operationses = ( ( - cygrpc.SendInitialMetadataOperation(metadata, + cygrpc.SendInitialMetadataOperation(augmented_metadata, initial_metadata_flags), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), ), (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), ) event_handler = _event_handler(state, self._response_deserializer) - deadline_to_propagate = _determine_deadline(deadline) call = self._managed_call( cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, - None, deadline_to_propagate, metadata, None + None, _determine_deadline(deadline), augmented_metadata, None if credentials is None else credentials._credentials, operationses, event_handler, self._context) _consume_request_iterator(request_iterator, state, call, @@ -982,28 +1000,30 @@ def _unsubscribe(state, callback): break -def _options(options): - return list(options) + [ - ( - cygrpc.ChannelArgKey.primary_user_agent_string, - _USER_AGENT, - ), - ] +def _augment_options(base_options, compression): + compression_option = _compression.create_channel_option(compression) + return tuple(base_options) + compression_option + (( + cygrpc.ChannelArgKey.primary_user_agent_string, + _USER_AGENT, + ),) class Channel(grpc.Channel): """A cygrpc.Channel-backed implementation of grpc.Channel.""" - def __init__(self, target, options, credentials): + def __init__(self, target, options, credentials, compression): """Constructor. Args: target: The target to which to connect. options: Configuration options for the channel. credentials: A cygrpc.ChannelCredentials or None. + compression: An optional value indicating the compression method to be + used over the lifetime of the channel. """ self._channel = cygrpc.Channel( - _common.encode(target), _options(options), credentials) + _common.encode(target), _augment_options(options, compression), + credentials) self._call_state = _ChannelCallState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel) cygrpc.fork_register_channel(self) diff --git a/src/python/grpcio/grpc/_compression.py b/src/python/grpcio/grpc/_compression.py new file mode 100644 index 00000000000..45339c3afe2 --- /dev/null +++ b/src/python/grpcio/grpc/_compression.py @@ -0,0 +1,55 @@ +# Copyright 2019 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from grpc._cython import cygrpc + +NoCompression = cygrpc.CompressionAlgorithm.none +Deflate = cygrpc.CompressionAlgorithm.deflate +Gzip = cygrpc.CompressionAlgorithm.gzip + +_METADATA_STRING_MAPPING = { + NoCompression: 'identity', + Deflate: 'deflate', + Gzip: 'gzip', +} + + +def _compression_algorithm_to_metadata_value(compression): + return _METADATA_STRING_MAPPING[compression] + + +def compression_algorithm_to_metadata(compression): + return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, + _compression_algorithm_to_metadata_value(compression)) + + +def create_channel_option(compression): + return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, + int(compression)),) if compression else () + + +def augment_metadata(metadata, compression): + if not metadata and not compression: + return None + base_metadata = tuple(metadata) if metadata else () + compression_metadata = ( + compression_algorithm_to_metadata(compression),) if compression else () + return base_metadata + compression_metadata + + +__all__ = ( + "NoCompression", + "Deflate", + "Gzip", +) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index 0a35002a9d4..057d0776983 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -140,7 +140,8 @@ cdef extern from "grpc/grpc.h": const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING const char *GRPC_SSL_TARGET_NAME_OVERRIDE_ARG const char *GRPC_SSL_SESSION_CACHE_ARG - const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM + const char *_GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM \ + "GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM" const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL const char *GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET @@ -618,3 +619,8 @@ cdef extern from "grpc/compression.h": int grpc_compression_options_is_algorithm_enabled( const grpc_compression_options *opts, grpc_compression_algorithm algorithm) nogil + +cdef extern from "grpc/impl/codegen/compression_types.h": + + const char *_GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY \ + "GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY" diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index 02c904b43fc..308d677695f 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -108,6 +108,11 @@ class OperationType: receive_status_on_client = GRPC_OP_RECV_STATUS_ON_CLIENT receive_close_on_server = GRPC_OP_RECV_CLOSE_ON_SERVER +GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM= ( + _GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM) + +GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY = ( + _GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY) class CompressionAlgorithm: none = GRPC_COMPRESS_NONE diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index 6c4e396ac23..4ec2e6bb733 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -44,9 +44,9 @@ def service_pipeline(interceptors): class _ClientCallDetails( - collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')), + collections.namedtuple('_ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials', + 'wait_for_ready', 'compression')), grpc.ClientCallDetails): pass @@ -77,7 +77,12 @@ def _unwrap_client_call_details(call_details, default_details): except AttributeError: wait_for_ready = default_details.wait_for_ready - return method, timeout, metadata, credentials, wait_for_ready + try: + compression = call_details.compression + except AttributeError: + compression = default_details.compression + + return method, timeout, metadata, credentials, wait_for_ready, compression class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors @@ -206,13 +211,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): response, ignored_call = self._with_call( request, timeout=timeout, metadata=metadata, credentials=credentials, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + compression=compression) return response def _with_call(self, @@ -220,20 +227,25 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): - client_call_details = _ClientCallDetails( - self._method, timeout, metadata, credentials, wait_for_ready) + wait_for_ready=None, + compression=None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) def continuation(new_details, request): - new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( - _unwrap_client_call_details(new_details, client_call_details)) + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) try: response, call = self._thunk(new_method).with_call( request, timeout=new_timeout, metadata=new_metadata, credentials=new_credentials, - wait_for_ready=new_wait_for_ready) + wait_for_ready=new_wait_for_ready, + compression=new_compression) return _UnaryOutcome(response, call) except grpc.RpcError as rpc_error: return rpc_error @@ -249,32 +261,39 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): return self._with_call( request, timeout=timeout, metadata=metadata, credentials=credentials, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + compression=compression) def future(self, request, timeout=None, metadata=None, credentials=None, - wait_for_ready=None): - client_call_details = _ClientCallDetails( - self._method, timeout, metadata, credentials, wait_for_ready) + wait_for_ready=None, + compression=None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) def continuation(new_details, request): - new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( - _unwrap_client_call_details(new_details, client_call_details)) + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) return self._thunk(new_method).future( request, timeout=new_timeout, metadata=new_metadata, credentials=new_credentials, - wait_for_ready=new_wait_for_ready) + wait_for_ready=new_wait_for_ready, + compression=new_compression) try: return self._interceptor.intercept_unary_unary( @@ -295,19 +314,24 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): - client_call_details = _ClientCallDetails( - self._method, timeout, metadata, credentials, wait_for_ready) + wait_for_ready=None, + compression=None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) def continuation(new_details, request): - new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( - _unwrap_client_call_details(new_details, client_call_details)) + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) return self._thunk(new_method)( request, timeout=new_timeout, metadata=new_metadata, credentials=new_credentials, - wait_for_ready=new_wait_for_ready) + wait_for_ready=new_wait_for_ready, + compression=new_compression) try: return self._interceptor.intercept_unary_stream( @@ -328,13 +352,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): response, ignored_call = self._with_call( request_iterator, timeout=timeout, metadata=metadata, credentials=credentials, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + compression=compression) return response def _with_call(self, @@ -342,20 +368,25 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): - client_call_details = _ClientCallDetails( - self._method, timeout, metadata, credentials, wait_for_ready) + wait_for_ready=None, + compression=None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) def continuation(new_details, request_iterator): - new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( - _unwrap_client_call_details(new_details, client_call_details)) + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) try: response, call = self._thunk(new_method).with_call( request_iterator, timeout=new_timeout, metadata=new_metadata, credentials=new_credentials, - wait_for_ready=new_wait_for_ready) + wait_for_ready=new_wait_for_ready, + compression=new_compression) return _UnaryOutcome(response, call) except grpc.RpcError as rpc_error: return rpc_error @@ -371,32 +402,39 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): + wait_for_ready=None, + compression=None): return self._with_call( request_iterator, timeout=timeout, metadata=metadata, credentials=credentials, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + compression=compression) def future(self, request_iterator, timeout=None, metadata=None, credentials=None, - wait_for_ready=None): - client_call_details = _ClientCallDetails( - self._method, timeout, metadata, credentials, wait_for_ready) + wait_for_ready=None, + compression=None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) def continuation(new_details, request_iterator): - new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( - _unwrap_client_call_details(new_details, client_call_details)) + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) return self._thunk(new_method).future( request_iterator, timeout=new_timeout, metadata=new_metadata, credentials=new_credentials, - wait_for_ready=new_wait_for_ready) + wait_for_ready=new_wait_for_ready, + compression=new_compression) try: return self._interceptor.intercept_stream_unary( @@ -417,19 +455,24 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): timeout=None, metadata=None, credentials=None, - wait_for_ready=None): - client_call_details = _ClientCallDetails( - self._method, timeout, metadata, credentials, wait_for_ready) + wait_for_ready=None, + compression=None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials, + wait_for_ready, compression) def continuation(new_details, request_iterator): - new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( - _unwrap_client_call_details(new_details, client_call_details)) + (new_method, new_timeout, new_metadata, new_credentials, + new_wait_for_ready, + new_compression) = (_unwrap_client_call_details( + new_details, client_call_details)) return self._thunk(new_method)( request_iterator, timeout=new_timeout, metadata=new_metadata, credentials=new_credentials, - wait_for_ready=new_wait_for_ready) + wait_for_ready=new_wait_for_ready, + compression=new_compression) try: return self._interceptor.intercept_stream_stream( diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 90136aef3c2..370c81100af 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -24,6 +24,7 @@ import six import grpc from grpc import _common +from grpc import _compression from grpc import _interceptor from grpc._cython import cygrpc @@ -94,6 +95,7 @@ class _RPCState(object): self.request = None self.client = _OPEN self.initial_metadata_allowed = True + self.compression_algorithm = None self.disable_next_compression = False self.trailing_metadata = None self.code = None @@ -129,13 +131,33 @@ def _send_status_from_server(state, token): return send_status_from_server +def _get_initial_metadata(state, metadata): + with state.condition: + if state.compression_algorithm: + compression_metadata = ( + _compression.compression_algorithm_to_metadata( + state.compression_algorithm),) + if metadata is None: + return compression_metadata + else: + return compression_metadata + tuple(metadata) + else: + return metadata + + +def _get_initial_metadata_operation(state, metadata): + operation = cygrpc.SendInitialMetadataOperation( + _get_initial_metadata(state, metadata), _EMPTY_FLAGS) + return operation + + def _abort(state, call, code, details): if state.client is not _CANCELLED: effective_code = _abortion_code(state, code) effective_details = details if state.details is None else state.details if state.initial_metadata_allowed: operations = ( - cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), + _get_initial_metadata_operation(state, None), cygrpc.SendStatusFromServerOperation( state.trailing_metadata, effective_code, effective_details, _EMPTY_FLAGS), @@ -259,14 +281,18 @@ class _Context(grpc.ServicerContext): cygrpc.auth_context(self._rpc_event.call)) } + def set_compression(self, compression): + with self._state.condition: + self._state.compression_algorithm = compression + def send_initial_metadata(self, initial_metadata): with self._state.condition: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) else: if self._state.initial_metadata_allowed: - operation = cygrpc.SendInitialMetadataOperation( - initial_metadata, _EMPTY_FLAGS) + operation = _get_initial_metadata_operation( + self._state, initial_metadata) self._rpc_event.call.start_server_batch( (operation,), _send_initial_metadata(self._state)) self._state.initial_metadata_allowed = False @@ -400,10 +426,13 @@ def _call_behavior(rpc_event, with _create_servicer_context(rpc_event, state, request_deserializer) as context: try: + response_or_iterator = None if send_response_callback is not None: - return behavior(argument, context, send_response_callback), True + response_or_iterator = behavior(argument, context, + send_response_callback) else: - return behavior(argument, context), True + response_or_iterator = behavior(argument, context) + return response_or_iterator, True except Exception as exception: # pylint: disable=broad-except with state.condition: if state.aborted: @@ -447,6 +476,18 @@ def _serialize_response(rpc_event, state, response, response_serializer): return serialized_response +def _get_send_message_op_flags_from_state(state): + if state.disable_next_compression: + return cygrpc.WriteFlag.no_compress + else: + return _EMPTY_FLAGS + + +def _reset_per_message_state(state): + with state.condition: + state.disable_next_compression = False + + def _send_response(rpc_event, state, serialized_response): with state.condition: if not _is_rpc_state_active(state): @@ -454,19 +495,22 @@ def _send_response(rpc_event, state, serialized_response): else: if state.initial_metadata_allowed: operations = ( - cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), - cygrpc.SendMessageOperation(serialized_response, - _EMPTY_FLAGS), + _get_initial_metadata_operation(state, None), + cygrpc.SendMessageOperation( + serialized_response, + _get_send_message_op_flags_from_state(state)), ) state.initial_metadata_allowed = False token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN else: operations = (cygrpc.SendMessageOperation( - serialized_response, _EMPTY_FLAGS),) + serialized_response, + _get_send_message_op_flags_from_state(state)),) token = _SEND_MESSAGE_TOKEN rpc_event.call.start_server_batch(operations, _send_message(state, token)) state.due.add(token) + _reset_per_message_state(state) while True: state.condition.wait() if token not in state.due: @@ -483,16 +527,17 @@ def _status(rpc_event, state, serialized_response): state.trailing_metadata, code, details, _EMPTY_FLAGS), ] if state.initial_metadata_allowed: - operations.append( - cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS)) + operations.append(_get_initial_metadata_operation(state, None)) if serialized_response is not None: operations.append( - cygrpc.SendMessageOperation(serialized_response, - _EMPTY_FLAGS)) + cygrpc.SendMessageOperation( + serialized_response, + _get_send_message_op_flags_from_state(state))) rpc_event.call.start_server_batch( operations, _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) state.statused = True + _reset_per_message_state(state) state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) @@ -639,13 +684,13 @@ def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline): def _reject_rpc(rpc_event, status, details): + rpc_state = _RPCState() operations = ( - cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), + _get_initial_metadata_operation(rpc_state, None), cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), cygrpc.SendStatusFromServerOperation(None, status, details, _EMPTY_FLAGS), ) - rpc_state = _RPCState() rpc_event.call.start_server_batch(operations, lambda ignored_event: (rpc_state, (),)) return rpc_state @@ -883,13 +928,18 @@ def _validate_generic_rpc_handlers(generic_rpc_handlers): 'not have "service" method!'.format(generic_rpc_handler)) +def _augment_options(base_options, compression): + compression_option = _compression.create_channel_option(compression) + return tuple(base_options) + compression_option + + class _Server(grpc.Server): # pylint: disable=too-many-arguments def __init__(self, thread_pool, generic_handlers, interceptors, options, - maximum_concurrent_rpcs): + maximum_concurrent_rpcs, compression): completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server(options) + server = cygrpc.Server(_augment_options(options, compression)) server.register_completion_queue(completion_queue) self._state = _ServerState(completion_queue, server, generic_handlers, _interceptor.service_pipeline(interceptors), @@ -920,7 +970,7 @@ class _Server(grpc.Server): def create_server(thread_pool, generic_rpc_handlers, interceptors, options, - maximum_concurrent_rpcs): + maximum_concurrent_rpcs, compression): _validate_generic_rpc_handlers(generic_rpc_handlers) return _Server(thread_pool, generic_rpc_handlers, interceptors, options, - maximum_concurrent_rpcs) + maximum_concurrent_rpcs, compression) diff --git a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py index 5b1dfeacdf5..63a1b1aec95 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py @@ -56,6 +56,9 @@ class ServicerContext(grpc.ServicerContext): def auth_context(self): raise NotImplementedError() + def set_compression(self): + raise NotImplementedError() + def send_initial_metadata(self, initial_metadata): initial_metadata_sent = self._rpc.send_initial_metadata( _common.fuss_with_metadata(initial_metadata)) @@ -63,6 +66,9 @@ class ServicerContext(grpc.ServicerContext): raise ValueError( 'ServicerContext.send_initial_metadata called too late!') + def disable_next_message_compression(self): + raise NotImplementedError() + def set_trailing_metadata(self, trailing_metadata): self._rpc.set_trailing_metadata( _common.fuss_with_metadata(trailing_metadata)) diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index 7a441feb84e..e9b6333c891 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -117,6 +117,7 @@ class TestGevent(setuptools.Command): # eventually succeed, but need to dig into performance issues. 'unit._cython._no_messages_server_completion_queue_per_call_test.Test.test_rpcs', 'unit._cython._no_messages_single_server_completion_queue_test.Test.test_rpcs', + 'unit._compression_test', # TODO(https://github.com/grpc/grpc/issues/16890) enable this test 'unit._cython._channel_test.ChannelTest.test_multiple_channels_lonely_connectivity', # I have no idea why this doesn't work in gevent, but it shouldn't even be diff --git a/src/python/grpcio_tests/tests/unit/BUILD.bazel b/src/python/grpcio_tests/tests/unit/BUILD.bazel index 54b3c9b6f6a..04f91e63a18 100644 --- a/src/python/grpcio_tests/tests/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests/unit/BUILD.bazel @@ -33,6 +33,11 @@ GRPCIO_TESTS_UNIT = [ "_session_cache_test.py", ] +py_library( + name = "_tcp_proxy", + srcs = ["_tcp_proxy.py"], +) + py_library( name = "resources", srcs = ["resources.py"], @@ -80,6 +85,7 @@ py_library( ":_exit_scenarios", ":_server_shutdown_scenarios", ":_from_grpc_import_star", + ":_tcp_proxy", "//src/python/grpcio_tests/tests/unit/framework/common", "//src/python/grpcio_tests/tests/testing", requirement('six'), diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py index 0dc6a8718c3..127dab336bf 100644 --- a/src/python/grpcio_tests/tests/unit/_api_test.py +++ b/src/python/grpcio_tests/tests/unit/_api_test.py @@ -31,6 +31,7 @@ class AllTest(unittest.TestCase): 'FutureCancelledError', 'Future', 'ChannelConnectivity', + 'Compression', 'StatusCode', 'Status', 'RpcError', diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py index 87884a19dc0..66f9f0ae40f 100644 --- a/src/python/grpcio_tests/tests/unit/_compression_test.py +++ b/src/python/grpcio_tests/tests/unit/_compression_test.py @@ -13,37 +13,130 @@ # limitations under the License. """Tests server and client side compression.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import unittest +import contextlib +from concurrent import futures +import functools +import itertools import logging +import os + import grpc from grpc import _grpcio_metadata from tests.unit import test_common from tests.unit.framework.common import test_constants +from tests.unit import _tcp_proxy _UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' _STREAM_STREAM = '/test/StreamStream' +# Cut down on test time. +_STREAM_LENGTH = test_constants.STREAM_LENGTH // 8 + +_HOST = 'localhost' + +_REQUEST = b'\x00' * 100 +_COMPRESSION_RATIO_THRESHOLD = 0.1 +_COMPRESSION_METHODS = ( + None, + # Disabled for test tractability. + # grpc.Compression.NoCompression, + grpc.Compression.Deflate, + grpc.Compression.Gzip, +) +_COMPRESSION_NAMES = { + None: 'Uncompressed', + grpc.Compression.NoCompression: 'NoCompression', + grpc.Compression.Deflate: 'DeflateCompression', + grpc.Compression.Gzip: 'GzipCompression', +} + +_TEST_OPTIONS = { + 'client_streaming': (True, False), + 'server_streaming': (True, False), + 'channel_compression': _COMPRESSION_METHODS, + 'multicallable_compression': _COMPRESSION_METHODS, + 'server_compression': _COMPRESSION_METHODS, + 'server_call_compression': _COMPRESSION_METHODS, +} + + +def _make_handle_unary_unary(pre_response_callback): + + def _handle_unary(request, servicer_context): + if pre_response_callback: + pre_response_callback(request, servicer_context) + return request + + return _handle_unary + + +def _make_handle_unary_stream(pre_response_callback): + + def _handle_unary_stream(request, servicer_context): + if pre_response_callback: + pre_response_callback(request, servicer_context) + for _ in range(_STREAM_LENGTH): + yield request + + return _handle_unary_stream + + +def _make_handle_stream_unary(pre_response_callback): + + def _handle_stream_unary(request_iterator, servicer_context): + if pre_response_callback: + pre_response_callback(request_iterator, servicer_context) + response = None + for request in request_iterator: + if not response: + response = request + return response -def handle_unary(request, servicer_context): - servicer_context.send_initial_metadata([('grpc-internal-encoding-request', - 'gzip')]) - return request + return _handle_stream_unary -def handle_stream(request_iterator, servicer_context): - # TODO(issue:#6891) We should be able to remove this loop, - # and replace with return; yield - servicer_context.send_initial_metadata([('grpc-internal-encoding-request', - 'gzip')]) - for request in request_iterator: - yield request +def _make_handle_stream_stream(pre_response_callback): + + def _handle_stream(request_iterator, servicer_context): + # TODO(issue:#6891) We should be able to remove this loop, + # and replace with return; yield + for request in request_iterator: + if pre_response_callback: + pre_response_callback(request, servicer_context) + yield request + + return _handle_stream + + +def set_call_compression(compression_method, request_or_iterator, + servicer_context): + del request_or_iterator + servicer_context.set_compression(compression_method) + + +def disable_next_compression(request, servicer_context): + del request + servicer_context.disable_next_message_compression() + + +def disable_first_compression(request, servicer_context): + if int(request.decode('ascii')) == 0: + servicer_context.disable_next_message_compression() class _MethodHandler(grpc.RpcMethodHandler): - def __init__(self, request_streaming, response_streaming): + def __init__(self, request_streaming, response_streaming, + pre_response_callback): self.request_streaming = request_streaming self.response_streaming = response_streaming self.request_deserializer = None @@ -52,75 +145,239 @@ class _MethodHandler(grpc.RpcMethodHandler): self.unary_stream = None self.stream_unary = None self.stream_stream = None + if self.request_streaming and self.response_streaming: - self.stream_stream = handle_stream + self.stream_stream = _make_handle_stream_stream( + pre_response_callback) elif not self.request_streaming and not self.response_streaming: - self.unary_unary = handle_unary + self.unary_unary = _make_handle_unary_unary(pre_response_callback) + elif not self.request_streaming and self.response_streaming: + self.unary_stream = _make_handle_unary_stream(pre_response_callback) + else: + self.stream_unary = _make_handle_stream_unary(pre_response_callback) class _GenericHandler(grpc.GenericRpcHandler): + def __init__(self, pre_response_callback): + self._pre_response_callback = pre_response_callback + def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: - return _MethodHandler(False, False) + return _MethodHandler(False, False, self._pre_response_callback) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True, self._pre_response_callback) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False, self._pre_response_callback) elif handler_call_details.method == _STREAM_STREAM: - return _MethodHandler(True, True) + return _MethodHandler(True, True, self._pre_response_callback) else: return None +@contextlib.contextmanager +def _instrumented_client_server_pair(channel_kwargs, server_kwargs, + server_handler): + server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs) + server.add_generic_rpc_handlers((server_handler,)) + server_port = server.add_insecure_port('{}:0'.format(_HOST)) + server.start() + with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy: + proxy_port = proxy.get_port() + with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port), + **channel_kwargs) as client_channel: + try: + yield client_channel, proxy, server + finally: + server.stop(None) + + +def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function, + server_kwargs, server_handler, message): + with _instrumented_client_server_pair(channel_kwargs, server_kwargs, + server_handler) as pipeline: + client_channel, proxy, server = pipeline + client_function(client_channel, multicallable_kwargs, message) + return proxy.get_byte_count() + + +def _get_compression_ratios(client_function, first_channel_kwargs, + first_multicallable_kwargs, first_server_kwargs, + first_server_handler, second_channel_kwargs, + second_multicallable_kwargs, second_server_kwargs, + second_server_handler, message): + try: + # This test requires the byte length of each connection to be deterministic. As + # it turns out, flow control puts bytes on the wire in a nondeterministic + # manner. We disable it here in order to measure compression ratios + # deterministically. + os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] = 'true' + first_bytes_sent, first_bytes_received = _get_byte_counts( + first_channel_kwargs, first_multicallable_kwargs, client_function, + first_server_kwargs, first_server_handler, message) + second_bytes_sent, second_bytes_received = _get_byte_counts( + second_channel_kwargs, second_multicallable_kwargs, client_function, + second_server_kwargs, second_server_handler, message) + return (( + second_bytes_sent - first_bytes_sent) / float(first_bytes_sent), + (second_bytes_received - first_bytes_received) / + float(first_bytes_received)) + finally: + del os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] + + +def _unary_unary_client(channel, multicallable_kwargs, message): + multi_callable = channel.unary_unary(_UNARY_UNARY) + response = multi_callable(message, **multicallable_kwargs) + if response != message: + raise RuntimeError("Request '{}' != Response '{}'".format( + message, response)) + + +def _unary_stream_client(channel, multicallable_kwargs, message): + multi_callable = channel.unary_stream(_UNARY_STREAM) + response_iterator = multi_callable(message, **multicallable_kwargs) + for response in response_iterator: + if response != message: + raise RuntimeError("Request '{}' != Response '{}'".format( + message, response)) + + +def _stream_unary_client(channel, multicallable_kwargs, message): + multi_callable = channel.stream_unary(_STREAM_UNARY) + requests = (_REQUEST for _ in range(_STREAM_LENGTH)) + response = multi_callable(requests, **multicallable_kwargs) + if response != message: + raise RuntimeError("Request '{}' != Response '{}'".format( + message, response)) + + +def _stream_stream_client(channel, multicallable_kwargs, message): + multi_callable = channel.stream_stream(_STREAM_STREAM) + request_prefix = str(0).encode('ascii') * 100 + requests = ( + request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH)) + response_iterator = multi_callable(requests, **multicallable_kwargs) + for i, response in enumerate(response_iterator): + if int(response.decode('ascii')) != i: + raise RuntimeError("Request '{}' != Response '{}'".format( + i, response)) + + class CompressionTest(unittest.TestCase): - def setUp(self): - self._server = test_common.test_server() - self._server.add_generic_rpc_handlers((_GenericHandler(),)) - self._port = self._server.add_insecure_port('[::]:0') - self._server.start() - - def tearDown(self): - self._server.stop(None) - - def testUnary(self): - request = b'\x00' * 100 - - # Client -> server compressed through default client channel compression - # settings. Server -> client compressed via server-side metadata setting. - # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer - # literal with proper use of the public API. - compressed_channel = grpc.insecure_channel( - 'localhost:%d' % self._port, - options=[('grpc.default_compression_algorithm', 1)]) - multi_callable = compressed_channel.unary_unary(_UNARY_UNARY) - response = multi_callable(request) - self.assertEqual(request, response) - - # Client -> server compressed through client metadata setting. Server -> - # client compressed via server-side metadata setting. - # TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer - # literal with proper use of the public API. - uncompressed_channel = grpc.insecure_channel( - 'localhost:%d' % self._port, - options=[('grpc.default_compression_algorithm', 0)]) - multi_callable = compressed_channel.unary_unary(_UNARY_UNARY) - response = multi_callable( - request, metadata=[('grpc-internal-encoding-request', 'gzip')]) - self.assertEqual(request, response) - compressed_channel.close() - - def testStreaming(self): - request = b'\x00' * 100 - - # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer - # literal with proper use of the public API. - compressed_channel = grpc.insecure_channel( - 'localhost:%d' % self._port, - options=[('grpc.default_compression_algorithm', 1)]) - multi_callable = compressed_channel.stream_stream(_STREAM_STREAM) - call = multi_callable(iter([request] * test_constants.STREAM_LENGTH)) - for response in call: - self.assertEqual(request, response) - compressed_channel.close() + def assertCompressed(self, compression_ratio): + self.assertLess( + compression_ratio, + -1.0 * _COMPRESSION_RATIO_THRESHOLD, + msg='Actual compression ratio: {}'.format(compression_ratio)) + + def assertNotCompressed(self, compression_ratio): + self.assertGreaterEqual( + compression_ratio, + -1.0 * _COMPRESSION_RATIO_THRESHOLD, + msg='Actual compession ratio: {}'.format(compression_ratio)) + + def assertConfigurationCompressed( + self, client_streaming, server_streaming, channel_compression, + multicallable_compression, server_compression, + server_call_compression): + client_side_compressed = channel_compression or multicallable_compression + server_side_compressed = server_compression or server_call_compression + channel_kwargs = { + 'compression': channel_compression, + } if channel_compression else {} + multicallable_kwargs = { + 'compression': multicallable_compression, + } if multicallable_compression else {} + + client_function = None + if not client_streaming and not server_streaming: + client_function = _unary_unary_client + elif not client_streaming and server_streaming: + client_function = _unary_stream_client + elif client_streaming and not server_streaming: + client_function = _stream_unary_client + else: + client_function = _stream_stream_client + + server_kwargs = { + 'compression': server_compression, + } if server_compression else {} + server_handler = _GenericHandler( + functools.partial(set_call_compression, grpc.Compression.Gzip) + ) if server_call_compression else _GenericHandler(None) + sent_ratio, received_ratio = _get_compression_ratios( + client_function, {}, {}, {}, _GenericHandler(None), channel_kwargs, + multicallable_kwargs, server_kwargs, server_handler, _REQUEST) + + if client_side_compressed: + self.assertCompressed(sent_ratio) + else: + self.assertNotCompressed(sent_ratio) + + if server_side_compressed: + self.assertCompressed(received_ratio) + else: + self.assertNotCompressed(received_ratio) + + def testDisableNextCompressionStreaming(self): + server_kwargs = { + 'compression': grpc.Compression.Deflate, + } + _, received_ratio = _get_compression_ratios( + _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {}, + server_kwargs, _GenericHandler(disable_next_compression), _REQUEST) + self.assertNotCompressed(received_ratio) + + def testDisableNextCompressionStreamingResets(self): + server_kwargs = { + 'compression': grpc.Compression.Deflate, + } + _, received_ratio = _get_compression_ratios( + _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {}, + server_kwargs, _GenericHandler(disable_first_compression), _REQUEST) + self.assertCompressed(received_ratio) + + +def _get_compression_str(name, value): + return '{}{}'.format(name, _COMPRESSION_NAMES[value]) + + +def _get_compression_test_name(client_streaming, server_streaming, + channel_compression, multicallable_compression, + server_compression, server_call_compression): + client_arity = 'Stream' if client_streaming else 'Unary' + server_arity = 'Stream' if server_streaming else 'Unary' + arity = '{}{}'.format(client_arity, server_arity) + channel_compression_str = _get_compression_str('Channel', + channel_compression) + multicallable_compression_str = _get_compression_str( + 'Multicallable', multicallable_compression) + server_compression_str = _get_compression_str('Server', server_compression) + server_call_compression_str = _get_compression_str('ServerCall', + server_call_compression) + return 'test{}{}{}{}{}'.format( + arity, channel_compression_str, multicallable_compression_str, + server_compression_str, server_call_compression_str) + + +def _test_options(): + for test_parameters in itertools.product(*_TEST_OPTIONS.values()): + yield dict(zip(_TEST_OPTIONS.keys(), test_parameters)) + + +for options in _test_options(): + + def test_compression(**kwargs): + + def _test_compression(self): + self.assertConfigurationCompressed(**kwargs) + + return _test_compression + setattr(CompressionTest, _get_compression_test_name(**options), + test_compression(**options)) if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests/unit/_tcp_proxy.py b/src/python/grpcio_tests/tests/unit/_tcp_proxy.py new file mode 100644 index 00000000000..5ad0bf8f028 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_tcp_proxy.py @@ -0,0 +1,164 @@ +# Copyright 2019 the gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Proxies a TCP connection between a single client-server pair. + +This proxy is not suitable for production, but should work well for cases in +which a test needs to spy on the bytes put on the wire between a server and +a client. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import select +import socket +import threading + +_TCP_PROXY_BUFFER_SIZE = 1024 +_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500) + + +def _create_socket_ipv6(bind_address): + listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + listen_socket.bind((bind_address, 0, 0, 0)) + return listen_socket + + +def _create_socket_ipv4(bind_address): + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.bind((bind_address, 0)) + return listen_socket + + +def _init_listen_socket(bind_address): + listen_socket = None + if socket.has_ipv6: + try: + listen_socket = _create_socket_ipv6(bind_address) + except socket.error: + listen_socket = _create_socket_ipv4(bind_address) + else: + listen_socket = _create_socket_ipv4(bind_address) + listen_socket.listen(1) + return listen_socket, listen_socket.getsockname()[1] + + +def _init_proxy_socket(gateway_address, gateway_port): + proxy_socket = socket.create_connection((gateway_address, gateway_port)) + return proxy_socket + + +class TcpProxy(object): + """Proxies a TCP connection between one client and one server.""" + + def __init__(self, bind_address, gateway_address, gateway_port): + self._bind_address = bind_address + self._gateway_address = gateway_address + self._gateway_port = gateway_port + + self._byte_count_lock = threading.RLock() + self._sent_byte_count = 0 + self._received_byte_count = 0 + + self._stop_event = threading.Event() + + self._port = None + self._listen_socket = None + self._proxy_socket = None + + # The following three attributes are owned by the serving thread. + self._northbound_data = b"" + self._southbound_data = b"" + self._client_sockets = [] + + self._thread = threading.Thread(target=self._run_proxy) + + def start(self): + self._listen_socket, self._port = _init_listen_socket( + self._bind_address) + self._proxy_socket = _init_proxy_socket(self._gateway_address, + self._gateway_port) + self._thread.start() + + def get_port(self): + return self._port + + def _handle_reads(self, sockets_to_read): + for socket_to_read in sockets_to_read: + if socket_to_read is self._listen_socket: + client_socket, client_address = socket_to_read.accept() + self._client_sockets.append(client_socket) + elif socket_to_read is self._proxy_socket: + data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) + with self._byte_count_lock: + self._received_byte_count += len(data) + self._northbound_data += data + elif socket_to_read in self._client_sockets: + data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) + if data: + with self._byte_count_lock: + self._sent_byte_count += len(data) + self._southbound_data += data + else: + self._client_sockets.remove(socket_to_read) + else: + raise RuntimeError('Unidentified socket appeared in read set.') + + def _handle_writes(self, sockets_to_write): + for socket_to_write in sockets_to_write: + if socket_to_write is self._proxy_socket: + if self._southbound_data: + self._proxy_socket.sendall(self._southbound_data) + self._southbound_data = b"" + elif socket_to_write in self._client_sockets: + if self._northbound_data: + socket_to_write.sendall(self._northbound_data) + self._northbound_data = b"" + + def _run_proxy(self): + while not self._stop_event.is_set(): + expected_reads = (self._listen_socket, self._proxy_socket) + tuple( + self._client_sockets) + expected_writes = expected_reads + sockets_to_read, sockets_to_write, _ = select.select( + expected_reads, expected_writes, (), + _TCP_PROXY_TIMEOUT.total_seconds()) + self._handle_reads(sockets_to_read) + self._handle_writes(sockets_to_write) + for client_socket in self._client_sockets: + client_socket.close() + + def stop(self): + self._stop_event.set() + self._thread.join() + self._listen_socket.close() + self._proxy_socket.close() + + def get_byte_count(self): + with self._byte_count_lock: + return self._sent_byte_count, self._received_byte_count + + def reset_byte_count(self): + with self._byte_count_lock: + self._byte_count = 0 + self._received_byte_count = 0 + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop()