Implement compression API within gRPC Python.

pull/18564/head
Richard Belleville 6 years ago
parent 10e39e316c
commit 5afd77398e
  1. 4
      .pylintrc
  2. 6
      doc/python/sphinx/grpc.rst
  3. 8
      src/python/grpcio/grpc/BUILD.bazel
  4. 96
      src/python/grpcio/grpc/__init__.py
  5. 114
      src/python/grpcio/grpc/_channel.py
  6. 55
      src/python/grpcio/grpc/_compression.py
  7. 8
      src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
  8. 5
      src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
  9. 139
      src/python/grpcio/grpc/_interceptor.py
  10. 88
      src/python/grpcio/grpc/_server.py
  11. 6
      src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py
  12. 1
      src/python/grpcio_tests/commands.py
  13. 6
      src/python/grpcio_tests/tests/unit/BUILD.bazel
  14. 1
      src/python/grpcio_tests/tests/unit/_api_test.py
  15. 387
      src/python/grpcio_tests/tests/unit/_compression_test.py
  16. 164
      src/python/grpcio_tests/tests/unit/_tcp_proxy.py

@ -6,6 +6,8 @@ ignore=
src/python/grpcio/grpc/framework/foundation,
src/python/grpcio/grpc/framework/interfaces,
extension-pkg-whitelist=grpc._cython.cygrpc
[VARIABLES]
# TODO(https://github.com/PyCQA/pylint/issues/1345): How does the inspection
@ -17,7 +19,7 @@ dummy-variables-rgx=^ignored_|^unused_
# NOTE(nathaniel): Not particularly attached to this value; it just seems to
# be what works for us at the moment (excepting the dead-code-walking Beta
# API).
max-args=6
max-args=7
[MISCELLANEOUS]

@ -172,3 +172,9 @@ Future Interfaces
.. autoexception:: FutureTimeoutError
.. autoexception:: FutureCancelledError
.. autoclass:: Future
Compression
^^^^^^^^^^^
.. autoclass:: Compression

@ -12,6 +12,7 @@ py_library(
":channel",
":interceptor",
":server",
":compression",
"//src/python/grpcio/grpc/_cython:cygrpc",
"//src/python/grpcio/grpc/experimental",
"//src/python/grpcio/grpc/framework",
@ -31,12 +32,18 @@ py_library(
srcs = ["_auth.py"],
)
py_library(
name = "compression",
srcs = ["_compression.py"],
)
py_library(
name = "channel",
srcs = ["_channel.py"],
deps = [
":common",
":grpcio_metadata",
":compression",
],
)
@ -68,6 +75,7 @@ py_library(
srcs = ["_server.py"],
deps = [
":common",
":compression",
":interceptor",
],
)

@ -21,6 +21,7 @@ import sys
import six
from grpc._cython import cygrpc as _cygrpc
from grpc import _compression
logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -413,6 +414,8 @@ class ClientCallDetails(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional flag t
enable wait for ready mechanism.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
"""
@ -669,7 +672,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
"""Synchronously invokes the underlying RPC.
Args:
@ -681,6 +685,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
The response value for the RPC.
@ -698,7 +704,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
"""Synchronously invokes the underlying RPC.
Args:
@ -710,6 +717,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
The response value for the RPC and a Call value for the RPC.
@ -727,7 +736,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
"""Asynchronously invokes the underlying RPC.
Args:
@ -739,6 +749,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
An object that is both a Call for the RPC and a Future.
@ -759,7 +771,8 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
"""Invokes the underlying RPC.
Args:
@ -771,6 +784,8 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
An object that is both a Call for the RPC and an iterator of
@ -790,7 +805,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
"""Synchronously invokes the underlying RPC.
Args:
@ -803,6 +819,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
The response value for the RPC.
@ -820,7 +838,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
"""Synchronously invokes the underlying RPC on the client.
Args:
@ -833,6 +852,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
The response value for the RPC and a Call object for the RPC.
@ -850,7 +871,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
"""Asynchronously invokes the underlying RPC on the client.
Args:
@ -862,6 +884,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
An object that is both a Call for the RPC and a Future.
@ -882,7 +906,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
"""Invokes the underlying RPC on the client.
Args:
@ -894,6 +919,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
credentials: An optional CallCredentials for the RPC.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
An object that is both a Call for the RPC and an iterator of
@ -1097,6 +1124,17 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
"""
raise NotImplementedError()
def set_compression(self, compression):
"""Set the compression algorithm to be used for the entire call.
This is an EXPERIMENTAL method.
Args:
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip.
"""
raise NotImplementedError()
@abc.abstractmethod
def send_initial_metadata(self, initial_metadata):
"""Sends the initial metadata value to the client.
@ -1184,6 +1222,16 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
"""
raise NotImplementedError()
def disable_next_message_compression(self):
"""Disables compression for the next response message.
This is an EXPERIMENTAL method.
This method will override any compression configuration set during
server creation or set on the call.
"""
raise NotImplementedError()
##################### Service-Side Handler Interfaces ########################
@ -1682,7 +1730,7 @@ def channel_ready_future(channel):
return _utilities.channel_ready_future(channel)
def insecure_channel(target, options=None):
def insecure_channel(target, options=None, compression=None):
"""Creates an insecure Channel to a server.
The returned Channel is thread-safe.
@ -1691,15 +1739,18 @@ def insecure_channel(target, options=None):
target: The server address
options: An optional list of key-value pairs (channel args
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option.
Returns:
A Channel.
"""
from grpc import _channel # pylint: disable=cyclic-import
return _channel.Channel(target, () if options is None else options, None)
return _channel.Channel(target, ()
if options is None else options, None, compression)
def secure_channel(target, credentials, options=None):
def secure_channel(target, credentials, options=None, compression=None):
"""Creates a secure Channel to a server.
The returned Channel is thread-safe.
@ -1709,13 +1760,15 @@ def secure_channel(target, credentials, options=None):
credentials: A ChannelCredentials instance.
options: An optional list of key-value pairs (channel args
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option.
Returns:
A Channel.
"""
from grpc import _channel # pylint: disable=cyclic-import
return _channel.Channel(target, () if options is None else options,
credentials._credentials)
credentials._credentials, compression)
def intercept_channel(channel, *interceptors):
@ -1750,7 +1803,8 @@ def server(thread_pool,
handlers=None,
interceptors=None,
options=None,
maximum_concurrent_rpcs=None):
maximum_concurrent_rpcs=None,
compression=None):
"""Creates a Server with which RPCs can be serviced.
Args:
@ -1768,6 +1822,9 @@ def server(thread_pool,
maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server
will service before returning RESOURCE_EXHAUSTED status, or None to
indicate no limit.
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This compression algorithm will be used for the
lifetime of the server unless overridden. This is an EXPERIMENTAL option.
Returns:
A Server object.
@ -1777,7 +1834,7 @@ def server(thread_pool,
if handlers is None else handlers, ()
if interceptors is None else interceptors, ()
if options is None else options,
maximum_concurrent_rpcs)
maximum_concurrent_rpcs, compression)
@contextlib.contextmanager
@ -1788,6 +1845,16 @@ def _create_servicer_context(rpc_event, state, request_deserializer):
context._finalize_state() # pylint: disable=protected-access
class Compression(enum.IntEnum):
"""Indicates the compression method to be used for an RPC.
This enumeration is part of an EXPERIMENTAL API.
"""
NoCompression = _compression.NoCompression
Deflate = _compression.Deflate
Gzip = _compression.Gzip
################################### __all__ #################################
__all__ = (
@ -1805,6 +1872,7 @@ __all__ = (
'AuthMetadataContext',
'AuthMetadataPluginCallback',
'AuthMetadataPlugin',
'Compression',
'ClientCallDetails',
'ServerCertificateConfiguration',
'ServerCredentials',

@ -19,6 +19,7 @@ import threading
import time
import grpc
from grpc import _compression
from grpc import _common
from grpc import _grpcio_metadata
from grpc._cython import cygrpc
@ -512,17 +513,19 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
def _prepare(self, request, timeout, metadata, wait_for_ready):
def _prepare(self, request, timeout, metadata, wait_for_ready, compression):
deadline, serialized_request, rendezvous = _start_unary_request(
request, timeout, self._request_serializer)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
if serialized_request is None:
return None, None, None, rendezvous
else:
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
operations = (
cygrpc.SendInitialMetadataOperation(metadata,
cygrpc.SendInitialMetadataOperation(augmented_metadata,
initial_metadata_flags),
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
@ -532,18 +535,17 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
)
return state, operations, deadline, None
def _blocking(self, request, timeout, metadata, credentials,
wait_for_ready):
def _blocking(self, request, timeout, metadata, credentials, wait_for_ready,
compression):
state, operations, deadline, rendezvous = self._prepare(
request, timeout, metadata, wait_for_ready)
request, timeout, metadata, wait_for_ready, compression)
if state is None:
raise rendezvous # pylint: disable-msg=raising-bad-type
else:
deadline_to_propagate = _determine_deadline(deadline)
call = self._channel.segregated_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
self._method, None, deadline_to_propagate, metadata, None
if credentials is None else credentials._credentials, ((
self._method, None, _determine_deadline(deadline), metadata,
None if credentials is None else credentials._credentials, ((
operations,
None,
),), self._context)
@ -556,9 +558,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
state, call, = self._blocking(request, timeout, metadata, credentials,
wait_for_ready)
wait_for_ready, compression)
return _end_unary_response_blocking(state, call, False, None)
def with_call(self,
@ -566,9 +569,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
state, call, = self._blocking(request, timeout, metadata, credentials,
wait_for_ready)
wait_for_ready, compression)
return _end_unary_response_blocking(state, call, True, None)
def future(self,
@ -576,9 +580,10 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
state, operations, deadline, rendezvous = self._prepare(
request, timeout, metadata, wait_for_ready)
request, timeout, metadata, wait_for_ready, compression)
if state is None:
raise rendezvous # pylint: disable-msg=raising-bad-type
else:
@ -604,12 +609,14 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
def __call__(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
def __call__( # pylint: disable=too-many-locals
self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
deadline, serialized_request, rendezvous = _start_unary_request(
request, timeout, self._request_serializer)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
@ -617,10 +624,12 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
if serialized_request is None:
raise rendezvous # pylint: disable-msg=raising-bad-type
else:
augmented_metadata = _compression.augment_metadata(
metadata, compression)
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
operationses = (
(
cygrpc.SendInitialMetadataOperation(metadata,
cygrpc.SendInitialMetadataOperation(augmented_metadata,
initial_metadata_flags),
cygrpc.SendMessageOperation(serialized_request,
_EMPTY_FLAGS),
@ -629,12 +638,13 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
),
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
)
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
self._method, None, _determine_deadline(deadline), metadata,
None if credentials is None else credentials._credentials,
operationses, event_handler, self._context)
None if credentials is None else
credentials._credentials, operationses,
_event_handler(state,
self._response_deserializer), self._context)
return _Rendezvous(state, call, self._response_deserializer,
deadline)
@ -652,18 +662,19 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
self._context = cygrpc.build_census_context()
def _blocking(self, request_iterator, timeout, metadata, credentials,
wait_for_ready):
wait_for_ready, compression):
deadline = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready)
deadline_to_propagate = _determine_deadline(deadline)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
call = self._channel.segregated_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
None, deadline_to_propagate, metadata, None
None, _determine_deadline(deadline), augmented_metadata, None
if credentials is None else credentials._credentials,
_stream_unary_invocation_operationses_and_tags(
metadata, initial_metadata_flags), self._context)
augmented_metadata, initial_metadata_flags), self._context)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer, None)
while True:
@ -680,9 +691,10 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
state, call, = self._blocking(request_iterator, timeout, metadata,
credentials, wait_for_ready)
credentials, wait_for_ready, compression)
return _end_unary_response_blocking(state, call, False, None)
def with_call(self,
@ -690,9 +702,10 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
state, call, = self._blocking(request_iterator, timeout, metadata,
credentials, wait_for_ready)
credentials, wait_for_ready, compression)
return _end_unary_response_blocking(state, call, True, None)
def future(self,
@ -700,15 +713,18 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
deadline = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
event_handler = _event_handler(state, self._response_deserializer)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
call = self._managed_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
None, deadline, metadata, None
None, deadline, augmented_metadata, None
if credentials is None else credentials._credentials,
_stream_unary_invocation_operationses(
metadata, initial_metadata_flags), event_handler, self._context)
@ -734,24 +750,26 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
deadline = _deadline(timeout)
state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None)
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
operationses = (
(
cygrpc.SendInitialMetadataOperation(metadata,
cygrpc.SendInitialMetadataOperation(augmented_metadata,
initial_metadata_flags),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
),
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
)
event_handler = _event_handler(state, self._response_deserializer)
deadline_to_propagate = _determine_deadline(deadline)
call = self._managed_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
None, deadline_to_propagate, metadata, None
None, _determine_deadline(deadline), augmented_metadata, None
if credentials is None else credentials._credentials, operationses,
event_handler, self._context)
_consume_request_iterator(request_iterator, state, call,
@ -982,28 +1000,30 @@ def _unsubscribe(state, callback):
break
def _options(options):
return list(options) + [
(
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),
]
def _augment_options(base_options, compression):
compression_option = _compression.create_channel_option(compression)
return tuple(base_options) + compression_option + ((
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),)
class Channel(grpc.Channel):
"""A cygrpc.Channel-backed implementation of grpc.Channel."""
def __init__(self, target, options, credentials):
def __init__(self, target, options, credentials, compression):
"""Constructor.
Args:
target: The target to which to connect.
options: Configuration options for the channel.
credentials: A cygrpc.ChannelCredentials or None.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
"""
self._channel = cygrpc.Channel(
_common.encode(target), _options(options), credentials)
_common.encode(target), _augment_options(options, compression),
credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
cygrpc.fork_register_channel(self)

@ -0,0 +1,55 @@
# Copyright 2019 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from grpc._cython import cygrpc
NoCompression = cygrpc.CompressionAlgorithm.none
Deflate = cygrpc.CompressionAlgorithm.deflate
Gzip = cygrpc.CompressionAlgorithm.gzip
_METADATA_STRING_MAPPING = {
NoCompression: 'identity',
Deflate: 'deflate',
Gzip: 'gzip',
}
def _compression_algorithm_to_metadata_value(compression):
return _METADATA_STRING_MAPPING[compression]
def compression_algorithm_to_metadata(compression):
return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
_compression_algorithm_to_metadata_value(compression))
def create_channel_option(compression):
return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM,
int(compression)),) if compression else ()
def augment_metadata(metadata, compression):
if not metadata and not compression:
return None
base_metadata = tuple(metadata) if metadata else ()
compression_metadata = (
compression_algorithm_to_metadata(compression),) if compression else ()
return base_metadata + compression_metadata
__all__ = (
"NoCompression",
"Deflate",
"Gzip",
)

@ -140,7 +140,8 @@ cdef extern from "grpc/grpc.h":
const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING
const char *GRPC_SSL_TARGET_NAME_OVERRIDE_ARG
const char *GRPC_SSL_SESSION_CACHE_ARG
const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM
const char *_GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM \
"GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM"
const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL
const char *GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET
@ -618,3 +619,8 @@ cdef extern from "grpc/compression.h":
int grpc_compression_options_is_algorithm_enabled(
const grpc_compression_options *opts,
grpc_compression_algorithm algorithm) nogil
cdef extern from "grpc/impl/codegen/compression_types.h":
const char *_GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY \
"GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY"

@ -108,6 +108,11 @@ class OperationType:
receive_status_on_client = GRPC_OP_RECV_STATUS_ON_CLIENT
receive_close_on_server = GRPC_OP_RECV_CLOSE_ON_SERVER
GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM= (
_GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM)
GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY = (
_GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY)
class CompressionAlgorithm:
none = GRPC_COMPRESS_NONE

@ -44,9 +44,9 @@ def service_pipeline(interceptors):
class _ClientCallDetails(
collections.namedtuple(
'_ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')),
collections.namedtuple('_ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials',
'wait_for_ready', 'compression')),
grpc.ClientCallDetails):
pass
@ -77,7 +77,12 @@ def _unwrap_client_call_details(call_details, default_details):
except AttributeError:
wait_for_ready = default_details.wait_for_ready
return method, timeout, metadata, credentials, wait_for_ready
try:
compression = call_details.compression
except AttributeError:
compression = default_details.compression
return method, timeout, metadata, credentials, wait_for_ready, compression
class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
@ -206,13 +211,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
response, ignored_call = self._with_call(
request,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready)
wait_for_ready=wait_for_ready,
compression=compression)
return response
def _with_call(self,
@ -220,20 +227,25 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
client_call_details = _ClientCallDetails(
self._method, timeout, metadata, credentials, wait_for_ready)
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = (
_unwrap_client_call_details(new_details, client_call_details))
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
try:
response, call = self._thunk(new_method).with_call(
request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready)
wait_for_ready=new_wait_for_ready,
compression=new_compression)
return _UnaryOutcome(response, call)
except grpc.RpcError as rpc_error:
return rpc_error
@ -249,32 +261,39 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
return self._with_call(
request,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready)
wait_for_ready=wait_for_ready,
compression=compression)
def future(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
client_call_details = _ClientCallDetails(
self._method, timeout, metadata, credentials, wait_for_ready)
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = (
_unwrap_client_call_details(new_details, client_call_details))
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method).future(
request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready)
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_unary_unary(
@ -295,19 +314,24 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
client_call_details = _ClientCallDetails(
self._method, timeout, metadata, credentials, wait_for_ready)
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = (
_unwrap_client_call_details(new_details, client_call_details))
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method)(
request,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready)
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_unary_stream(
@ -328,13 +352,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
response, ignored_call = self._with_call(
request_iterator,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready)
wait_for_ready=wait_for_ready,
compression=compression)
return response
def _with_call(self,
@ -342,20 +368,25 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
client_call_details = _ClientCallDetails(
self._method, timeout, metadata, credentials, wait_for_ready)
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = (
_unwrap_client_call_details(new_details, client_call_details))
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
try:
response, call = self._thunk(new_method).with_call(
request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready)
wait_for_ready=new_wait_for_ready,
compression=new_compression)
return _UnaryOutcome(response, call)
except grpc.RpcError as rpc_error:
return rpc_error
@ -371,32 +402,39 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
wait_for_ready=None,
compression=None):
return self._with_call(
request_iterator,
timeout=timeout,
metadata=metadata,
credentials=credentials,
wait_for_ready=wait_for_ready)
wait_for_ready=wait_for_ready,
compression=compression)
def future(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
client_call_details = _ClientCallDetails(
self._method, timeout, metadata, credentials, wait_for_ready)
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = (
_unwrap_client_call_details(new_details, client_call_details))
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method).future(
request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready)
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_stream_unary(
@ -417,19 +455,24 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None):
client_call_details = _ClientCallDetails(
self._method, timeout, metadata, credentials, wait_for_ready)
wait_for_ready=None,
compression=None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
def continuation(new_details, request_iterator):
new_method, new_timeout, new_metadata, new_credentials, new_wait_for_ready = (
_unwrap_client_call_details(new_details, client_call_details))
(new_method, new_timeout, new_metadata, new_credentials,
new_wait_for_ready,
new_compression) = (_unwrap_client_call_details(
new_details, client_call_details))
return self._thunk(new_method)(
request_iterator,
timeout=new_timeout,
metadata=new_metadata,
credentials=new_credentials,
wait_for_ready=new_wait_for_ready)
wait_for_ready=new_wait_for_ready,
compression=new_compression)
try:
return self._interceptor.intercept_stream_stream(

@ -24,6 +24,7 @@ import six
import grpc
from grpc import _common
from grpc import _compression
from grpc import _interceptor
from grpc._cython import cygrpc
@ -94,6 +95,7 @@ class _RPCState(object):
self.request = None
self.client = _OPEN
self.initial_metadata_allowed = True
self.compression_algorithm = None
self.disable_next_compression = False
self.trailing_metadata = None
self.code = None
@ -129,13 +131,33 @@ def _send_status_from_server(state, token):
return send_status_from_server
def _get_initial_metadata(state, metadata):
with state.condition:
if state.compression_algorithm:
compression_metadata = (
_compression.compression_algorithm_to_metadata(
state.compression_algorithm),)
if metadata is None:
return compression_metadata
else:
return compression_metadata + tuple(metadata)
else:
return metadata
def _get_initial_metadata_operation(state, metadata):
operation = cygrpc.SendInitialMetadataOperation(
_get_initial_metadata(state, metadata), _EMPTY_FLAGS)
return operation
def _abort(state, call, code, details):
if state.client is not _CANCELLED:
effective_code = _abortion_code(state, code)
effective_details = details if state.details is None else state.details
if state.initial_metadata_allowed:
operations = (
cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
_get_initial_metadata_operation(state, None),
cygrpc.SendStatusFromServerOperation(
state.trailing_metadata, effective_code, effective_details,
_EMPTY_FLAGS),
@ -259,14 +281,18 @@ class _Context(grpc.ServicerContext):
cygrpc.auth_context(self._rpc_event.call))
}
def set_compression(self, compression):
with self._state.condition:
self._state.compression_algorithm = compression
def send_initial_metadata(self, initial_metadata):
with self._state.condition:
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
else:
if self._state.initial_metadata_allowed:
operation = cygrpc.SendInitialMetadataOperation(
initial_metadata, _EMPTY_FLAGS)
operation = _get_initial_metadata_operation(
self._state, initial_metadata)
self._rpc_event.call.start_server_batch(
(operation,), _send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False
@ -400,10 +426,13 @@ def _call_behavior(rpc_event,
with _create_servicer_context(rpc_event, state,
request_deserializer) as context:
try:
response_or_iterator = None
if send_response_callback is not None:
return behavior(argument, context, send_response_callback), True
response_or_iterator = behavior(argument, context,
send_response_callback)
else:
return behavior(argument, context), True
response_or_iterator = behavior(argument, context)
return response_or_iterator, True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if state.aborted:
@ -447,6 +476,18 @@ def _serialize_response(rpc_event, state, response, response_serializer):
return serialized_response
def _get_send_message_op_flags_from_state(state):
if state.disable_next_compression:
return cygrpc.WriteFlag.no_compress
else:
return _EMPTY_FLAGS
def _reset_per_message_state(state):
with state.condition:
state.disable_next_compression = False
def _send_response(rpc_event, state, serialized_response):
with state.condition:
if not _is_rpc_state_active(state):
@ -454,19 +495,22 @@ def _send_response(rpc_event, state, serialized_response):
else:
if state.initial_metadata_allowed:
operations = (
cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
cygrpc.SendMessageOperation(serialized_response,
_EMPTY_FLAGS),
_get_initial_metadata_operation(state, None),
cygrpc.SendMessageOperation(
serialized_response,
_get_send_message_op_flags_from_state(state)),
)
state.initial_metadata_allowed = False
token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
else:
operations = (cygrpc.SendMessageOperation(
serialized_response, _EMPTY_FLAGS),)
serialized_response,
_get_send_message_op_flags_from_state(state)),)
token = _SEND_MESSAGE_TOKEN
rpc_event.call.start_server_batch(operations,
_send_message(state, token))
state.due.add(token)
_reset_per_message_state(state)
while True:
state.condition.wait()
if token not in state.due:
@ -483,16 +527,17 @@ def _status(rpc_event, state, serialized_response):
state.trailing_metadata, code, details, _EMPTY_FLAGS),
]
if state.initial_metadata_allowed:
operations.append(
cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS))
operations.append(_get_initial_metadata_operation(state, None))
if serialized_response is not None:
operations.append(
cygrpc.SendMessageOperation(serialized_response,
_EMPTY_FLAGS))
cygrpc.SendMessageOperation(
serialized_response,
_get_send_message_op_flags_from_state(state)))
rpc_event.call.start_server_batch(
operations,
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
state.statused = True
_reset_per_message_state(state)
state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
@ -639,13 +684,13 @@ def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline):
def _reject_rpc(rpc_event, status, details):
rpc_state = _RPCState()
operations = (
cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
_get_initial_metadata_operation(rpc_state, None),
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation(None, status, details,
_EMPTY_FLAGS),
)
rpc_state = _RPCState()
rpc_event.call.start_server_batch(operations,
lambda ignored_event: (rpc_state, (),))
return rpc_state
@ -883,13 +928,18 @@ def _validate_generic_rpc_handlers(generic_rpc_handlers):
'not have "service" method!'.format(generic_rpc_handler))
def _augment_options(base_options, compression):
compression_option = _compression.create_channel_option(compression)
return tuple(base_options) + compression_option
class _Server(grpc.Server):
# pylint: disable=too-many-arguments
def __init__(self, thread_pool, generic_handlers, interceptors, options,
maximum_concurrent_rpcs):
maximum_concurrent_rpcs, compression):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(options)
server = cygrpc.Server(_augment_options(options, compression))
server.register_completion_queue(completion_queue)
self._state = _ServerState(completion_queue, server, generic_handlers,
_interceptor.service_pipeline(interceptors),
@ -920,7 +970,7 @@ class _Server(grpc.Server):
def create_server(thread_pool, generic_rpc_handlers, interceptors, options,
maximum_concurrent_rpcs):
maximum_concurrent_rpcs, compression):
_validate_generic_rpc_handlers(generic_rpc_handlers)
return _Server(thread_pool, generic_rpc_handlers, interceptors, options,
maximum_concurrent_rpcs)
maximum_concurrent_rpcs, compression)

@ -56,6 +56,9 @@ class ServicerContext(grpc.ServicerContext):
def auth_context(self):
raise NotImplementedError()
def set_compression(self):
raise NotImplementedError()
def send_initial_metadata(self, initial_metadata):
initial_metadata_sent = self._rpc.send_initial_metadata(
_common.fuss_with_metadata(initial_metadata))
@ -63,6 +66,9 @@ class ServicerContext(grpc.ServicerContext):
raise ValueError(
'ServicerContext.send_initial_metadata called too late!')
def disable_next_message_compression(self):
raise NotImplementedError()
def set_trailing_metadata(self, trailing_metadata):
self._rpc.set_trailing_metadata(
_common.fuss_with_metadata(trailing_metadata))

@ -117,6 +117,7 @@ class TestGevent(setuptools.Command):
# eventually succeed, but need to dig into performance issues.
'unit._cython._no_messages_server_completion_queue_per_call_test.Test.test_rpcs',
'unit._cython._no_messages_single_server_completion_queue_test.Test.test_rpcs',
'unit._compression_test',
# TODO(https://github.com/grpc/grpc/issues/16890) enable this test
'unit._cython._channel_test.ChannelTest.test_multiple_channels_lonely_connectivity',
# I have no idea why this doesn't work in gevent, but it shouldn't even be

@ -33,6 +33,11 @@ GRPCIO_TESTS_UNIT = [
"_session_cache_test.py",
]
py_library(
name = "_tcp_proxy",
srcs = ["_tcp_proxy.py"],
)
py_library(
name = "resources",
srcs = ["resources.py"],
@ -80,6 +85,7 @@ py_library(
":_exit_scenarios",
":_server_shutdown_scenarios",
":_from_grpc_import_star",
":_tcp_proxy",
"//src/python/grpcio_tests/tests/unit/framework/common",
"//src/python/grpcio_tests/tests/testing",
requirement('six'),

@ -31,6 +31,7 @@ class AllTest(unittest.TestCase):
'FutureCancelledError',
'Future',
'ChannelConnectivity',
'Compression',
'StatusCode',
'Status',
'RpcError',

@ -13,37 +13,130 @@
# limitations under the License.
"""Tests server and client side compression."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import contextlib
from concurrent import futures
import functools
import itertools
import logging
import os
import grpc
from grpc import _grpcio_metadata
from tests.unit import test_common
from tests.unit.framework.common import test_constants
from tests.unit import _tcp_proxy
_UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream'
_STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream'
# Cut down on test time.
_STREAM_LENGTH = test_constants.STREAM_LENGTH // 8
_HOST = 'localhost'
_REQUEST = b'\x00' * 100
_COMPRESSION_RATIO_THRESHOLD = 0.1
_COMPRESSION_METHODS = (
None,
# Disabled for test tractability.
# grpc.Compression.NoCompression,
grpc.Compression.Deflate,
grpc.Compression.Gzip,
)
_COMPRESSION_NAMES = {
None: 'Uncompressed',
grpc.Compression.NoCompression: 'NoCompression',
grpc.Compression.Deflate: 'DeflateCompression',
grpc.Compression.Gzip: 'GzipCompression',
}
_TEST_OPTIONS = {
'client_streaming': (True, False),
'server_streaming': (True, False),
'channel_compression': _COMPRESSION_METHODS,
'multicallable_compression': _COMPRESSION_METHODS,
'server_compression': _COMPRESSION_METHODS,
'server_call_compression': _COMPRESSION_METHODS,
}
def _make_handle_unary_unary(pre_response_callback):
def _handle_unary(request, servicer_context):
if pre_response_callback:
pre_response_callback(request, servicer_context)
return request
return _handle_unary
def _make_handle_unary_stream(pre_response_callback):
def _handle_unary_stream(request, servicer_context):
if pre_response_callback:
pre_response_callback(request, servicer_context)
for _ in range(_STREAM_LENGTH):
yield request
return _handle_unary_stream
def _make_handle_stream_unary(pre_response_callback):
def _handle_stream_unary(request_iterator, servicer_context):
if pre_response_callback:
pre_response_callback(request_iterator, servicer_context)
response = None
for request in request_iterator:
if not response:
response = request
return response
def handle_unary(request, servicer_context):
servicer_context.send_initial_metadata([('grpc-internal-encoding-request',
'gzip')])
return request
return _handle_stream_unary
def handle_stream(request_iterator, servicer_context):
# TODO(issue:#6891) We should be able to remove this loop,
# and replace with return; yield
servicer_context.send_initial_metadata([('grpc-internal-encoding-request',
'gzip')])
for request in request_iterator:
yield request
def _make_handle_stream_stream(pre_response_callback):
def _handle_stream(request_iterator, servicer_context):
# TODO(issue:#6891) We should be able to remove this loop,
# and replace with return; yield
for request in request_iterator:
if pre_response_callback:
pre_response_callback(request, servicer_context)
yield request
return _handle_stream
def set_call_compression(compression_method, request_or_iterator,
servicer_context):
del request_or_iterator
servicer_context.set_compression(compression_method)
def disable_next_compression(request, servicer_context):
del request
servicer_context.disable_next_message_compression()
def disable_first_compression(request, servicer_context):
if int(request.decode('ascii')) == 0:
servicer_context.disable_next_message_compression()
class _MethodHandler(grpc.RpcMethodHandler):
def __init__(self, request_streaming, response_streaming):
def __init__(self, request_streaming, response_streaming,
pre_response_callback):
self.request_streaming = request_streaming
self.response_streaming = response_streaming
self.request_deserializer = None
@ -52,75 +145,239 @@ class _MethodHandler(grpc.RpcMethodHandler):
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
if self.request_streaming and self.response_streaming:
self.stream_stream = handle_stream
self.stream_stream = _make_handle_stream_stream(
pre_response_callback)
elif not self.request_streaming and not self.response_streaming:
self.unary_unary = handle_unary
self.unary_unary = _make_handle_unary_unary(pre_response_callback)
elif not self.request_streaming and self.response_streaming:
self.unary_stream = _make_handle_unary_stream(pre_response_callback)
else:
self.stream_unary = _make_handle_stream_unary(pre_response_callback)
class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self, pre_response_callback):
self._pre_response_callback = pre_response_callback
def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY:
return _MethodHandler(False, False)
return _MethodHandler(False, False, self._pre_response_callback)
elif handler_call_details.method == _UNARY_STREAM:
return _MethodHandler(False, True, self._pre_response_callback)
elif handler_call_details.method == _STREAM_UNARY:
return _MethodHandler(True, False, self._pre_response_callback)
elif handler_call_details.method == _STREAM_STREAM:
return _MethodHandler(True, True)
return _MethodHandler(True, True, self._pre_response_callback)
else:
return None
@contextlib.contextmanager
def _instrumented_client_server_pair(channel_kwargs, server_kwargs,
server_handler):
server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs)
server.add_generic_rpc_handlers((server_handler,))
server_port = server.add_insecure_port('{}:0'.format(_HOST))
server.start()
with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy:
proxy_port = proxy.get_port()
with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port),
**channel_kwargs) as client_channel:
try:
yield client_channel, proxy, server
finally:
server.stop(None)
def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function,
server_kwargs, server_handler, message):
with _instrumented_client_server_pair(channel_kwargs, server_kwargs,
server_handler) as pipeline:
client_channel, proxy, server = pipeline
client_function(client_channel, multicallable_kwargs, message)
return proxy.get_byte_count()
def _get_compression_ratios(client_function, first_channel_kwargs,
first_multicallable_kwargs, first_server_kwargs,
first_server_handler, second_channel_kwargs,
second_multicallable_kwargs, second_server_kwargs,
second_server_handler, message):
try:
# This test requires the byte length of each connection to be deterministic. As
# it turns out, flow control puts bytes on the wire in a nondeterministic
# manner. We disable it here in order to measure compression ratios
# deterministically.
os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] = 'true'
first_bytes_sent, first_bytes_received = _get_byte_counts(
first_channel_kwargs, first_multicallable_kwargs, client_function,
first_server_kwargs, first_server_handler, message)
second_bytes_sent, second_bytes_received = _get_byte_counts(
second_channel_kwargs, second_multicallable_kwargs, client_function,
second_server_kwargs, second_server_handler, message)
return ((
second_bytes_sent - first_bytes_sent) / float(first_bytes_sent),
(second_bytes_received - first_bytes_received) /
float(first_bytes_received))
finally:
del os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL']
def _unary_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_unary(_UNARY_UNARY)
response = multi_callable(message, **multicallable_kwargs)
if response != message:
raise RuntimeError("Request '{}' != Response '{}'".format(
message, response))
def _unary_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_stream(_UNARY_STREAM)
response_iterator = multi_callable(message, **multicallable_kwargs)
for response in response_iterator:
if response != message:
raise RuntimeError("Request '{}' != Response '{}'".format(
message, response))
def _stream_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_unary(_STREAM_UNARY)
requests = (_REQUEST for _ in range(_STREAM_LENGTH))
response = multi_callable(requests, **multicallable_kwargs)
if response != message:
raise RuntimeError("Request '{}' != Response '{}'".format(
message, response))
def _stream_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_stream(_STREAM_STREAM)
request_prefix = str(0).encode('ascii') * 100
requests = (
request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH))
response_iterator = multi_callable(requests, **multicallable_kwargs)
for i, response in enumerate(response_iterator):
if int(response.decode('ascii')) != i:
raise RuntimeError("Request '{}' != Response '{}'".format(
i, response))
class CompressionTest(unittest.TestCase):
def setUp(self):
self._server = test_common.test_server()
self._server.add_generic_rpc_handlers((_GenericHandler(),))
self._port = self._server.add_insecure_port('[::]:0')
self._server.start()
def tearDown(self):
self._server.stop(None)
def testUnary(self):
request = b'\x00' * 100
# Client -> server compressed through default client channel compression
# settings. Server -> client compressed via server-side metadata setting.
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
# literal with proper use of the public API.
compressed_channel = grpc.insecure_channel(
'localhost:%d' % self._port,
options=[('grpc.default_compression_algorithm', 1)])
multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
response = multi_callable(request)
self.assertEqual(request, response)
# Client -> server compressed through client metadata setting. Server ->
# client compressed via server-side metadata setting.
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer
# literal with proper use of the public API.
uncompressed_channel = grpc.insecure_channel(
'localhost:%d' % self._port,
options=[('grpc.default_compression_algorithm', 0)])
multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
response = multi_callable(
request, metadata=[('grpc-internal-encoding-request', 'gzip')])
self.assertEqual(request, response)
compressed_channel.close()
def testStreaming(self):
request = b'\x00' * 100
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
# literal with proper use of the public API.
compressed_channel = grpc.insecure_channel(
'localhost:%d' % self._port,
options=[('grpc.default_compression_algorithm', 1)])
multi_callable = compressed_channel.stream_stream(_STREAM_STREAM)
call = multi_callable(iter([request] * test_constants.STREAM_LENGTH))
for response in call:
self.assertEqual(request, response)
compressed_channel.close()
def assertCompressed(self, compression_ratio):
self.assertLess(
compression_ratio,
-1.0 * _COMPRESSION_RATIO_THRESHOLD,
msg='Actual compression ratio: {}'.format(compression_ratio))
def assertNotCompressed(self, compression_ratio):
self.assertGreaterEqual(
compression_ratio,
-1.0 * _COMPRESSION_RATIO_THRESHOLD,
msg='Actual compession ratio: {}'.format(compression_ratio))
def assertConfigurationCompressed(
self, client_streaming, server_streaming, channel_compression,
multicallable_compression, server_compression,
server_call_compression):
client_side_compressed = channel_compression or multicallable_compression
server_side_compressed = server_compression or server_call_compression
channel_kwargs = {
'compression': channel_compression,
} if channel_compression else {}
multicallable_kwargs = {
'compression': multicallable_compression,
} if multicallable_compression else {}
client_function = None
if not client_streaming and not server_streaming:
client_function = _unary_unary_client
elif not client_streaming and server_streaming:
client_function = _unary_stream_client
elif client_streaming and not server_streaming:
client_function = _stream_unary_client
else:
client_function = _stream_stream_client
server_kwargs = {
'compression': server_compression,
} if server_compression else {}
server_handler = _GenericHandler(
functools.partial(set_call_compression, grpc.Compression.Gzip)
) if server_call_compression else _GenericHandler(None)
sent_ratio, received_ratio = _get_compression_ratios(
client_function, {}, {}, {}, _GenericHandler(None), channel_kwargs,
multicallable_kwargs, server_kwargs, server_handler, _REQUEST)
if client_side_compressed:
self.assertCompressed(sent_ratio)
else:
self.assertNotCompressed(sent_ratio)
if server_side_compressed:
self.assertCompressed(received_ratio)
else:
self.assertNotCompressed(received_ratio)
def testDisableNextCompressionStreaming(self):
server_kwargs = {
'compression': grpc.Compression.Deflate,
}
_, received_ratio = _get_compression_ratios(
_stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
server_kwargs, _GenericHandler(disable_next_compression), _REQUEST)
self.assertNotCompressed(received_ratio)
def testDisableNextCompressionStreamingResets(self):
server_kwargs = {
'compression': grpc.Compression.Deflate,
}
_, received_ratio = _get_compression_ratios(
_stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
server_kwargs, _GenericHandler(disable_first_compression), _REQUEST)
self.assertCompressed(received_ratio)
def _get_compression_str(name, value):
return '{}{}'.format(name, _COMPRESSION_NAMES[value])
def _get_compression_test_name(client_streaming, server_streaming,
channel_compression, multicallable_compression,
server_compression, server_call_compression):
client_arity = 'Stream' if client_streaming else 'Unary'
server_arity = 'Stream' if server_streaming else 'Unary'
arity = '{}{}'.format(client_arity, server_arity)
channel_compression_str = _get_compression_str('Channel',
channel_compression)
multicallable_compression_str = _get_compression_str(
'Multicallable', multicallable_compression)
server_compression_str = _get_compression_str('Server', server_compression)
server_call_compression_str = _get_compression_str('ServerCall',
server_call_compression)
return 'test{}{}{}{}{}'.format(
arity, channel_compression_str, multicallable_compression_str,
server_compression_str, server_call_compression_str)
def _test_options():
for test_parameters in itertools.product(*_TEST_OPTIONS.values()):
yield dict(zip(_TEST_OPTIONS.keys(), test_parameters))
for options in _test_options():
def test_compression(**kwargs):
def _test_compression(self):
self.assertConfigurationCompressed(**kwargs)
return _test_compression
setattr(CompressionTest, _get_compression_test_name(**options),
test_compression(**options))
if __name__ == '__main__':
logging.basicConfig()

@ -0,0 +1,164 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Proxies a TCP connection between a single client-server pair.
This proxy is not suitable for production, but should work well for cases in
which a test needs to spy on the bytes put on the wire between a server and
a client.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import select
import socket
import threading
_TCP_PROXY_BUFFER_SIZE = 1024
_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500)
def _create_socket_ipv6(bind_address):
listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
listen_socket.bind((bind_address, 0, 0, 0))
return listen_socket
def _create_socket_ipv4(bind_address):
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.bind((bind_address, 0))
return listen_socket
def _init_listen_socket(bind_address):
listen_socket = None
if socket.has_ipv6:
try:
listen_socket = _create_socket_ipv6(bind_address)
except socket.error:
listen_socket = _create_socket_ipv4(bind_address)
else:
listen_socket = _create_socket_ipv4(bind_address)
listen_socket.listen(1)
return listen_socket, listen_socket.getsockname()[1]
def _init_proxy_socket(gateway_address, gateway_port):
proxy_socket = socket.create_connection((gateway_address, gateway_port))
return proxy_socket
class TcpProxy(object):
"""Proxies a TCP connection between one client and one server."""
def __init__(self, bind_address, gateway_address, gateway_port):
self._bind_address = bind_address
self._gateway_address = gateway_address
self._gateway_port = gateway_port
self._byte_count_lock = threading.RLock()
self._sent_byte_count = 0
self._received_byte_count = 0
self._stop_event = threading.Event()
self._port = None
self._listen_socket = None
self._proxy_socket = None
# The following three attributes are owned by the serving thread.
self._northbound_data = b""
self._southbound_data = b""
self._client_sockets = []
self._thread = threading.Thread(target=self._run_proxy)
def start(self):
self._listen_socket, self._port = _init_listen_socket(
self._bind_address)
self._proxy_socket = _init_proxy_socket(self._gateway_address,
self._gateway_port)
self._thread.start()
def get_port(self):
return self._port
def _handle_reads(self, sockets_to_read):
for socket_to_read in sockets_to_read:
if socket_to_read is self._listen_socket:
client_socket, client_address = socket_to_read.accept()
self._client_sockets.append(client_socket)
elif socket_to_read is self._proxy_socket:
data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE)
with self._byte_count_lock:
self._received_byte_count += len(data)
self._northbound_data += data
elif socket_to_read in self._client_sockets:
data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE)
if data:
with self._byte_count_lock:
self._sent_byte_count += len(data)
self._southbound_data += data
else:
self._client_sockets.remove(socket_to_read)
else:
raise RuntimeError('Unidentified socket appeared in read set.')
def _handle_writes(self, sockets_to_write):
for socket_to_write in sockets_to_write:
if socket_to_write is self._proxy_socket:
if self._southbound_data:
self._proxy_socket.sendall(self._southbound_data)
self._southbound_data = b""
elif socket_to_write in self._client_sockets:
if self._northbound_data:
socket_to_write.sendall(self._northbound_data)
self._northbound_data = b""
def _run_proxy(self):
while not self._stop_event.is_set():
expected_reads = (self._listen_socket, self._proxy_socket) + tuple(
self._client_sockets)
expected_writes = expected_reads
sockets_to_read, sockets_to_write, _ = select.select(
expected_reads, expected_writes, (),
_TCP_PROXY_TIMEOUT.total_seconds())
self._handle_reads(sockets_to_read)
self._handle_writes(sockets_to_write)
for client_socket in self._client_sockets:
client_socket.close()
def stop(self):
self._stop_event.set()
self._thread.join()
self._listen_socket.close()
self._proxy_socket.close()
def get_byte_count(self):
with self._byte_count_lock:
return self._sent_byte_count, self._received_byte_count
def reset_byte_count(self):
with self._byte_count_lock:
self._byte_count = 0
self._received_byte_count = 0
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
Loading…
Cancel
Save