|
|
|
@ -13,14 +13,24 @@ |
|
|
|
|
# limitations under the License. |
|
|
|
|
"""Tests for Simple Stubs.""" |
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
|
|
import datetime |
|
|
|
|
import inspect |
|
|
|
|
import unittest |
|
|
|
|
import sys |
|
|
|
|
import time |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
|
import test_common |
|
|
|
|
|
|
|
|
|
# TODO: Figure out how to get this test to run only for Python 3. |
|
|
|
|
from typing import Callable, Optional |
|
|
|
|
|
|
|
|
|
_CACHE_EPOCHS = 8 |
|
|
|
|
_CACHE_TRIALS = 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_UNARY_UNARY = "/test/UnaryUnary" |
|
|
|
|
|
|
|
|
@ -37,26 +47,93 @@ class _GenericHandler(grpc.GenericRpcHandler): |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _time_invocation(to_time: Callable[[], None]) -> datetime.timedelta: |
|
|
|
|
start = datetime.datetime.now() |
|
|
|
|
to_time() |
|
|
|
|
return datetime.datetime.now() - start |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
|
|
def _server(credentials: Optional[grpc.ServerCredentials]): |
|
|
|
|
try: |
|
|
|
|
server = test_common.test_server() |
|
|
|
|
target = '[::]:0' |
|
|
|
|
if credentials is None: |
|
|
|
|
port = server.add_insecure_port(target) |
|
|
|
|
else: |
|
|
|
|
port = server.add_secure_port(target, credentials) |
|
|
|
|
server.add_generic_rpc_handlers((_GenericHandler(),)) |
|
|
|
|
server.start() |
|
|
|
|
yield server, port |
|
|
|
|
finally: |
|
|
|
|
server.stop(None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(sys.version_info[0] < 3, "Unsupported on Python 2.") |
|
|
|
|
class SimpleStubsTest(unittest.TestCase): |
|
|
|
|
@classmethod |
|
|
|
|
def setUpClass(cls): |
|
|
|
|
super(SimpleStubsTest, cls).setUpClass() |
|
|
|
|
cls._server = test_common.test_server() |
|
|
|
|
cls._port = cls._server.add_insecure_port('[::]:0') |
|
|
|
|
cls._server.add_generic_rpc_handlers((_GenericHandler(),)) |
|
|
|
|
cls._server.start() |
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
|
def tearDownClass(cls): |
|
|
|
|
cls._server.stop(None) |
|
|
|
|
super(SimpleStubsTest, cls).tearDownClass() |
|
|
|
|
|
|
|
|
|
def test_unary_unary(self): |
|
|
|
|
target = f'localhost:{self._port}' |
|
|
|
|
request = b'0000' |
|
|
|
|
response = grpc.unary_unary(request, target, _UNARY_UNARY) |
|
|
|
|
self.assertEqual(request, response) |
|
|
|
|
|
|
|
|
|
def assert_cached(self, to_check: Callable[[str], None]) -> None: |
|
|
|
|
"""Asserts that a function caches intermediate data/state. |
|
|
|
|
|
|
|
|
|
To be specific, given a function whose caching behavior is |
|
|
|
|
deterministic in the value of a supplied string, this function asserts |
|
|
|
|
that, on average, subsequent invocations of the function for a specific |
|
|
|
|
string are faster than first invocations with that same string. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
to_check: A function returning nothing, that caches values based on |
|
|
|
|
an arbitrary supplied Text object. |
|
|
|
|
""" |
|
|
|
|
initial_runs = [] |
|
|
|
|
cached_runs = [] |
|
|
|
|
for epoch in range(_CACHE_EPOCHS): |
|
|
|
|
runs = [] |
|
|
|
|
text = str(epoch) |
|
|
|
|
for trial in range(_CACHE_TRIALS): |
|
|
|
|
runs.append(_time_invocation(lambda: to_check(text))) |
|
|
|
|
initial_runs.append(runs[0]) |
|
|
|
|
cached_runs.extend(runs[1:]) |
|
|
|
|
average_cold = sum((run for run in initial_runs), datetime.timedelta()) / len(initial_runs) |
|
|
|
|
average_warm = sum((run for run in cached_runs), datetime.timedelta()) / len(cached_runs) |
|
|
|
|
self.assertLess(average_warm, average_cold) |
|
|
|
|
|
|
|
|
|
def test_unary_unary_insecure(self): |
|
|
|
|
with _server(None) as (_, port): |
|
|
|
|
target = f'localhost:{port}' |
|
|
|
|
request = b'0000' |
|
|
|
|
response = grpc.unary_unary(request, target, _UNARY_UNARY) |
|
|
|
|
self.assertEqual(request, response) |
|
|
|
|
|
|
|
|
|
def test_unary_unary_secure(self): |
|
|
|
|
with _server(grpc.local_server_credentials()) as (_, port): |
|
|
|
|
target = f'localhost:{port}' |
|
|
|
|
request = b'0000' |
|
|
|
|
response = grpc.unary_unary(request, |
|
|
|
|
target, |
|
|
|
|
_UNARY_UNARY, |
|
|
|
|
channel_credentials=grpc.local_channel_credentials()) |
|
|
|
|
self.assertEqual(request, response) |
|
|
|
|
|
|
|
|
|
def test_channels_cached(self): |
|
|
|
|
with _server(grpc.local_server_credentials()) as (_, port): |
|
|
|
|
target = f'localhost:{port}' |
|
|
|
|
request = b'0000' |
|
|
|
|
test_name = inspect.stack()[0][3] |
|
|
|
|
args = (request, target, _UNARY_UNARY) |
|
|
|
|
kwargs = {"channel_credentials": grpc.local_channel_credentials()} |
|
|
|
|
def _invoke(seed: Text): |
|
|
|
|
run_kwargs = dict(kwargs) |
|
|
|
|
run_kwargs["options"] = ((test_name + seed, ""),) |
|
|
|
|
grpc.unary_unary(*args, **run_kwargs) |
|
|
|
|
self.assert_cached(_invoke) |
|
|
|
|
|
|
|
|
|
# TODO: Test request_serializer |
|
|
|
|
# TODO: Test request_deserializer |
|
|
|
|
# TODO: Test channel_credentials |
|
|
|
|
# TODO: Test call_credentials |
|
|
|
|
# TODO: Test compression |
|
|
|
|
# TODO: Test wait_for_ready |
|
|
|
|
# TODO: Test metadata |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
logging.basicConfig() |
|
|
|
|