Support compression for both client and server

pull/21809/head
Lidi Zheng 5 years ago
parent ffb41a2368
commit cd76b79e7f
  1. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 15
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  3. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  4. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  5. 63
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  6. 51
      src/python/grpcio/grpc/experimental/aio/_channel.py
  7. 47
      src/python/grpcio/grpc/experimental/aio/_server.py
  8. 1
      src/python/grpcio_tests/tests_aio/tests.json
  9. 174
      src/python/grpcio_tests/tests_aio/unit/compression_test.py

@ -367,7 +367,8 @@ cdef class _AioCall(GrpcCallWrapper):
"""Sends one single raw message in bytes."""
await _send_message(self,
message,
True,
None,
False,
self._loop)
async def send_receive_close(self):

@ -153,12 +153,13 @@ async def _receive_message(GrpcCallWrapper grpc_call_wrapper,
async def _send_message(GrpcCallWrapper grpc_call_wrapper,
bytes message,
bint metadata_sent,
Operation send_initial_metadata_op,
int write_flag,
object loop):
cdef SendMessageOperation op = SendMessageOperation(message, _EMPTY_FLAG)
cdef SendMessageOperation op = SendMessageOperation(message, write_flag)
cdef tuple ops = (op,)
if not metadata_sent:
ops = prepend_send_initial_metadata_op(ops, None)
if send_initial_metadata_op is not None:
ops = (send_initial_metadata_op,) + ops
await execute_batch(grpc_call_wrapper, ops, loop)
@ -184,7 +185,7 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
grpc_status_code code,
str details,
tuple trailing_metadata,
bint metadata_sent,
Operation send_initial_metadata_op,
object loop):
assert code != StatusCode.ok, 'Expecting non-ok status code.'
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
@ -194,6 +195,6 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
_EMPTY_FLAGS,
)
cdef tuple ops = (op,)
if not metadata_sent:
ops = prepend_send_initial_metadata_op(ops, None)
if send_initial_metadata_op is not None:
ops = (send_initial_metadata_op,) + ops
await execute_batch(grpc_call_wrapper, ops, loop)

@ -67,3 +67,9 @@ class _EOF:
EOF = _EOF()
_COMPRESSION_METADATA_STRING_MAPPING = {
CompressionAlgorithm.none: 'identity',
CompressionAlgorithm.deflate: 'deflate',
CompressionAlgorithm.gzip: 'gzip',
}

@ -31,10 +31,14 @@ cdef class RPCState(GrpcCallWrapper):
cdef grpc_status_code status_code
cdef str status_details
cdef tuple trailing_metadata
cdef object compression_algorithm
cdef bint disable_next_compression
cdef bytes method(self)
cdef tuple invocation_metadata(self)
cdef void raise_for_termination(self) except *
cdef int get_write_flag(self)
cdef Operation create_send_initial_metadata_op_if_not_sent(self)
cdef enum AioServerStatus:

@ -21,6 +21,16 @@ cdef int _EMPTY_FLAG = 0
cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.'
cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
cdef _augment_metadata(tuple metadata, object compression):
if compression is None:
return metadata
else:
return ((
GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
_COMPRESSION_METADATA_STRING_MAPPING[compression]
),) + metadata
cdef class _HandlerCallDetails:
def __cinit__(self, str method, tuple invocation_metadata):
self.method = method
@ -45,6 +55,8 @@ cdef class RPCState:
self.status_code = StatusCode.ok
self.status_details = ''
self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA
self.compression_algorithm = None
self.disable_next_compression = False
cdef bytes method(self):
return _slice_bytes(self.details.method)
@ -69,6 +81,24 @@ cdef class RPCState:
if self.server._status == AIO_SERVER_STATUS_STOPPED:
raise _ServerStoppedError(_SERVER_STOPPED_DETAILS)
cdef int get_write_flag(self):
if self.disable_next_compression:
self.disable_next_compression = False
return WriteFlag.no_compress
else:
return _EMPTY_FLAG
cdef Operation create_send_initial_metadata_op_if_not_sent(self):
if self.metadata_sent:
return None
cdef SendInitialMetadataOperation op = SendInitialMetadataOperation(
_augment_metadata(_IMMUTABLE_EMPTY_METADATA, self.compression_algorithm),
_EMPTY_FLAG
)
self.metadata_sent = True
return op
def __dealloc__(self):
"""Cleans the Core objects."""
grpc_call_details_destroy(&self.details)
@ -116,10 +146,9 @@ cdef class _ServicerContext:
await _send_message(self._rpc_state,
serialize(self._response_serializer, message),
self._rpc_state.metadata_sent,
self._rpc_state.create_send_initial_metadata_op_if_not_sent(),
self._rpc_state.get_write_flag(),
self._loop)
if not self._rpc_state.metadata_sent:
self._rpc_state.metadata_sent = True
async def send_initial_metadata(self, tuple metadata):
self._rpc_state.raise_for_termination()
@ -127,7 +156,12 @@ cdef class _ServicerContext:
if self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
else:
await _send_initial_metadata(self._rpc_state, metadata, _EMPTY_FLAG, self._loop)
await _send_initial_metadata(
self._rpc_state,
_augment_metadata(metadata, self._rpc_state.compression_algorithm),
_EMPTY_FLAG,
self._loop
)
self._rpc_state.metadata_sent = True
async def abort(self,
@ -156,7 +190,7 @@ cdef class _ServicerContext:
actual_code,
details,
trailing_metadata,
self._rpc_state.metadata_sent,
self._rpc_state.create_send_initial_metadata_op_if_not_sent(),
self._loop
)
@ -174,6 +208,15 @@ cdef class _ServicerContext:
def set_details(self, str details):
self._rpc_state.status_details = details
def set_compression(self, object compression):
if self._rpc_state.metadata_sent:
raise RuntimeError('Compression setting must be specified before sending initial metadata')
else:
self._rpc_state.compression_algorithm = compression
def disable_next_message_compression(self):
self._rpc_state.disable_next_compression = True
cdef _find_method_handler(str method, tuple metadata, list generic_handlers):
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
@ -217,7 +260,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
# Assembles the batch operations
cdef tuple finish_ops
finish_ops = (
SendMessageOperation(response_raw, _EMPTY_FLAGS),
SendMessageOperation(response_raw, rpc_state.get_write_flag()),
SendStatusFromServerOperation(
rpc_state.trailing_metadata,
rpc_state.status_code,
@ -446,7 +489,7 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
status_code,
'Unexpected %s: %s' % (type(e), e),
rpc_state.trailing_metadata,
rpc_state.metadata_sent,
rpc_state.create_send_initial_metadata_op_if_not_sent(),
loop
)
@ -492,7 +535,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
StatusCode.unimplemented,
'Method not found!',
_IMMUTABLE_EMPTY_METADATA,
rpc_state.metadata_sent,
rpc_state.create_send_initial_metadata_op_if_not_sent(),
loop
)
return
@ -541,7 +584,7 @@ cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHan
cdef class AioServer:
def __init__(self, loop, thread_pool, generic_handlers, interceptors,
options, maximum_concurrent_rpcs, compression):
options, maximum_concurrent_rpcs):
# NOTE(lidiz) Core objects won't be deallocated automatically.
# If AioServer.shutdown is not called, those objects will leak.
self._server = Server(options)
@ -570,8 +613,6 @@ cdef class AioServer:
raise NotImplementedError()
if maximum_concurrent_rpcs:
raise NotImplementedError()
if compression:
raise NotImplementedError()
if thread_pool:
raise NotImplementedError()

@ -20,6 +20,8 @@ import logging
import grpc
from grpc import _common
from grpc._cython import cygrpc
from grpc import _compression
from grpc import _grpcio_metadata
from . import _base_call
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
@ -31,6 +33,19 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple()
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_channel_argument = _compression.create_channel_option(
compression)
user_agent_channel_argument = ((
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),)
return tuple(base_options
) + compression_channel_argument + user_agent_channel_argument
_LOGGER = logging.getLogger(__name__)
@ -135,12 +150,12 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if compression:
metadata = _compression.augment_metadata(metadata, compression)
if not self._interceptors:
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, wait_for_ready,
@ -188,12 +203,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
Returns:
A Call object instance which is an awaitable object.
"""
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = UnaryStreamCall(request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
@ -237,12 +253,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
@ -286,12 +303,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
@ -311,7 +329,7 @@ class Channel:
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_ongoing_calls: _OngoingCalls
def __init__(self, target: Text, options: Optional[ChannelArgumentType],
def __init__(self, target: Text, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]):
@ -326,10 +344,6 @@ class Channel:
interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel.
"""
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if interceptors is None:
self._unary_unary_interceptors = None
else:
@ -349,7 +363,8 @@ class Channel:
.format(invalid_interceptors))
self._loop = asyncio.get_event_loop()
self._channel = cygrpc.AioChannel(_common.encode(target), options,
self._channel = cygrpc.AioChannel(_common.encode(target),
_augment_channel_arguments(options, compression),
credentials, self._loop)
self._ongoing_calls = _OngoingCalls()

@ -13,34 +13,47 @@
# limitations under the License.
"""Server-side implementation of gRPC Asyncio Python."""
from typing import Text, Optional
import asyncio
from concurrent.futures import Executor
from typing import Any, Optional, Sequence, Text
import grpc
from grpc import _common
from grpc import _common, _compression
from grpc._cython import cygrpc
from ._typing import ChannelArgumentType
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_option = _compression.create_channel_option(compression)
return tuple(base_options) + compression_option
class Server:
"""Serves RPCs."""
def __init__(self, thread_pool, generic_handlers, interceptors, options,
maximum_concurrent_rpcs, compression):
def __init__(self, thread_pool: Optional[Executor],
generic_handlers: Optional[Sequence[grpc.GenericRpcHandler]],
interceptors: Optional[Sequence[Any]],
options: ChannelArgumentType,
maximum_concurrent_rpcs: Optional[int],
compression: Optional[grpc.Compression]):
self._loop = asyncio.get_event_loop()
self._server = cygrpc.AioServer(self._loop, thread_pool,
generic_handlers, interceptors, options,
maximum_concurrent_rpcs, compression)
self._server = cygrpc.AioServer(
self._loop, thread_pool, generic_handlers, interceptors,
_augment_channel_arguments(options, compression),
maximum_concurrent_rpcs)
def add_generic_rpc_handlers(
self,
generic_rpc_handlers,
# generic_rpc_handlers: Iterable[grpc.GenericRpcHandlers]
) -> None:
generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None:
"""Registers GenericRpcHandlers with this Server.
This method is only safe to call before the server is started.
Args:
generic_rpc_handlers: An iterable of GenericRpcHandlers that will be
generic_rpc_handlers: A sequence of GenericRpcHandlers that will be
used to service RPCs.
"""
self._server.add_generic_rpc_handlers(generic_rpc_handlers)
@ -141,12 +154,12 @@ class Server:
self._loop.create_task(self._server.shutdown(None))
def server(migration_thread_pool=None,
handlers=None,
interceptors=None,
options=None,
maximum_concurrent_rpcs=None,
compression=None):
def server(migration_thread_pool: Optional[Executor] = None,
handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None,
interceptors: Optional[Sequence[Any]] = None,
options: Optional[ChannelArgumentType] = None,
maximum_concurrent_rpcs: Optional[int] = None,
compression: Optional[grpc.Compression] = None):
"""Creates a Server with which RPCs can be serviced.
Args:

@ -12,6 +12,7 @@
"unit.channel_test.TestChannel",
"unit.close_channel_test.TestCloseChannel",
"unit.close_channel_test.TestOngoingCalls",
"unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback",
"unit.init_test.TestInsecureChannel",

@ -0,0 +1,174 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior around the compression mechanism."""
import asyncio
import logging
import platform
import random
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit import _common
_GZIP_CHANNEL_ARGUMENT = ('grpc.default_compression_algorithm', 2)
_GZIP_DISABLED_CHANNEL_ARGUMENT = ('grpc.compression_enabled_algorithms_bitset',
3)
_DEFLATE_DISABLED_CHANNEL_ARGUMENT = (
'grpc.compression_enabled_algorithms_bitset', 5)
_TEST_UNARY_UNARY = '/test/TestUnaryUnary'
_TEST_SET_COMPRESSION = '/test/TestSetCompression'
_TEST_DISABLE_COMPRESSION_UNARY = '/test/TestDisableCompressionUnary'
_TEST_DISABLE_COMPRESSION_STREAM = '/test/TestDisableCompressionStream'
_REQUEST = b'\x01' * 100
_RESPONSE = b'\x02' * 100
async def _test_unary_unary(unused_request, unused_context):
return _RESPONSE
async def _test_set_compression(unused_request_iterator, context):
assert _REQUEST == await context.read()
context.set_compression(grpc.Compression.Deflate)
await context.write(_RESPONSE)
try:
context.set_compression(grpc.Compression.Deflate)
except RuntimeError:
pass
else:
raise ValueError(
'Expecting exceptions if set_compression is not effective')
async def _test_disable_compression_unary(request, context):
assert _REQUEST == request
context.set_compression(grpc.Compression.Deflate)
context.disable_next_message_compression()
return _RESPONSE
async def _test_disable_compression_stream(unused_request_iterator, context):
assert _REQUEST == await context.read()
context.set_compression(grpc.Compression.Deflate)
await context.write(_RESPONSE)
context.disable_next_message_compression()
await context.write(_RESPONSE)
await context.write(_RESPONSE)
_ROUTING_TABLE = {
_TEST_UNARY_UNARY:
grpc.unary_unary_rpc_method_handler(_test_unary_unary),
_TEST_SET_COMPRESSION:
grpc.stream_stream_rpc_method_handler(_test_set_compression),
_TEST_DISABLE_COMPRESSION_UNARY:
grpc.unary_unary_rpc_method_handler(_test_disable_compression_unary),
_TEST_DISABLE_COMPRESSION_STREAM:
grpc.stream_stream_rpc_method_handler(_test_disable_compression_stream),
}
class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details):
return _ROUTING_TABLE.get(handler_call_details.method)
async def _start_test_server(options=None):
server = aio.server(options=options)
port = server.add_insecure_port('[::]:0')
server.add_generic_rpc_handlers((_GenericHandler(),))
await server.start()
return f'localhost:{port}', server
class TestCompression(AioTestBase):
async def setUp(self):
server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,)
self._address, self._server = await _start_test_server(server_options)
self._channel = aio.insecure_channel(self._address)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_channel_level_compression(self):
# GZIP is disabled, this call should fail
async with aio.insecure_channel(
self._address, compression=grpc.Compression.Gzip) as channel:
multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
call = multicallable(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
# Deflate is allowed, this call should succeed
async with aio.insecure_channel(
self._address, compression=grpc.Compression.Deflate) as channel:
multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
call = multicallable(_REQUEST)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_client_call_level_compression(self):
multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY)
# GZIP is disabled, this call should fail
call = multicallable(_REQUEST, compression=grpc.Compression.Gzip)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
# Deflate is allowed, this call should succeed
call = multicallable(_REQUEST, compression=grpc.Compression.Deflate)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_call_level_compression(self):
multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION)
call = multicallable()
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_disable_compression_unary(self):
multicallable = self._channel.unary_unary(
_TEST_DISABLE_COMPRESSION_UNARY)
call = multicallable(_REQUEST)
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_disable_compression_stream(self):
multicallable = self._channel.stream_stream(
_TEST_DISABLE_COMPRESSION_STREAM)
call = multicallable()
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(grpc.StatusCode.OK, await call.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save