Add thread-safe channel cache. Test that it actually caches

pull/21954/head
Richard Belleville 5 years ago
parent 0c4af58acc
commit 5a8a6e3ad3
  1. 38
      src/python/grpcio/grpc/_simple_stubs.py
  2. 113
      src/python/grpcio_tests/tests/unit/_simple_stubs_test.py

@ -1,14 +1,22 @@
# TODO: Flowerbox.
import threading
import grpc
from typing import Any, Callable, Optional, Sequence, Text, Tuple, Union
_CHANNEL_CACHE = None
_CHANNEL_CACHE_LOCK = threading.RLock()
def _get_cached_channel(target: Text,
options: Sequence[Tuple[Text, Text]],
channel_credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]) -> grpc.Channel:
# TODO: Actually cache.
# TODO: Evict channels.
# Eviction policy based on both channel count and time since use. Perhaps
# OrderedDict instead?
def _create_channel(target: Text,
options: Sequence[Tuple[Text, Text]],
channel_credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]) -> grpc.Channel:
if channel_credentials is None:
return grpc.insecure_channel(target,
options=options,
@ -19,6 +27,26 @@ def _get_cached_channel(target: Text,
options=options,
compression=compression)
def _get_cached_channel(target: Text,
options: Sequence[Tuple[Text, Text]],
channel_credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]) -> grpc.Channel:
global _CHANNEL_CACHE
global _CHANNEL_CACHE_LOCK
key = (target, options, channel_credentials, compression)
with _CHANNEL_CACHE_LOCK:
if _CHANNEL_CACHE is None:
_CHANNEL_CACHE = {}
channel = _CHANNEL_CACHE.get(key, None)
if channel is not None:
return channel
else:
channel = _create_channel(target, options, channel_credentials, compression)
_CHANNEL_CACHE[key] = channel
return channel
def unary_unary(request: Any,
target: Text,
method: Text,

@ -13,14 +13,24 @@
# limitations under the License.
"""Tests for Simple Stubs."""
import contextlib
import datetime
import inspect
import unittest
import sys
import time
import logging
import grpc
import test_common
# TODO: Figure out how to get this test to run only for Python 3.
from typing import Callable, Optional
_CACHE_EPOCHS = 8
_CACHE_TRIALS = 6
_UNARY_UNARY = "/test/UnaryUnary"
@ -37,26 +47,93 @@ class _GenericHandler(grpc.GenericRpcHandler):
raise NotImplementedError()
def _time_invocation(to_time: Callable[[], None]) -> datetime.timedelta:
start = datetime.datetime.now()
to_time()
return datetime.datetime.now() - start
@contextlib.contextmanager
def _server(credentials: Optional[grpc.ServerCredentials]):
try:
server = test_common.test_server()
target = '[::]:0'
if credentials is None:
port = server.add_insecure_port(target)
else:
port = server.add_secure_port(target, credentials)
server.add_generic_rpc_handlers((_GenericHandler(),))
server.start()
yield server, port
finally:
server.stop(None)
@unittest.skipIf(sys.version_info[0] < 3, "Unsupported on Python 2.")
class SimpleStubsTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(SimpleStubsTest, cls).setUpClass()
cls._server = test_common.test_server()
cls._port = cls._server.add_insecure_port('[::]:0')
cls._server.add_generic_rpc_handlers((_GenericHandler(),))
cls._server.start()
@classmethod
def tearDownClass(cls):
cls._server.stop(None)
super(SimpleStubsTest, cls).tearDownClass()
def test_unary_unary(self):
target = f'localhost:{self._port}'
request = b'0000'
response = grpc.unary_unary(request, target, _UNARY_UNARY)
self.assertEqual(request, response)
def assert_cached(self, to_check: Callable[[str], None]) -> None:
"""Asserts that a function caches intermediate data/state.
To be specific, given a function whose caching behavior is
deterministic in the value of a supplied string, this function asserts
that, on average, subsequent invocations of the function for a specific
string are faster than first invocations with that same string.
Args:
to_check: A function returning nothing, that caches values based on
an arbitrary supplied Text object.
"""
initial_runs = []
cached_runs = []
for epoch in range(_CACHE_EPOCHS):
runs = []
text = str(epoch)
for trial in range(_CACHE_TRIALS):
runs.append(_time_invocation(lambda: to_check(text)))
initial_runs.append(runs[0])
cached_runs.extend(runs[1:])
average_cold = sum((run for run in initial_runs), datetime.timedelta()) / len(initial_runs)
average_warm = sum((run for run in cached_runs), datetime.timedelta()) / len(cached_runs)
self.assertLess(average_warm, average_cold)
def test_unary_unary_insecure(self):
with _server(None) as (_, port):
target = f'localhost:{port}'
request = b'0000'
response = grpc.unary_unary(request, target, _UNARY_UNARY)
self.assertEqual(request, response)
def test_unary_unary_secure(self):
with _server(grpc.local_server_credentials()) as (_, port):
target = f'localhost:{port}'
request = b'0000'
response = grpc.unary_unary(request,
target,
_UNARY_UNARY,
channel_credentials=grpc.local_channel_credentials())
self.assertEqual(request, response)
def test_channels_cached(self):
with _server(grpc.local_server_credentials()) as (_, port):
target = f'localhost:{port}'
request = b'0000'
test_name = inspect.stack()[0][3]
args = (request, target, _UNARY_UNARY)
kwargs = {"channel_credentials": grpc.local_channel_credentials()}
def _invoke(seed: Text):
run_kwargs = dict(kwargs)
run_kwargs["options"] = ((test_name + seed, ""),)
grpc.unary_unary(*args, **run_kwargs)
self.assert_cached(_invoke)
# TODO: Test request_serializer
# TODO: Test request_deserializer
# TODO: Test channel_credentials
# TODO: Test call_credentials
# TODO: Test compression
# TODO: Test wait_for_ready
# TODO: Test metadata
if __name__ == "__main__":
logging.basicConfig()

Loading…
Cancel
Save