Revert "Merge pull request #18727 from grpc/revert_compression"

This reverts commit 8054a731d1, reversing
changes made to c3d3cf8053.
pull/18732/head
Richard Belleville 6 years ago
parent 8054a731d1
commit f900eec41d
  1. 4
      .pylintrc
  2. 6
      doc/python/sphinx/grpc.rst
  3. 8
      src/python/grpcio/grpc/BUILD.bazel
  4. 96
      src/python/grpcio/grpc/__init__.py
  5. 114
      src/python/grpcio/grpc/_channel.py
  6. 55
      src/python/grpcio/grpc/_compression.py
  7. 8
      src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
  8. 5
      src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
  9. 139
      src/python/grpcio/grpc/_interceptor.py
  10. 88
      src/python/grpcio/grpc/_server.py
  11. 6
      src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py
  12. 1
      src/python/grpcio_tests/commands.py
  13. 6
      src/python/grpcio_tests/tests/unit/BUILD.bazel
  14. 1
      src/python/grpcio_tests/tests/unit/_api_test.py
  15. 387
      src/python/grpcio_tests/tests/unit/_compression_test.py
  16. 164
      src/python/grpcio_tests/tests/unit/_tcp_proxy.py

@ -6,6 +6,8 @@ ignore=
src/python/grpcio/grpc/framework/foundation, src/python/grpcio/grpc/framework/foundation,
src/python/grpcio/grpc/framework/interfaces, src/python/grpcio/grpc/framework/interfaces,
extension-pkg-whitelist=grpc._cython.cygrpc
[VARIABLES] [VARIABLES]
# TODO(https://github.com/PyCQA/pylint/issues/1345): How does the inspection # TODO(https://github.com/PyCQA/pylint/issues/1345): How does the inspection
@ -17,7 +19,7 @@ dummy-variables-rgx=^ignored_|^unused_
# NOTE(nathaniel): Not particularly attached to this value; it just seems to # NOTE(nathaniel): Not particularly attached to this value; it just seems to
# be what works for us at the moment (excepting the dead-code-walking Beta # be what works for us at the moment (excepting the dead-code-walking Beta
# API). # API).
max-args=6 max-args=7
[MISCELLANEOUS] [MISCELLANEOUS]

@ -172,3 +172,9 @@ Future Interfaces
.. autoexception:: FutureTimeoutError .. autoexception:: FutureTimeoutError
.. autoexception:: FutureCancelledError .. autoexception:: FutureCancelledError
.. autoclass:: Future .. autoclass:: Future
Compression
^^^^^^^^^^^
.. autoclass:: Compression

@ -12,6 +12,7 @@ py_library(
":channel", ":channel",
":interceptor", ":interceptor",
":server", ":server",
":compression",
"//src/python/grpcio/grpc/_cython:cygrpc", "//src/python/grpcio/grpc/_cython:cygrpc",
"//src/python/grpcio/grpc/experimental", "//src/python/grpcio/grpc/experimental",
"//src/python/grpcio/grpc/framework", "//src/python/grpcio/grpc/framework",
@ -31,12 +32,18 @@ py_library(
srcs = ["_auth.py"], srcs = ["_auth.py"],
) )
py_library(
name = "compression",
srcs = ["_compression.py"],
)
py_library( py_library(
name = "channel", name = "channel",
srcs = ["_channel.py"], srcs = ["_channel.py"],
deps = [ deps = [
":common", ":common",
":grpcio_metadata", ":grpcio_metadata",
":compression",
], ],
) )
@ -68,6 +75,7 @@ py_library(
srcs = ["_server.py"], srcs = ["_server.py"],
deps = [ deps = [
":common", ":common",
":compression",
":interceptor", ":interceptor",
], ],
) )

@ -21,6 +21,7 @@ import sys
import six import six
from grpc._cython import cygrpc as _cygrpc from grpc._cython import cygrpc as _cygrpc
from grpc import _compression
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -413,6 +414,8 @@ class ClientCallDetails(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional flag t wait_for_ready: This is an EXPERIMENTAL argument. An optional flag t
enable wait for ready mechanism. enable wait for ready mechanism.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
""" """
@ -669,7 +672,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
"""Synchronously invokes the underlying RPC. """Synchronously invokes the underlying RPC.
Args: Args:
@ -681,6 +685,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns: Returns:
The response value for the RPC. The response value for the RPC.
@ -698,7 +704,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
"""Synchronously invokes the underlying RPC. """Synchronously invokes the underlying RPC.
Args: Args:
@ -710,6 +717,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns: Returns:
The response value for the RPC and a Call value for the RPC. The response value for the RPC and a Call value for the RPC.
@ -727,7 +736,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
"""Asynchronously invokes the underlying RPC. """Asynchronously invokes the underlying RPC.
Args: Args:
@ -739,6 +749,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns: Returns:
An object that is both a Call for the RPC and a Future. An object that is both a Call for the RPC and a Future.
@ -759,7 +771,8 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
"""Invokes the underlying RPC. """Invokes the underlying RPC.
Args: Args:
@ -771,6 +784,8 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns: Returns:
An object that is both a Call for the RPC and an iterator of An object that is both a Call for the RPC and an iterator of
@ -790,7 +805,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
"""Synchronously invokes the underlying RPC. """Synchronously invokes the underlying RPC.
Args: Args:
@ -803,6 +819,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns: Returns:
The response value for the RPC. The response value for the RPC.
@ -820,7 +838,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
"""Synchronously invokes the underlying RPC on the client. """Synchronously invokes the underlying RPC on the client.
Args: Args:
@ -833,6 +852,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns: Returns:
The response value for the RPC and a Call object for the RPC. The response value for the RPC and a Call object for the RPC.
@ -850,7 +871,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
"""Asynchronously invokes the underlying RPC on the client. """Asynchronously invokes the underlying RPC on the client.
Args: Args:
@ -862,6 +884,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns: Returns:
An object that is both a Call for the RPC and a Future. An object that is both a Call for the RPC and a Future.
@ -882,7 +906,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
"""Invokes the underlying RPC on the client. """Invokes the underlying RPC on the client.
Args: Args:
@ -894,6 +919,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC. credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns: Returns:
An object that is both a Call for the RPC and an iterator of An object that is both a Call for the RPC and an iterator of
@ -1097,6 +1124,17 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
""" """
raise NotImplementedError() raise NotImplementedError()
def set_compression(self, compression):
"""Set the compression algorithm to be used for the entire call.
This is an EXPERIMENTAL method.
Args:
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip.
"""
raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def send_initial_metadata(self, initial_metadata): def send_initial_metadata(self, initial_metadata):
"""Sends the initial metadata value to the client. """Sends the initial metadata value to the client.
@ -1184,6 +1222,16 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
""" """
raise NotImplementedError() raise NotImplementedError()
def disable_next_message_compression(self):
"""Disables compression for the next response message.
This is an EXPERIMENTAL method.
This method will override any compression configuration set during
server creation or set on the call.
"""
raise NotImplementedError()
##################### Service-Side Handler Interfaces ######################## ##################### Service-Side Handler Interfaces ########################
@ -1682,7 +1730,7 @@ def channel_ready_future(channel):
return _utilities.channel_ready_future(channel) return _utilities.channel_ready_future(channel)
def insecure_channel(target, options=None): def insecure_channel(target, options=None, compression=None):
"""Creates an insecure Channel to a server. """Creates an insecure Channel to a server.
The returned Channel is thread-safe. The returned Channel is thread-safe.
@ -1691,15 +1739,18 @@ def insecure_channel(target, options=None):
target: The server address target: The server address
options: An optional list of key-value pairs (channel args options: An optional list of key-value pairs (channel args
in gRPC Core runtime) to configure the channel. in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option.
Returns: Returns:
A Channel. A Channel.
""" """
from grpc import _channel # pylint: disable=cyclic-import from grpc import _channel # pylint: disable=cyclic-import
return _channel.Channel(target, () if options is None else options, None) return _channel.Channel(target, ()
if options is None else options, None, compression)
def secure_channel(target, credentials, options=None): def secure_channel(target, credentials, options=None, compression=None):
"""Creates a secure Channel to a server. """Creates a secure Channel to a server.
The returned Channel is thread-safe. The returned Channel is thread-safe.
@ -1709,13 +1760,15 @@ def secure_channel(target, credentials, options=None):
credentials: A ChannelCredentials instance. credentials: A ChannelCredentials instance.
options: An optional list of key-value pairs (channel args options: An optional list of key-value pairs (channel args
in gRPC Core runtime) to configure the channel. in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option.
Returns: Returns:
A Channel. A Channel.
""" """
from grpc import _channel # pylint: disable=cyclic-import from grpc import _channel # pylint: disable=cyclic-import
return _channel.Channel(target, () if options is None else options, return _channel.Channel(target, () if options is None else options,
credentials._credentials) credentials._credentials, compression)
def intercept_channel(channel, *interceptors): def intercept_channel(channel, *interceptors):
@ -1750,7 +1803,8 @@ def server(thread_pool,
handlers=None, handlers=None,
interceptors=None, interceptors=None,
options=None, options=None,
maximum_concurrent_rpcs=None): maximum_concurrent_rpcs=None,
compression=None):
"""Creates a Server with which RPCs can be serviced. """Creates a Server with which RPCs can be serviced.
Args: Args:
@ -1768,6 +1822,9 @@ def server(thread_pool,
maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server
will service before returning RESOURCE_EXHAUSTED status, or None to will service before returning RESOURCE_EXHAUSTED status, or None to
indicate no limit. indicate no limit.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This compression algorithm will be used for the
lifetime of the server unless overridden. This is an EXPERIMENTAL option.
Returns: Returns:
A Server object. A Server object.
@ -1777,7 +1834,7 @@ def server(thread_pool,
if handlers is None else handlers, () if handlers is None else handlers, ()
if interceptors is None else interceptors, () if interceptors is None else interceptors, ()
if options is None else options, if options is None else options,
maximum_concurrent_rpcs) maximum_concurrent_rpcs, compression)
@contextlib.contextmanager @contextlib.contextmanager
@ -1788,6 +1845,16 @@ def _create_servicer_context(rpc_event, state, request_deserializer):
context._finalize_state() # pylint: disable=protected-access context._finalize_state() # pylint: disable=protected-access
class Compression(enum.IntEnum):
"""Indicates the compression method to be used for an RPC.
This enumeration is part of an EXPERIMENTAL API.
"""
NoCompression = _compression.NoCompression
Deflate = _compression.Deflate
Gzip = _compression.Gzip
################################### __all__ ################################# ################################### __all__ #################################
__all__ = ( __all__ = (
@ -1805,6 +1872,7 @@ __all__ = (
'AuthMetadataContext', 'AuthMetadataContext',
'AuthMetadataPluginCallback', 'AuthMetadataPluginCallback',
'AuthMetadataPlugin', 'AuthMetadataPlugin',
'Compression',
'ClientCallDetails', 'ClientCallDetails',
'ServerCertificateConfiguration', 'ServerCertificateConfiguration',
'ServerCredentials', 'ServerCredentials',

@ -19,6 +19,7 @@ import threading
import time import time
import grpc import grpc
from grpc import _compression
from grpc import _common from grpc import _common
from grpc import _grpcio_metadata from grpc import _grpcio_metadata
from grpc._cython import cygrpc from grpc._cython import cygrpc
@ -512,17 +513,19 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context() self._context = cygrpc.build_census_context()
def _prepare(self, request, timeout, metadata, wait_for_ready): def _prepare(self, request, timeout, metadata, wait_for_ready, compression):
deadline, serialized_request, rendezvous = _start_unary_request( deadline, serialized_request, rendezvous = _start_unary_request(
request, timeout, self._request_serializer) request, timeout, self._request_serializer)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready) wait_for_ready)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
if serialized_request is None: if serialized_request is None:
return None, None, None, rendezvous return None, None, None, rendezvous
else: else:
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
operations = ( operations = (
cygrpc.SendInitialMetadataOperation(metadata, cygrpc.SendInitialMetadataOperation(augmented_metadata,
initial_metadata_flags), initial_metadata_flags),
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
@ -532,18 +535,17 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
) )
return state, operations, deadline, None return state, operations, deadline, None
def _blocking(self, request, timeout, metadata, credentials, def _blocking(self, request, timeout, metadata, credentials, wait_for_ready,
wait_for_ready): compression):
state, operations, deadline, rendezvous = self._prepare( state, operations, deadline, rendezvous = self._prepare(
request, timeout, metadata, wait_for_ready) request, timeout, metadata, wait_for_ready, compression)
if state is None: if state is None:
raise rendezvous # pylint: disable-msg=raising-bad-type raise rendezvous # pylint: disable-msg=raising-bad-type
else: else:
deadline_to_propagate = _determine_deadline(deadline)
call = self._channel.segregated_call( call = self._channel.segregated_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
self._method, None, deadline_to_propagate, metadata, None self._method, None, _determine_deadline(deadline), metadata,
if credentials is None else credentials._credentials, (( None if credentials is None else credentials._credentials, ((
operations, operations,
None, None,
),), self._context) ),), self._context)
@ -556,9 +558,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
state, call, = self._blocking(request, timeout, metadata, credentials, state, call, = self._blocking(request, timeout, metadata, credentials,
wait_for_ready) wait_for_ready, compression)
return _end_unary_response_blocking(state, call, False, None) return _end_unary_response_blocking(state, call, False, None)
def with_call(self, def with_call(self,
@ -566,9 +569,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
state, call, = self._blocking(request, timeout, metadata, credentials, state, call, = self._blocking(request, timeout, metadata, credentials,
wait_for_ready) wait_for_ready, compression)
return _end_unary_response_blocking(state, call, True, None) return _end_unary_response_blocking(state, call, True, None)
def future(self, def future(self,
@ -576,9 +580,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
state, operations, deadline, rendezvous = self._prepare( state, operations, deadline, rendezvous = self._prepare(
request, timeout, metadata, wait_for_ready) request, timeout, metadata, wait_for_ready, compression)
if state is None: if state is None:
raise rendezvous # pylint: disable-msg=raising-bad-type raise rendezvous # pylint: disable-msg=raising-bad-type
else: else:
@ -604,12 +609,14 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context() self._context = cygrpc.build_census_context()
def __call__(self, def __call__( # pylint: disable=too-many-locals
request, self,
timeout=None, request,
metadata=None, timeout=None,
credentials=None, metadata=None,
wait_for_ready=None): credentials=None,
wait_for_ready=None,
compression=None):
deadline, serialized_request, rendezvous = _start_unary_request( deadline, serialized_request, rendezvous = _start_unary_request(
request, timeout, self._request_serializer) request, timeout, self._request_serializer)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
@ -617,10 +624,12 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
if serialized_request is None: if serialized_request is None:
raise rendezvous # pylint: disable-msg=raising-bad-type raise rendezvous # pylint: disable-msg=raising-bad-type
else: else:
augmented_metadata = _compression.augment_metadata(
metadata, compression)
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
operationses = ( operationses = (
( (
cygrpc.SendInitialMetadataOperation(metadata, cygrpc.SendInitialMetadataOperation(augmented_metadata,
initial_metadata_flags), initial_metadata_flags),
cygrpc.SendMessageOperation(serialized_request, cygrpc.SendMessageOperation(serialized_request,
_EMPTY_FLAGS), _EMPTY_FLAGS),
@ -629,12 +638,13 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
), ),
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
) )
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call( call = self._managed_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
self._method, None, _determine_deadline(deadline), metadata, self._method, None, _determine_deadline(deadline), metadata,
None if credentials is None else credentials._credentials, None if credentials is None else
operationses, event_handler, self._context) credentials._credentials, operationses,
_event_handler(state,
self._response_deserializer), self._context)
return _Rendezvous(state, call, self._response_deserializer, return _Rendezvous(state, call, self._response_deserializer,
deadline) deadline)
@ -652,18 +662,19 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
self._context = cygrpc.build_census_context() self._context = cygrpc.build_census_context()
def _blocking(self, request_iterator, timeout, metadata, credentials, def _blocking(self, request_iterator, timeout, metadata, credentials,
wait_for_ready): wait_for_ready, compression):
deadline = _deadline(timeout) deadline = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready) wait_for_ready)
deadline_to_propagate = _determine_deadline(deadline) augmented_metadata = _compression.augment_metadata(
metadata, compression)
call = self._channel.segregated_call( call = self._channel.segregated_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
None, deadline_to_propagate, metadata, None None, _determine_deadline(deadline), augmented_metadata, None
if credentials is None else credentials._credentials, if credentials is None else credentials._credentials,
_stream_unary_invocation_operationses_and_tags( _stream_unary_invocation_operationses_and_tags(
metadata, initial_metadata_flags), self._context) augmented_metadata, initial_metadata_flags), self._context)
_consume_request_iterator(request_iterator, state, call, _consume_request_iterator(request_iterator, state, call,
self._request_serializer, None) self._request_serializer, None)
while True: while True:
@ -680,9 +691,10 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
state, call, = self._blocking(request_iterator, timeout, metadata, state, call, = self._blocking(request_iterator, timeout, metadata,
credentials, wait_for_ready) credentials, wait_for_ready, compression)
return _end_unary_response_blocking(state, call, False, None) return _end_unary_response_blocking(state, call, False, None)
def with_call(self, def with_call(self,
@ -690,9 +702,10 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
state, call, = self._blocking(request_iterator, timeout, metadata, state, call, = self._blocking(request_iterator, timeout, metadata,
credentials, wait_for_ready) credentials, wait_for_ready, compression)
return _end_unary_response_blocking(state, call, True, None) return _end_unary_response_blocking(state, call, True, None)
def future(self, def future(self,
@ -700,15 +713,18 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
deadline = _deadline(timeout) deadline = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
event_handler = _event_handler(state, self._response_deserializer) event_handler = _event_handler(state, self._response_deserializer)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready) wait_for_ready)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
call = self._managed_call( call = self._managed_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
None, deadline, metadata, None None, deadline, augmented_metadata, None
if credentials is None else credentials._credentials, if credentials is None else credentials._credentials,
_stream_unary_invocation_operationses( _stream_unary_invocation_operationses(
metadata, initial_metadata_flags), event_handler, self._context) metadata, initial_metadata_flags), event_handler, self._context)
@ -734,24 +750,26 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
deadline = _deadline(timeout) deadline = _deadline(timeout)
state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready) wait_for_ready)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
operationses = ( operationses = (
( (
cygrpc.SendInitialMetadataOperation(metadata, cygrpc.SendInitialMetadataOperation(augmented_metadata,
initial_metadata_flags), initial_metadata_flags),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
), ),
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
) )
event_handler = _event_handler(state, self._response_deserializer) event_handler = _event_handler(state, self._response_deserializer)
deadline_to_propagate = _determine_deadline(deadline)
call = self._managed_call( call = self._managed_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
None, deadline_to_propagate, metadata, None None, _determine_deadline(deadline), augmented_metadata, None
if credentials is None else credentials._credentials, operationses, if credentials is None else credentials._credentials, operationses,
event_handler, self._context) event_handler, self._context)
_consume_request_iterator(request_iterator, state, call, _consume_request_iterator(request_iterator, state, call,
@ -982,28 +1000,30 @@ def _unsubscribe(state, callback):
break break
def _options(options): def _augment_options(base_options, compression):
return list(options) + [ compression_option = _compression.create_channel_option(compression)
( return tuple(base_options) + compression_option + ((
cygrpc.ChannelArgKey.primary_user_agent_string, cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT, _USER_AGENT,
), ),)
]
class Channel(grpc.Channel): class Channel(grpc.Channel):
"""A cygrpc.Channel-backed implementation of grpc.Channel.""" """A cygrpc.Channel-backed implementation of grpc.Channel."""
def __init__(self, target, options, credentials): def __init__(self, target, options, credentials, compression):
"""Constructor. """Constructor.
Args: Args:
target: The target to which to connect. target: The target to which to connect.
options: Configuration options for the channel. options: Configuration options for the channel.
credentials: A cygrpc.ChannelCredentials or None. credentials: A cygrpc.ChannelCredentials or None.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
""" """
self._channel = cygrpc.Channel( self._channel = cygrpc.Channel(
_common.encode(target), _options(options), credentials) _common.encode(target), _augment_options(options, compression),
credentials)
self._call_state = _ChannelCallState(self._channel) self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel)
cygrpc.fork_register_channel(self) cygrpc.fork_register_channel(self)

@ -0,0 +1,55 @@
# Copyright 2019 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from grpc._cython import cygrpc
NoCompression = cygrpc.CompressionAlgorithm.none
Deflate = cygrpc.CompressionAlgorithm.deflate
Gzip = cygrpc.CompressionAlgorithm.gzip
_METADATA_STRING_MAPPING = {
NoCompression: 'identity',
Deflate: 'deflate',
Gzip: 'gzip',
}
def _compression_algorithm_to_metadata_value(compression):
return _METADATA_STRING_MAPPING[compression]
def compression_algorithm_to_metadata(compression):
return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
_compression_algorithm_to_metadata_value(compression))
def create_channel_option(compression):
return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM,
int(compression)),) if compression else ()
def augment_metadata(metadata, compression):
if not metadata and not compression:
return None
base_metadata = tuple(metadata) if metadata else ()
compression_metadata = (
compression_algorithm_to_metadata(compression),) if compression else ()
return base_metadata + compression_metadata
__all__ = (
"NoCompression",
"Deflate",
"Gzip",
)

@ -140,7 +140,8 @@ cdef extern from "grpc/grpc.h":
const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING
const char *GRPC_SSL_TARGET_NAME_OVERRIDE_ARG const char *GRPC_SSL_TARGET_NAME_OVERRIDE_ARG
const char *GRPC_SSL_SESSION_CACHE_ARG const char *GRPC_SSL_SESSION_CACHE_ARG
const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM const char *_GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM \
"GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM"
const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL
const char *GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET const char *GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET
@ -618,3 +619,8 @@ cdef extern from "grpc/compression.h":
int grpc_compression_options_is_algorithm_enabled( int grpc_compression_options_is_algorithm_enabled(
const grpc_compression_options *opts, const grpc_compression_options *opts,
grpc_compression_algorithm algorithm) nogil grpc_compression_algorithm algorithm) nogil
cdef extern from "grpc/impl/codegen/compression_types.h":
const char *_GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY \
"GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY"

@ -108,6 +108,11 @@ class OperationType:
receive_status_on_client = GRPC_OP_RECV_STATUS_ON_CLIENT receive_status_on_client = GRPC_OP_RECV_STATUS_ON_CLIENT
receive_close_on_server = GRPC_OP_RECV_CLOSE_ON_SERVER receive_close_on_server = GRPC_OP_RECV_CLOSE_ON_SERVER
GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM= (
_GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM)
GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY = (
_GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY)
class CompressionAlgorithm: class CompressionAlgorithm:
none = GRPC_COMPRESS_NONE none = GRPC_COMPRESS_NONE

@ -44,9 +44,9 @@ def service_pipeline(interceptors):
class _ClientCallDetails( class _ClientCallDetails(
collections.namedtuple( collections.namedtuple('_ClientCallDetails',
'_ClientCallDetails', ('method', 'timeout', 'metadata', 'credentials',
('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')), 'wait_for_ready', 'compression')),
grpc.ClientCallDetails): grpc.ClientCallDetails):
pass pass
@ -77,7 +77,12 @@ def _unwrap_client_call_details(call_details, default_details):
except AttributeError: except AttributeError:
wait_for_ready = default_details.wait_for_ready wait_for_ready = default_details.wait_for_ready
return method, timeout, metadata, credentials, wait_for_ready try:
compression = call_details.compression
except AttributeError:
compression = default_details.compression
return method, timeout, metadata, credentials, wait_for_ready, compression
class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
@ -206,13 +211,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
response, ignored_call = self._with_call( response, ignored_call = self._with_call(
request, request,
timeout=timeout, timeout=timeout,
metadata=metadata, metadata=metadata,
credentials=credentials, credentials=credentials,
wait_for_ready=wait_for_ready) wait_for_ready=wait_for_ready,
compression=compression)
return response return response
def _with_call(self, def _with_call(self,
@ -220,20 +227,25 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
client_call_details = _ClientCallDetails( compression=None):
self._method, timeout, metadata, credentials, wait_for_ready) client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request): def continuation(new_details, request):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( (new_method, new_timeout, new_metadata, new_credentials,
_unwrap_client_call_details(new_details, client_call_details)) new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
try: try:
response, call = self._thunk(new_method).with_call( response, call = self._thunk(new_method).with_call(
request, request,
timeout=new_timeout, timeout=new_timeout,
metadata=new_metadata, metadata=new_metadata,
credentials=new_credentials, credentials=new_credentials,
wait_for_ready=new_wait_for_ready) wait_for_ready=new_wait_for_ready,
compression=new_compression)
return _UnaryOutcome(response, call) return _UnaryOutcome(response, call)
except grpc.RpcError as rpc_error: except grpc.RpcError as rpc_error:
return rpc_error return rpc_error
@ -249,32 +261,39 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
return self._with_call( return self._with_call(
request, request,
timeout=timeout, timeout=timeout,
metadata=metadata, metadata=metadata,
credentials=credentials, credentials=credentials,
wait_for_ready=wait_for_ready) wait_for_ready=wait_for_ready,
compression=compression)
def future(self, def future(self,
request, request,
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
client_call_details = _ClientCallDetails( compression=None):
self._method, timeout, metadata, credentials, wait_for_ready) client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request): def continuation(new_details, request):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( (new_method, new_timeout, new_metadata, new_credentials,
_unwrap_client_call_details(new_details, client_call_details)) new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method).future( return self._thunk(new_method).future(
request, request,
timeout=new_timeout, timeout=new_timeout,
metadata=new_metadata, metadata=new_metadata,
credentials=new_credentials, credentials=new_credentials,
wait_for_ready=new_wait_for_ready) wait_for_ready=new_wait_for_ready,
compression=new_compression)
try: try:
return self._interceptor.intercept_unary_unary( return self._interceptor.intercept_unary_unary(
@ -295,19 +314,24 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
client_call_details = _ClientCallDetails( compression=None):
self._method, timeout, metadata, credentials, wait_for_ready) client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request): def continuation(new_details, request):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( (new_method, new_timeout, new_metadata, new_credentials,
_unwrap_client_call_details(new_details, client_call_details)) new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method)( return self._thunk(new_method)(
request, request,
timeout=new_timeout, timeout=new_timeout,
metadata=new_metadata, metadata=new_metadata,
credentials=new_credentials, credentials=new_credentials,
wait_for_ready=new_wait_for_ready) wait_for_ready=new_wait_for_ready,
compression=new_compression)
try: try:
return self._interceptor.intercept_unary_stream( return self._interceptor.intercept_unary_stream(
@ -328,13 +352,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
response, ignored_call = self._with_call( response, ignored_call = self._with_call(
request_iterator, request_iterator,
timeout=timeout, timeout=timeout,
metadata=metadata, metadata=metadata,
credentials=credentials, credentials=credentials,
wait_for_ready=wait_for_ready) wait_for_ready=wait_for_ready,
compression=compression)
return response return response
def _with_call(self, def _with_call(self,
@ -342,20 +368,25 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
client_call_details = _ClientCallDetails( compression=None):
self._method, timeout, metadata, credentials, wait_for_ready) client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator): def continuation(new_details, request_iterator):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( (new_method, new_timeout, new_metadata, new_credentials,
_unwrap_client_call_details(new_details, client_call_details)) new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
try: try:
response, call = self._thunk(new_method).with_call( response, call = self._thunk(new_method).with_call(
request_iterator, request_iterator,
timeout=new_timeout, timeout=new_timeout,
metadata=new_metadata, metadata=new_metadata,
credentials=new_credentials, credentials=new_credentials,
wait_for_ready=new_wait_for_ready) wait_for_ready=new_wait_for_ready,
compression=new_compression)
return _UnaryOutcome(response, call) return _UnaryOutcome(response, call)
except grpc.RpcError as rpc_error: except grpc.RpcError as rpc_error:
return rpc_error return rpc_error
@ -371,32 +402,39 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
compression=None):
return self._with_call( return self._with_call(
request_iterator, request_iterator,
timeout=timeout, timeout=timeout,
metadata=metadata, metadata=metadata,
credentials=credentials, credentials=credentials,
wait_for_ready=wait_for_ready) wait_for_ready=wait_for_ready,
compression=compression)
def future(self, def future(self,
request_iterator, request_iterator,
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
client_call_details = _ClientCallDetails( compression=None):
self._method, timeout, metadata, credentials, wait_for_ready) client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator): def continuation(new_details, request_iterator):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( (new_method, new_timeout, new_metadata, new_credentials,
_unwrap_client_call_details(new_details, client_call_details)) new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method).future( return self._thunk(new_method).future(
request_iterator, request_iterator,
timeout=new_timeout, timeout=new_timeout,
metadata=new_metadata, metadata=new_metadata,
credentials=new_credentials, credentials=new_credentials,
wait_for_ready=new_wait_for_ready) wait_for_ready=new_wait_for_ready,
compression=new_compression)
try: try:
return self._interceptor.intercept_stream_unary( return self._interceptor.intercept_stream_unary(
@ -417,19 +455,24 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
timeout=None, timeout=None,
metadata=None, metadata=None,
credentials=None, credentials=None,
wait_for_ready=None): wait_for_ready=None,
client_call_details = _ClientCallDetails( compression=None):
self._method, timeout, metadata, credentials, wait_for_ready) client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator): def continuation(new_details, request_iterator):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = ( (new_method, new_timeout, new_metadata, new_credentials,
_unwrap_client_call_details(new_details, client_call_details)) new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method)( return self._thunk(new_method)(
request_iterator, request_iterator,
timeout=new_timeout, timeout=new_timeout,
metadata=new_metadata, metadata=new_metadata,
credentials=new_credentials, credentials=new_credentials,
wait_for_ready=new_wait_for_ready) wait_for_ready=new_wait_for_ready,
compression=new_compression)
try: try:
return self._interceptor.intercept_stream_stream( return self._interceptor.intercept_stream_stream(

@ -24,6 +24,7 @@ import six
import grpc import grpc
from grpc import _common from grpc import _common
from grpc import _compression
from grpc import _interceptor from grpc import _interceptor
from grpc._cython import cygrpc from grpc._cython import cygrpc
@ -94,6 +95,7 @@ class _RPCState(object):
self.request = None self.request = None
self.client = _OPEN self.client = _OPEN
self.initial_metadata_allowed = True self.initial_metadata_allowed = True
self.compression_algorithm = None
self.disable_next_compression = False self.disable_next_compression = False
self.trailing_metadata = None self.trailing_metadata = None
self.code = None self.code = None
@ -129,13 +131,33 @@ def _send_status_from_server(state, token):
return send_status_from_server return send_status_from_server
def _get_initial_metadata(state, metadata):
with state.condition:
if state.compression_algorithm:
compression_metadata = (
_compression.compression_algorithm_to_metadata(
state.compression_algorithm),)
if metadata is None:
return compression_metadata
else:
return compression_metadata + tuple(metadata)
else:
return metadata
def _get_initial_metadata_operation(state, metadata):
operation = cygrpc.SendInitialMetadataOperation(
_get_initial_metadata(state, metadata), _EMPTY_FLAGS)
return operation
def _abort(state, call, code, details): def _abort(state, call, code, details):
if state.client is not _CANCELLED: if state.client is not _CANCELLED:
effective_code = _abortion_code(state, code) effective_code = _abortion_code(state, code)
effective_details = details if state.details is None else state.details effective_details = details if state.details is None else state.details
if state.initial_metadata_allowed: if state.initial_metadata_allowed:
operations = ( operations = (
cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), _get_initial_metadata_operation(state, None),
cygrpc.SendStatusFromServerOperation( cygrpc.SendStatusFromServerOperation(
state.trailing_metadata, effective_code, effective_details, state.trailing_metadata, effective_code, effective_details,
_EMPTY_FLAGS), _EMPTY_FLAGS),
@ -259,14 +281,18 @@ class _Context(grpc.ServicerContext):
cygrpc.auth_context(self._rpc_event.call)) cygrpc.auth_context(self._rpc_event.call))
} }
def set_compression(self, compression):
with self._state.condition:
self._state.compression_algorithm = compression
def send_initial_metadata(self, initial_metadata): def send_initial_metadata(self, initial_metadata):
with self._state.condition: with self._state.condition:
if self._state.client is _CANCELLED: if self._state.client is _CANCELLED:
_raise_rpc_error(self._state) _raise_rpc_error(self._state)
else: else:
if self._state.initial_metadata_allowed: if self._state.initial_metadata_allowed:
operation = cygrpc.SendInitialMetadataOperation( operation = _get_initial_metadata_operation(
initial_metadata, _EMPTY_FLAGS) self._state, initial_metadata)
self._rpc_event.call.start_server_batch( self._rpc_event.call.start_server_batch(
(operation,), _send_initial_metadata(self._state)) (operation,), _send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False self._state.initial_metadata_allowed = False
@ -400,10 +426,13 @@ def _call_behavior(rpc_event,
with _create_servicer_context(rpc_event, state, with _create_servicer_context(rpc_event, state,
request_deserializer) as context: request_deserializer) as context:
try: try:
response_or_iterator = None
if send_response_callback is not None: if send_response_callback is not None:
return behavior(argument, context, send_response_callback), True response_or_iterator = behavior(argument, context,
send_response_callback)
else: else:
return behavior(argument, context), True response_or_iterator = behavior(argument, context)
return response_or_iterator, True
except Exception as exception: # pylint: disable=broad-except except Exception as exception: # pylint: disable=broad-except
with state.condition: with state.condition:
if state.aborted: if state.aborted:
@ -447,6 +476,18 @@ def _serialize_response(rpc_event, state, response, response_serializer):
return serialized_response return serialized_response
def _get_send_message_op_flags_from_state(state):
if state.disable_next_compression:
return cygrpc.WriteFlag.no_compress
else:
return _EMPTY_FLAGS
def _reset_per_message_state(state):
with state.condition:
state.disable_next_compression = False
def _send_response(rpc_event, state, serialized_response): def _send_response(rpc_event, state, serialized_response):
with state.condition: with state.condition:
if not _is_rpc_state_active(state): if not _is_rpc_state_active(state):
@ -454,19 +495,22 @@ def _send_response(rpc_event, state, serialized_response):
else: else:
if state.initial_metadata_allowed: if state.initial_metadata_allowed:
operations = ( operations = (
cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), _get_initial_metadata_operation(state, None),
cygrpc.SendMessageOperation(serialized_response, cygrpc.SendMessageOperation(
_EMPTY_FLAGS), serialized_response,
_get_send_message_op_flags_from_state(state)),
) )
state.initial_metadata_allowed = False state.initial_metadata_allowed = False
token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
else: else:
operations = (cygrpc.SendMessageOperation( operations = (cygrpc.SendMessageOperation(
serialized_response, _EMPTY_FLAGS),) serialized_response,
_get_send_message_op_flags_from_state(state)),)
token = _SEND_MESSAGE_TOKEN token = _SEND_MESSAGE_TOKEN
rpc_event.call.start_server_batch(operations, rpc_event.call.start_server_batch(operations,
_send_message(state, token)) _send_message(state, token))
state.due.add(token) state.due.add(token)
_reset_per_message_state(state)
while True: while True:
state.condition.wait() state.condition.wait()
if token not in state.due: if token not in state.due:
@ -483,16 +527,17 @@ def _status(rpc_event, state, serialized_response):
state.trailing_metadata, code, details, _EMPTY_FLAGS), state.trailing_metadata, code, details, _EMPTY_FLAGS),
] ]
if state.initial_metadata_allowed: if state.initial_metadata_allowed:
operations.append( operations.append(_get_initial_metadata_operation(state, None))
cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS))
if serialized_response is not None: if serialized_response is not None:
operations.append( operations.append(
cygrpc.SendMessageOperation(serialized_response, cygrpc.SendMessageOperation(
_EMPTY_FLAGS)) serialized_response,
_get_send_message_op_flags_from_state(state)))
rpc_event.call.start_server_batch( rpc_event.call.start_server_batch(
operations, operations,
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
state.statused = True state.statused = True
_reset_per_message_state(state)
state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
@ -639,13 +684,13 @@ def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline):
def _reject_rpc(rpc_event, status, details): def _reject_rpc(rpc_event, status, details):
rpc_state = _RPCState()
operations = ( operations = (
cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), _get_initial_metadata_operation(rpc_state, None),
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation(None, status, details, cygrpc.SendStatusFromServerOperation(None, status, details,
_EMPTY_FLAGS), _EMPTY_FLAGS),
) )
rpc_state = _RPCState()
rpc_event.call.start_server_batch(operations, rpc_event.call.start_server_batch(operations,
lambda ignored_event: (rpc_state, (),)) lambda ignored_event: (rpc_state, (),))
return rpc_state return rpc_state
@ -883,13 +928,18 @@ def _validate_generic_rpc_handlers(generic_rpc_handlers):
'not have "service" method!'.format(generic_rpc_handler)) 'not have "service" method!'.format(generic_rpc_handler))
def _augment_options(base_options, compression):
compression_option = _compression.create_channel_option(compression)
return tuple(base_options) + compression_option
class _Server(grpc.Server): class _Server(grpc.Server):
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def __init__(self, thread_pool, generic_handlers, interceptors, options, def __init__(self, thread_pool, generic_handlers, interceptors, options,
maximum_concurrent_rpcs): maximum_concurrent_rpcs, compression):
completion_queue = cygrpc.CompletionQueue() completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(options) server = cygrpc.Server(_augment_options(options, compression))
server.register_completion_queue(completion_queue) server.register_completion_queue(completion_queue)
self._state = _ServerState(completion_queue, server, generic_handlers, self._state = _ServerState(completion_queue, server, generic_handlers,
_interceptor.service_pipeline(interceptors), _interceptor.service_pipeline(interceptors),
@ -920,7 +970,7 @@ class _Server(grpc.Server):
def create_server(thread_pool, generic_rpc_handlers, interceptors, options, def create_server(thread_pool, generic_rpc_handlers, interceptors, options,
maximum_concurrent_rpcs): maximum_concurrent_rpcs, compression):
_validate_generic_rpc_handlers(generic_rpc_handlers) _validate_generic_rpc_handlers(generic_rpc_handlers)
return _Server(thread_pool, generic_rpc_handlers, interceptors, options, return _Server(thread_pool, generic_rpc_handlers, interceptors, options,
maximum_concurrent_rpcs) maximum_concurrent_rpcs, compression)

@ -56,6 +56,9 @@ class ServicerContext(grpc.ServicerContext):
def auth_context(self): def auth_context(self):
raise NotImplementedError() raise NotImplementedError()
def set_compression(self):
raise NotImplementedError()
def send_initial_metadata(self, initial_metadata): def send_initial_metadata(self, initial_metadata):
initial_metadata_sent = self._rpc.send_initial_metadata( initial_metadata_sent = self._rpc.send_initial_metadata(
_common.fuss_with_metadata(initial_metadata)) _common.fuss_with_metadata(initial_metadata))
@ -63,6 +66,9 @@ class ServicerContext(grpc.ServicerContext):
raise ValueError( raise ValueError(
'ServicerContext.send_initial_metadata called too late!') 'ServicerContext.send_initial_metadata called too late!')
def disable_next_message_compression(self):
raise NotImplementedError()
def set_trailing_metadata(self, trailing_metadata): def set_trailing_metadata(self, trailing_metadata):
self._rpc.set_trailing_metadata( self._rpc.set_trailing_metadata(
_common.fuss_with_metadata(trailing_metadata)) _common.fuss_with_metadata(trailing_metadata))

@ -117,6 +117,7 @@ class TestGevent(setuptools.Command):
# eventually succeed, but need to dig into performance issues. # eventually succeed, but need to dig into performance issues.
'unit._cython._no_messages_server_completion_queue_per_call_test.Test.test_rpcs', 'unit._cython._no_messages_server_completion_queue_per_call_test.Test.test_rpcs',
'unit._cython._no_messages_single_server_completion_queue_test.Test.test_rpcs', 'unit._cython._no_messages_single_server_completion_queue_test.Test.test_rpcs',
'unit._compression_test',
# TODO(https://github.com/grpc/grpc/issues/16890) enable this test # TODO(https://github.com/grpc/grpc/issues/16890) enable this test
'unit._cython._channel_test.ChannelTest.test_multiple_channels_lonely_connectivity', 'unit._cython._channel_test.ChannelTest.test_multiple_channels_lonely_connectivity',
# I have no idea why this doesn't work in gevent, but it shouldn't even be # I have no idea why this doesn't work in gevent, but it shouldn't even be

@ -33,6 +33,11 @@ GRPCIO_TESTS_UNIT = [
"_session_cache_test.py", "_session_cache_test.py",
] ]
py_library(
name = "_tcp_proxy",
srcs = ["_tcp_proxy.py"],
)
py_library( py_library(
name = "resources", name = "resources",
srcs = ["resources.py"], srcs = ["resources.py"],
@ -80,6 +85,7 @@ py_library(
":_exit_scenarios", ":_exit_scenarios",
":_server_shutdown_scenarios", ":_server_shutdown_scenarios",
":_from_grpc_import_star", ":_from_grpc_import_star",
":_tcp_proxy",
"//src/python/grpcio_tests/tests/unit/framework/common", "//src/python/grpcio_tests/tests/unit/framework/common",
"//src/python/grpcio_tests/tests/testing", "//src/python/grpcio_tests/tests/testing",
requirement('six'), requirement('six'),

@ -31,6 +31,7 @@ class AllTest(unittest.TestCase):
'FutureCancelledError', 'FutureCancelledError',
'Future', 'Future',
'ChannelConnectivity', 'ChannelConnectivity',
'Compression',
'StatusCode', 'StatusCode',
'Status', 'Status',
'RpcError', 'RpcError',

@ -13,37 +13,130 @@
# limitations under the License. # limitations under the License.
"""Tests server and client side compression.""" """Tests server and client side compression."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest import unittest
import contextlib
from concurrent import futures
import functools
import itertools
import logging import logging
import os
import grpc import grpc
from grpc import _grpcio_metadata from grpc import _grpcio_metadata
from tests.unit import test_common from tests.unit import test_common
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests.unit import _tcp_proxy
_UNARY_UNARY = '/test/UnaryUnary' _UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream'
_STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream' _STREAM_STREAM = '/test/StreamStream'
# Cut down on test time.
_STREAM_LENGTH = test_constants.STREAM_LENGTH // 8
_HOST = 'localhost'
_REQUEST = b'\x00' * 100
_COMPRESSION_RATIO_THRESHOLD = 0.1
_COMPRESSION_METHODS = (
None,
# Disabled for test tractability.
# grpc.Compression.NoCompression,
grpc.Compression.Deflate,
grpc.Compression.Gzip,
)
_COMPRESSION_NAMES = {
None: 'Uncompressed',
grpc.Compression.NoCompression: 'NoCompression',
grpc.Compression.Deflate: 'DeflateCompression',
grpc.Compression.Gzip: 'GzipCompression',
}
_TEST_OPTIONS = {
'client_streaming': (True, False),
'server_streaming': (True, False),
'channel_compression': _COMPRESSION_METHODS,
'multicallable_compression': _COMPRESSION_METHODS,
'server_compression': _COMPRESSION_METHODS,
'server_call_compression': _COMPRESSION_METHODS,
}
def _make_handle_unary_unary(pre_response_callback):
def _handle_unary(request, servicer_context):
if pre_response_callback:
pre_response_callback(request, servicer_context)
return request
return _handle_unary
def _make_handle_unary_stream(pre_response_callback):
def _handle_unary_stream(request, servicer_context):
if pre_response_callback:
pre_response_callback(request, servicer_context)
for _ in range(_STREAM_LENGTH):
yield request
return _handle_unary_stream
def _make_handle_stream_unary(pre_response_callback):
def _handle_stream_unary(request_iterator, servicer_context):
if pre_response_callback:
pre_response_callback(request_iterator, servicer_context)
response = None
for request in request_iterator:
if not response:
response = request
return response
def handle_unary(request, servicer_context): return _handle_stream_unary
servicer_context.send_initial_metadata([('grpc-internal-encoding-request',
'gzip')])
return request
def handle_stream(request_iterator, servicer_context): def _make_handle_stream_stream(pre_response_callback):
# TODO(issue:#6891) We should be able to remove this loop,
# and replace with return; yield def _handle_stream(request_iterator, servicer_context):
servicer_context.send_initial_metadata([('grpc-internal-encoding-request', # TODO(issue:#6891) We should be able to remove this loop,
'gzip')]) # and replace with return; yield
for request in request_iterator: for request in request_iterator:
yield request if pre_response_callback:
pre_response_callback(request, servicer_context)
yield request
return _handle_stream
def set_call_compression(compression_method, request_or_iterator,
servicer_context):
del request_or_iterator
servicer_context.set_compression(compression_method)
def disable_next_compression(request, servicer_context):
del request
servicer_context.disable_next_message_compression()
def disable_first_compression(request, servicer_context):
if int(request.decode('ascii')) == 0:
servicer_context.disable_next_message_compression()
class _MethodHandler(grpc.RpcMethodHandler): class _MethodHandler(grpc.RpcMethodHandler):
def __init__(self, request_streaming, response_streaming): def __init__(self, request_streaming, response_streaming,
pre_response_callback):
self.request_streaming = request_streaming self.request_streaming = request_streaming
self.response_streaming = response_streaming self.response_streaming = response_streaming
self.request_deserializer = None self.request_deserializer = None
@ -52,75 +145,239 @@ class _MethodHandler(grpc.RpcMethodHandler):
self.unary_stream = None self.unary_stream = None
self.stream_unary = None self.stream_unary = None
self.stream_stream = None self.stream_stream = None
if self.request_streaming and self.response_streaming: if self.request_streaming and self.response_streaming:
self.stream_stream = handle_stream self.stream_stream = _make_handle_stream_stream(
pre_response_callback)
elif not self.request_streaming and not self.response_streaming: elif not self.request_streaming and not self.response_streaming:
self.unary_unary = handle_unary self.unary_unary = _make_handle_unary_unary(pre_response_callback)
elif not self.request_streaming and self.response_streaming:
self.unary_stream = _make_handle_unary_stream(pre_response_callback)
else:
self.stream_unary = _make_handle_stream_unary(pre_response_callback)
class _GenericHandler(grpc.GenericRpcHandler): class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self, pre_response_callback):
self._pre_response_callback = pre_response_callback
def service(self, handler_call_details): def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY: if handler_call_details.method == _UNARY_UNARY:
return _MethodHandler(False, False) return _MethodHandler(False, False, self._pre_response_callback)
elif handler_call_details.method == _UNARY_STREAM:
return _MethodHandler(False, True, self._pre_response_callback)
elif handler_call_details.method == _STREAM_UNARY:
return _MethodHandler(True, False, self._pre_response_callback)
elif handler_call_details.method == _STREAM_STREAM: elif handler_call_details.method == _STREAM_STREAM:
return _MethodHandler(True, True) return _MethodHandler(True, True, self._pre_response_callback)
else: else:
return None return None
@contextlib.contextmanager
def _instrumented_client_server_pair(channel_kwargs, server_kwargs,
server_handler):
server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs)
server.add_generic_rpc_handlers((server_handler,))
server_port = server.add_insecure_port('{}:0'.format(_HOST))
server.start()
with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy:
proxy_port = proxy.get_port()
with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port),
**channel_kwargs) as client_channel:
try:
yield client_channel, proxy, server
finally:
server.stop(None)
def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function,
server_kwargs, server_handler, message):
with _instrumented_client_server_pair(channel_kwargs, server_kwargs,
server_handler) as pipeline:
client_channel, proxy, server = pipeline
client_function(client_channel, multicallable_kwargs, message)
return proxy.get_byte_count()
def _get_compression_ratios(client_function, first_channel_kwargs,
first_multicallable_kwargs, first_server_kwargs,
first_server_handler, second_channel_kwargs,
second_multicallable_kwargs, second_server_kwargs,
second_server_handler, message):
try:
# This test requires the byte length of each connection to be deterministic. As
# it turns out, flow control puts bytes on the wire in a nondeterministic
# manner. We disable it here in order to measure compression ratios
# deterministically.
os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] = 'true'
first_bytes_sent, first_bytes_received = _get_byte_counts(
first_channel_kwargs, first_multicallable_kwargs, client_function,
first_server_kwargs, first_server_handler, message)
second_bytes_sent, second_bytes_received = _get_byte_counts(
second_channel_kwargs, second_multicallable_kwargs, client_function,
second_server_kwargs, second_server_handler, message)
return ((
second_bytes_sent - first_bytes_sent) / float(first_bytes_sent),
(second_bytes_received - first_bytes_received) /
float(first_bytes_received))
finally:
del os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL']
def _unary_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_unary(_UNARY_UNARY)
response = multi_callable(message, **multicallable_kwargs)
if response != message:
raise RuntimeError("Request '{}' != Response '{}'".format(
message, response))
def _unary_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_stream(_UNARY_STREAM)
response_iterator = multi_callable(message, **multicallable_kwargs)
for response in response_iterator:
if response != message:
raise RuntimeError("Request '{}' != Response '{}'".format(
message, response))
def _stream_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_unary(_STREAM_UNARY)
requests = (_REQUEST for _ in range(_STREAM_LENGTH))
response = multi_callable(requests, **multicallable_kwargs)
if response != message:
raise RuntimeError("Request '{}' != Response '{}'".format(
message, response))
def _stream_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_stream(_STREAM_STREAM)
request_prefix = str(0).encode('ascii') * 100
requests = (
request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH))
response_iterator = multi_callable(requests, **multicallable_kwargs)
for i, response in enumerate(response_iterator):
if int(response.decode('ascii')) != i:
raise RuntimeError("Request '{}' != Response '{}'".format(
i, response))
class CompressionTest(unittest.TestCase): class CompressionTest(unittest.TestCase):
def setUp(self): def assertCompressed(self, compression_ratio):
self._server = test_common.test_server() self.assertLess(
self._server.add_generic_rpc_handlers((_GenericHandler(),)) compression_ratio,
self._port = self._server.add_insecure_port('[::]:0') -1.0 * _COMPRESSION_RATIO_THRESHOLD,
self._server.start() msg='Actual compression ratio: {}'.format(compression_ratio))
def tearDown(self): def assertNotCompressed(self, compression_ratio):
self._server.stop(None) self.assertGreaterEqual(
compression_ratio,
def testUnary(self): -1.0 * _COMPRESSION_RATIO_THRESHOLD,
request = b'\x00' * 100 msg='Actual compession ratio: {}'.format(compression_ratio))
# Client -> server compressed through default client channel compression def assertConfigurationCompressed(
# settings. Server -> client compressed via server-side metadata setting. self, client_streaming, server_streaming, channel_compression,
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer multicallable_compression, server_compression,
# literal with proper use of the public API. server_call_compression):
compressed_channel = grpc.insecure_channel( client_side_compressed = channel_compression or multicallable_compression
'localhost:%d' % self._port, server_side_compressed = server_compression or server_call_compression
options=[('grpc.default_compression_algorithm', 1)]) channel_kwargs = {
multi_callable = compressed_channel.unary_unary(_UNARY_UNARY) 'compression': channel_compression,
response = multi_callable(request) } if channel_compression else {}
self.assertEqual(request, response) multicallable_kwargs = {
'compression': multicallable_compression,
# Client -> server compressed through client metadata setting. Server -> } if multicallable_compression else {}
# client compressed via server-side metadata setting.
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer client_function = None
# literal with proper use of the public API. if not client_streaming and not server_streaming:
uncompressed_channel = grpc.insecure_channel( client_function = _unary_unary_client
'localhost:%d' % self._port, elif not client_streaming and server_streaming:
options=[('grpc.default_compression_algorithm', 0)]) client_function = _unary_stream_client
multi_callable = compressed_channel.unary_unary(_UNARY_UNARY) elif client_streaming and not server_streaming:
response = multi_callable( client_function = _stream_unary_client
request, metadata=[('grpc-internal-encoding-request', 'gzip')]) else:
self.assertEqual(request, response) client_function = _stream_stream_client
compressed_channel.close()
server_kwargs = {
def testStreaming(self): 'compression': server_compression,
request = b'\x00' * 100 } if server_compression else {}
server_handler = _GenericHandler(
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer functools.partial(set_call_compression, grpc.Compression.Gzip)
# literal with proper use of the public API. ) if server_call_compression else _GenericHandler(None)
compressed_channel = grpc.insecure_channel( sent_ratio, received_ratio = _get_compression_ratios(
'localhost:%d' % self._port, client_function, {}, {}, {}, _GenericHandler(None), channel_kwargs,
options=[('grpc.default_compression_algorithm', 1)]) multicallable_kwargs, server_kwargs, server_handler, _REQUEST)
multi_callable = compressed_channel.stream_stream(_STREAM_STREAM)
call = multi_callable(iter([request] * test_constants.STREAM_LENGTH)) if client_side_compressed:
for response in call: self.assertCompressed(sent_ratio)
self.assertEqual(request, response) else:
compressed_channel.close() self.assertNotCompressed(sent_ratio)
if server_side_compressed:
self.assertCompressed(received_ratio)
else:
self.assertNotCompressed(received_ratio)
def testDisableNextCompressionStreaming(self):
server_kwargs = {
'compression': grpc.Compression.Deflate,
}
_, received_ratio = _get_compression_ratios(
_stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
server_kwargs, _GenericHandler(disable_next_compression), _REQUEST)
self.assertNotCompressed(received_ratio)
def testDisableNextCompressionStreamingResets(self):
server_kwargs = {
'compression': grpc.Compression.Deflate,
}
_, received_ratio = _get_compression_ratios(
_stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
server_kwargs, _GenericHandler(disable_first_compression), _REQUEST)
self.assertCompressed(received_ratio)
def _get_compression_str(name, value):
return '{}{}'.format(name, _COMPRESSION_NAMES[value])
def _get_compression_test_name(client_streaming, server_streaming,
channel_compression, multicallable_compression,
server_compression, server_call_compression):
client_arity = 'Stream' if client_streaming else 'Unary'
server_arity = 'Stream' if server_streaming else 'Unary'
arity = '{}{}'.format(client_arity, server_arity)
channel_compression_str = _get_compression_str('Channel',
channel_compression)
multicallable_compression_str = _get_compression_str(
'Multicallable', multicallable_compression)
server_compression_str = _get_compression_str('Server', server_compression)
server_call_compression_str = _get_compression_str('ServerCall',
server_call_compression)
return 'test{}{}{}{}{}'.format(
arity, channel_compression_str, multicallable_compression_str,
server_compression_str, server_call_compression_str)
def _test_options():
for test_parameters in itertools.product(*_TEST_OPTIONS.values()):
yield dict(zip(_TEST_OPTIONS.keys(), test_parameters))
for options in _test_options():
def test_compression(**kwargs):
def _test_compression(self):
self.assertConfigurationCompressed(**kwargs)
return _test_compression
setattr(CompressionTest, _get_compression_test_name(**options),
test_compression(**options))
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig() logging.basicConfig()

@ -0,0 +1,164 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Proxies a TCP connection between a single client-server pair.
This proxy is not suitable for production, but should work well for cases in
which a test needs to spy on the bytes put on the wire between a server and
a client.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import select
import socket
import threading
_TCP_PROXY_BUFFER_SIZE = 1024
_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500)
def _create_socket_ipv6(bind_address):
listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
listen_socket.bind((bind_address, 0, 0, 0))
return listen_socket
def _create_socket_ipv4(bind_address):
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.bind((bind_address, 0))
return listen_socket
def _init_listen_socket(bind_address):
listen_socket = None
if socket.has_ipv6:
try:
listen_socket = _create_socket_ipv6(bind_address)
except socket.error:
listen_socket = _create_socket_ipv4(bind_address)
else:
listen_socket = _create_socket_ipv4(bind_address)
listen_socket.listen(1)
return listen_socket, listen_socket.getsockname()[1]
def _init_proxy_socket(gateway_address, gateway_port):
proxy_socket = socket.create_connection((gateway_address, gateway_port))
return proxy_socket
class TcpProxy(object):
"""Proxies a TCP connection between one client and one server."""
def __init__(self, bind_address, gateway_address, gateway_port):
self._bind_address = bind_address
self._gateway_address = gateway_address
self._gateway_port = gateway_port
self._byte_count_lock = threading.RLock()
self._sent_byte_count = 0
self._received_byte_count = 0
self._stop_event = threading.Event()
self._port = None
self._listen_socket = None
self._proxy_socket = None
# The following three attributes are owned by the serving thread.
self._northbound_data = b""
self._southbound_data = b""
self._client_sockets = []
self._thread = threading.Thread(target=self._run_proxy)
def start(self):
self._listen_socket, self._port = _init_listen_socket(
self._bind_address)
self._proxy_socket = _init_proxy_socket(self._gateway_address,
self._gateway_port)
self._thread.start()
def get_port(self):
return self._port
def _handle_reads(self, sockets_to_read):
for socket_to_read in sockets_to_read:
if socket_to_read is self._listen_socket:
client_socket, client_address = socket_to_read.accept()
self._client_sockets.append(client_socket)
elif socket_to_read is self._proxy_socket:
data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE)
with self._byte_count_lock:
self._received_byte_count += len(data)
self._northbound_data += data
elif socket_to_read in self._client_sockets:
data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE)
if data:
with self._byte_count_lock:
self._sent_byte_count += len(data)
self._southbound_data += data
else:
self._client_sockets.remove(socket_to_read)
else:
raise RuntimeError('Unidentified socket appeared in read set.')
def _handle_writes(self, sockets_to_write):
for socket_to_write in sockets_to_write:
if socket_to_write is self._proxy_socket:
if self._southbound_data:
self._proxy_socket.sendall(self._southbound_data)
self._southbound_data = b""
elif socket_to_write in self._client_sockets:
if self._northbound_data:
socket_to_write.sendall(self._northbound_data)
self._northbound_data = b""
def _run_proxy(self):
while not self._stop_event.is_set():
expected_reads = (self._listen_socket, self._proxy_socket) + tuple(
self._client_sockets)
expected_writes = expected_reads
sockets_to_read, sockets_to_write, _ = select.select(
expected_reads, expected_writes, (),
_TCP_PROXY_TIMEOUT.total_seconds())
self._handle_reads(sockets_to_read)
self._handle_writes(sockets_to_write)
for client_socket in self._client_sockets:
client_socket.close()
def stop(self):
self._stop_event.set()
self._thread.join()
self._listen_socket.close()
self._proxy_socket.close()
def get_byte_count(self):
with self._byte_count_lock:
return self._sent_byte_count, self._received_byte_count
def reset_byte_count(self):
with self._byte_count_lock:
self._byte_count = 0
self._received_byte_count = 0
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
Loading…
Cancel
Save