Support compression for both client and server

pull/21809/head
Lidi Zheng 5 years ago
parent ffb41a2368
commit cd76b79e7f
  1. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 15
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  3. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  4. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  5. 63
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  6. 51
      src/python/grpcio/grpc/experimental/aio/_channel.py
  7. 47
      src/python/grpcio/grpc/experimental/aio/_server.py
  8. 1
      src/python/grpcio_tests/tests_aio/tests.json
  9. 174
      src/python/grpcio_tests/tests_aio/unit/compression_test.py

@ -367,7 +367,8 @@ cdef class _AioCall(GrpcCallWrapper):
"""Sends one single raw message in bytes.""" """Sends one single raw message in bytes."""
await _send_message(self, await _send_message(self,
message, message,
True, None,
False,
self._loop) self._loop)
async def send_receive_close(self): async def send_receive_close(self):

@ -153,12 +153,13 @@ async def _receive_message(GrpcCallWrapper grpc_call_wrapper,
async def _send_message(GrpcCallWrapper grpc_call_wrapper, async def _send_message(GrpcCallWrapper grpc_call_wrapper,
bytes message, bytes message,
bint metadata_sent, Operation send_initial_metadata_op,
int write_flag,
object loop): object loop):
cdef SendMessageOperation op = SendMessageOperation(message, _EMPTY_FLAG) cdef SendMessageOperation op = SendMessageOperation(message, write_flag)
cdef tuple ops = (op,) cdef tuple ops = (op,)
if not metadata_sent: if send_initial_metadata_op is not None:
ops = prepend_send_initial_metadata_op(ops, None) ops = (send_initial_metadata_op,) + ops
await execute_batch(grpc_call_wrapper, ops, loop) await execute_batch(grpc_call_wrapper, ops, loop)
@ -184,7 +185,7 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
grpc_status_code code, grpc_status_code code,
str details, str details,
tuple trailing_metadata, tuple trailing_metadata,
bint metadata_sent, Operation send_initial_metadata_op,
object loop): object loop):
assert code != StatusCode.ok, 'Expecting non-ok status code.' assert code != StatusCode.ok, 'Expecting non-ok status code.'
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation( cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
@ -194,6 +195,6 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
_EMPTY_FLAGS, _EMPTY_FLAGS,
) )
cdef tuple ops = (op,) cdef tuple ops = (op,)
if not metadata_sent: if send_initial_metadata_op is not None:
ops = prepend_send_initial_metadata_op(ops, None) ops = (send_initial_metadata_op,) + ops
await execute_batch(grpc_call_wrapper, ops, loop) await execute_batch(grpc_call_wrapper, ops, loop)

@ -67,3 +67,9 @@ class _EOF:
EOF = _EOF() EOF = _EOF()
_COMPRESSION_METADATA_STRING_MAPPING = {
CompressionAlgorithm.none: 'identity',
CompressionAlgorithm.deflate: 'deflate',
CompressionAlgorithm.gzip: 'gzip',
}

@ -31,10 +31,14 @@ cdef class RPCState(GrpcCallWrapper):
cdef grpc_status_code status_code cdef grpc_status_code status_code
cdef str status_details cdef str status_details
cdef tuple trailing_metadata cdef tuple trailing_metadata
cdef object compression_algorithm
cdef bint disable_next_compression
cdef bytes method(self) cdef bytes method(self)
cdef tuple invocation_metadata(self) cdef tuple invocation_metadata(self)
cdef void raise_for_termination(self) except * cdef void raise_for_termination(self) except *
cdef int get_write_flag(self)
cdef Operation create_send_initial_metadata_op_if_not_sent(self)
cdef enum AioServerStatus: cdef enum AioServerStatus:

@ -21,6 +21,16 @@ cdef int _EMPTY_FLAG = 0
cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.' cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.'
cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.' cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
cdef _augment_metadata(tuple metadata, object compression):
if compression is None:
return metadata
else:
return ((
GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
_COMPRESSION_METADATA_STRING_MAPPING[compression]
),) + metadata
cdef class _HandlerCallDetails: cdef class _HandlerCallDetails:
def __cinit__(self, str method, tuple invocation_metadata): def __cinit__(self, str method, tuple invocation_metadata):
self.method = method self.method = method
@ -45,6 +55,8 @@ cdef class RPCState:
self.status_code = StatusCode.ok self.status_code = StatusCode.ok
self.status_details = '' self.status_details = ''
self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA
self.compression_algorithm = None
self.disable_next_compression = False
cdef bytes method(self): cdef bytes method(self):
return _slice_bytes(self.details.method) return _slice_bytes(self.details.method)
@ -69,6 +81,24 @@ cdef class RPCState:
if self.server._status == AIO_SERVER_STATUS_STOPPED: if self.server._status == AIO_SERVER_STATUS_STOPPED:
raise _ServerStoppedError(_SERVER_STOPPED_DETAILS) raise _ServerStoppedError(_SERVER_STOPPED_DETAILS)
cdef int get_write_flag(self):
if self.disable_next_compression:
self.disable_next_compression = False
return WriteFlag.no_compress
else:
return _EMPTY_FLAG
cdef Operation create_send_initial_metadata_op_if_not_sent(self):
if self.metadata_sent:
return None
cdef SendInitialMetadataOperation op = SendInitialMetadataOperation(
_augment_metadata(_IMMUTABLE_EMPTY_METADATA, self.compression_algorithm),
_EMPTY_FLAG
)
self.metadata_sent = True
return op
def __dealloc__(self): def __dealloc__(self):
"""Cleans the Core objects.""" """Cleans the Core objects."""
grpc_call_details_destroy(&self.details) grpc_call_details_destroy(&self.details)
@ -116,10 +146,9 @@ cdef class _ServicerContext:
await _send_message(self._rpc_state, await _send_message(self._rpc_state,
serialize(self._response_serializer, message), serialize(self._response_serializer, message),
self._rpc_state.metadata_sent, self._rpc_state.create_send_initial_metadata_op_if_not_sent(),
self._rpc_state.get_write_flag(),
self._loop) self._loop)
if not self._rpc_state.metadata_sent:
self._rpc_state.metadata_sent = True
async def send_initial_metadata(self, tuple metadata): async def send_initial_metadata(self, tuple metadata):
self._rpc_state.raise_for_termination() self._rpc_state.raise_for_termination()
@ -127,7 +156,12 @@ cdef class _ServicerContext:
if self._rpc_state.metadata_sent: if self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent') raise RuntimeError('Send initial metadata failed: already sent')
else: else:
await _send_initial_metadata(self._rpc_state, metadata, _EMPTY_FLAG, self._loop) await _send_initial_metadata(
self._rpc_state,
_augment_metadata(metadata, self._rpc_state.compression_algorithm),
_EMPTY_FLAG,
self._loop
)
self._rpc_state.metadata_sent = True self._rpc_state.metadata_sent = True
async def abort(self, async def abort(self,
@ -156,7 +190,7 @@ cdef class _ServicerContext:
actual_code, actual_code,
details, details,
trailing_metadata, trailing_metadata,
self._rpc_state.metadata_sent, self._rpc_state.create_send_initial_metadata_op_if_not_sent(),
self._loop self._loop
) )
@ -174,6 +208,15 @@ cdef class _ServicerContext:
def set_details(self, str details): def set_details(self, str details):
self._rpc_state.status_details = details self._rpc_state.status_details = details
def set_compression(self, object compression):
if self._rpc_state.metadata_sent:
raise RuntimeError('Compression setting must be specified before sending initial metadata')
else:
self._rpc_state.compression_algorithm = compression
def disable_next_message_compression(self):
self._rpc_state.disable_next_compression = True
cdef _find_method_handler(str method, tuple metadata, list generic_handlers): cdef _find_method_handler(str method, tuple metadata, list generic_handlers):
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method, cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
@ -217,7 +260,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
# Assembles the batch operations # Assembles the batch operations
cdef tuple finish_ops cdef tuple finish_ops
finish_ops = ( finish_ops = (
SendMessageOperation(response_raw, _EMPTY_FLAGS), SendMessageOperation(response_raw, rpc_state.get_write_flag()),
SendStatusFromServerOperation( SendStatusFromServerOperation(
rpc_state.trailing_metadata, rpc_state.trailing_metadata,
rpc_state.status_code, rpc_state.status_code,
@ -446,7 +489,7 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
status_code, status_code,
'Unexpected %s: %s' % (type(e), e), 'Unexpected %s: %s' % (type(e), e),
rpc_state.trailing_metadata, rpc_state.trailing_metadata,
rpc_state.metadata_sent, rpc_state.create_send_initial_metadata_op_if_not_sent(),
loop loop
) )
@ -492,7 +535,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
StatusCode.unimplemented, StatusCode.unimplemented,
'Method not found!', 'Method not found!',
_IMMUTABLE_EMPTY_METADATA, _IMMUTABLE_EMPTY_METADATA,
rpc_state.metadata_sent, rpc_state.create_send_initial_metadata_op_if_not_sent(),
loop loop
) )
return return
@ -541,7 +584,7 @@ cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHan
cdef class AioServer: cdef class AioServer:
def __init__(self, loop, thread_pool, generic_handlers, interceptors, def __init__(self, loop, thread_pool, generic_handlers, interceptors,
options, maximum_concurrent_rpcs, compression): options, maximum_concurrent_rpcs):
# NOTE(lidiz) Core objects won't be deallocated automatically. # NOTE(lidiz) Core objects won't be deallocated automatically.
# If AioServer.shutdown is not called, those objects will leak. # If AioServer.shutdown is not called, those objects will leak.
self._server = Server(options) self._server = Server(options)
@ -570,8 +613,6 @@ cdef class AioServer:
raise NotImplementedError() raise NotImplementedError()
if maximum_concurrent_rpcs: if maximum_concurrent_rpcs:
raise NotImplementedError() raise NotImplementedError()
if compression:
raise NotImplementedError()
if thread_pool: if thread_pool:
raise NotImplementedError() raise NotImplementedError()

@ -20,6 +20,8 @@ import logging
import grpc import grpc
from grpc import _common from grpc import _common
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc import _compression
from grpc import _grpcio_metadata
from . import _base_call from . import _base_call
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
@ -31,6 +33,19 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
from ._utils import _timeout_to_deadline from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple() _IMMUTABLE_EMPTY_TUPLE = tuple()
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_channel_argument = _compression.create_channel_option(
compression)
user_agent_channel_argument = ((
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),)
return tuple(base_options
) + compression_channel_argument + user_agent_channel_argument
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -135,12 +150,12 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details. metadata, status code, and details.
""" """
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if metadata is None: if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE metadata = _IMMUTABLE_EMPTY_TUPLE
if compression:
metadata = _compression.augment_metadata(metadata, compression)
if not self._interceptors: if not self._interceptors:
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, wait_for_ready, metadata, credentials, wait_for_ready,
@ -188,12 +203,13 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
Returns: Returns:
A Call object instance which is an awaitable object. A Call object instance which is an awaitable object.
""" """
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = UnaryStreamCall(request, deadline, metadata, credentials, call = UnaryStreamCall(request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method, wait_for_ready, self._channel, self._method,
@ -237,12 +253,13 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details. metadata, status code, and details.
""" """
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = StreamUnaryCall(request_async_iterator, deadline, metadata, call = StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, credentials, wait_for_ready, self._channel,
@ -286,12 +303,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
raised RpcError will also be a Call for the RPC affording the RPC's raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details. metadata, status code, and details.
""" """
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = StreamStreamCall(request_async_iterator, deadline, metadata, call = StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, credentials, wait_for_ready, self._channel,
@ -311,7 +329,7 @@ class Channel:
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_ongoing_calls: _OngoingCalls _ongoing_calls: _OngoingCalls
def __init__(self, target: Text, options: Optional[ChannelArgumentType], def __init__(self, target: Text, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials], credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression], compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]): interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]):
@ -326,10 +344,6 @@ class Channel:
interceptors: An optional list of interceptors that would be used for interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel. intercepting any RPC executed with that channel.
""" """
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if interceptors is None: if interceptors is None:
self._unary_unary_interceptors = None self._unary_unary_interceptors = None
else: else:
@ -349,7 +363,8 @@ class Channel:
.format(invalid_interceptors)) .format(invalid_interceptors))
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._channel = cygrpc.AioChannel(_common.encode(target), options, self._channel = cygrpc.AioChannel(_common.encode(target),
_augment_channel_arguments(options, compression),
credentials, self._loop) credentials, self._loop)
self._ongoing_calls = _OngoingCalls() self._ongoing_calls = _OngoingCalls()

@ -13,34 +13,47 @@
# limitations under the License. # limitations under the License.
"""Server-side implementation of gRPC Asyncio Python.""" """Server-side implementation of gRPC Asyncio Python."""
from typing import Text, Optional
import asyncio import asyncio
from concurrent.futures import Executor
from typing import Any, Optional, Sequence, Text
import grpc import grpc
from grpc import _common from grpc import _common, _compression
from grpc._cython import cygrpc from grpc._cython import cygrpc
from ._typing import ChannelArgumentType
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_option = _compression.create_channel_option(compression)
return tuple(base_options) + compression_option
class Server: class Server:
"""Serves RPCs.""" """Serves RPCs."""
def __init__(self, thread_pool, generic_handlers, interceptors, options, def __init__(self, thread_pool: Optional[Executor],
maximum_concurrent_rpcs, compression): generic_handlers: Optional[Sequence[grpc.GenericRpcHandler]],
interceptors: Optional[Sequence[Any]],
options: ChannelArgumentType,
maximum_concurrent_rpcs: Optional[int],
compression: Optional[grpc.Compression]):
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._server = cygrpc.AioServer(self._loop, thread_pool, self._server = cygrpc.AioServer(
generic_handlers, interceptors, options, self._loop, thread_pool, generic_handlers, interceptors,
maximum_concurrent_rpcs, compression) _augment_channel_arguments(options, compression),
maximum_concurrent_rpcs)
def add_generic_rpc_handlers( def add_generic_rpc_handlers(
self, self,
generic_rpc_handlers, generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None:
# generic_rpc_handlers: Iterable[grpc.GenericRpcHandlers]
) -> None:
"""Registers GenericRpcHandlers with this Server. """Registers GenericRpcHandlers with this Server.
This method is only safe to call before the server is started. This method is only safe to call before the server is started.
Args: Args:
generic_rpc_handlers: An iterable of GenericRpcHandlers that will be generic_rpc_handlers: A sequence of GenericRpcHandlers that will be
used to service RPCs. used to service RPCs.
""" """
self._server.add_generic_rpc_handlers(generic_rpc_handlers) self._server.add_generic_rpc_handlers(generic_rpc_handlers)
@ -141,12 +154,12 @@ class Server:
self._loop.create_task(self._server.shutdown(None)) self._loop.create_task(self._server.shutdown(None))
def server(migration_thread_pool=None, def server(migration_thread_pool: Optional[Executor] = None,
handlers=None, handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None,
interceptors=None, interceptors: Optional[Sequence[Any]] = None,
options=None, options: Optional[ChannelArgumentType] = None,
maximum_concurrent_rpcs=None, maximum_concurrent_rpcs: Optional[int] = None,
compression=None): compression: Optional[grpc.Compression] = None):
"""Creates a Server with which RPCs can be serviced. """Creates a Server with which RPCs can be serviced.
Args: Args:

@ -12,6 +12,7 @@
"unit.channel_test.TestChannel", "unit.channel_test.TestChannel",
"unit.close_channel_test.TestCloseChannel", "unit.close_channel_test.TestCloseChannel",
"unit.close_channel_test.TestOngoingCalls", "unit.close_channel_test.TestOngoingCalls",
"unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState", "unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback", "unit.done_callback_test.TestDoneCallback",
"unit.init_test.TestInsecureChannel", "unit.init_test.TestInsecureChannel",

@ -0,0 +1,174 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior around the compression mechanism."""
import asyncio
import logging
import platform
import random
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit import _common
_GZIP_CHANNEL_ARGUMENT = ('grpc.default_compression_algorithm', 2)
_GZIP_DISABLED_CHANNEL_ARGUMENT = ('grpc.compression_enabled_algorithms_bitset',
3)
_DEFLATE_DISABLED_CHANNEL_ARGUMENT = (
'grpc.compression_enabled_algorithms_bitset', 5)
_TEST_UNARY_UNARY = '/test/TestUnaryUnary'
_TEST_SET_COMPRESSION = '/test/TestSetCompression'
_TEST_DISABLE_COMPRESSION_UNARY = '/test/TestDisableCompressionUnary'
_TEST_DISABLE_COMPRESSION_STREAM = '/test/TestDisableCompressionStream'
_REQUEST = b'\x01' * 100
_RESPONSE = b'\x02' * 100
async def _test_unary_unary(unused_request, unused_context):
return _RESPONSE
async def _test_set_compression(unused_request_iterator, context):
assert _REQUEST == await context.read()
context.set_compression(grpc.Compression.Deflate)
await context.write(_RESPONSE)
try:
context.set_compression(grpc.Compression.Deflate)
except RuntimeError:
pass
else:
raise ValueError(
'Expecting exceptions if set_compression is not effective')
async def _test_disable_compression_unary(request, context):
assert _REQUEST == request
context.set_compression(grpc.Compression.Deflate)
context.disable_next_message_compression()
return _RESPONSE
async def _test_disable_compression_stream(unused_request_iterator, context):
assert _REQUEST == await context.read()
context.set_compression(grpc.Compression.Deflate)
await context.write(_RESPONSE)
context.disable_next_message_compression()
await context.write(_RESPONSE)
await context.write(_RESPONSE)
_ROUTING_TABLE = {
_TEST_UNARY_UNARY:
grpc.unary_unary_rpc_method_handler(_test_unary_unary),
_TEST_SET_COMPRESSION:
grpc.stream_stream_rpc_method_handler(_test_set_compression),
_TEST_DISABLE_COMPRESSION_UNARY:
grpc.unary_unary_rpc_method_handler(_test_disable_compression_unary),
_TEST_DISABLE_COMPRESSION_STREAM:
grpc.stream_stream_rpc_method_handler(_test_disable_compression_stream),
}
class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details):
return _ROUTING_TABLE.get(handler_call_details.method)
async def _start_test_server(options=None):
server = aio.server(options=options)
port = server.add_insecure_port('[::]:0')
server.add_generic_rpc_handlers((_GenericHandler(),))
await server.start()
return f'localhost:{port}', server
class TestCompression(AioTestBase):
async def setUp(self):
server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,)
self._address, self._server = await _start_test_server(server_options)
self._channel = aio.insecure_channel(self._address)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_channel_level_compression(self):
# GZIP is disabled, this call should fail
async with aio.insecure_channel(
self._address, compression=grpc.Compression.Gzip) as channel:
multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
call = multicallable(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
# Deflate is allowed, this call should succeed
async with aio.insecure_channel(
self._address, compression=grpc.Compression.Deflate) as channel:
multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
call = multicallable(_REQUEST)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_client_call_level_compression(self):
multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY)
# GZIP is disabled, this call should fail
call = multicallable(_REQUEST, compression=grpc.Compression.Gzip)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
# Deflate is allowed, this call should succeed
call = multicallable(_REQUEST, compression=grpc.Compression.Deflate)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_call_level_compression(self):
multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION)
call = multicallable()
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_disable_compression_unary(self):
multicallable = self._channel.unary_unary(
_TEST_DISABLE_COMPRESSION_UNARY)
call = multicallable(_REQUEST)
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_disable_compression_stream(self):
multicallable = self._channel.stream_stream(
_TEST_DISABLE_COMPRESSION_STREAM)
call = multicallable()
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(grpc.StatusCode.OK, await call.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save