|
|
|
@ -18,11 +18,18 @@ import unittest |
|
|
|
|
import grpc |
|
|
|
|
|
|
|
|
|
from grpc.experimental import aio |
|
|
|
|
from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE |
|
|
|
|
from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY |
|
|
|
|
from tests_aio.unit import _constants |
|
|
|
|
from tests_aio.unit import _common |
|
|
|
|
from tests_aio.unit._test_base import AioTestBase |
|
|
|
|
from src.proto.grpc.testing import messages_pb2 |
|
|
|
|
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' |
|
|
|
|
_INITIAL_METADATA_TO_INJECT = ( |
|
|
|
|
(_INITIAL_METADATA_KEY, 'extra info'), |
|
|
|
|
(_TRAILING_METADATA_KEY, b'\x13\x37'), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnaryUnaryClientInterceptor(AioTestBase): |
|
|
|
@ -124,7 +131,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): |
|
|
|
|
client_call_details, request): |
|
|
|
|
new_client_call_details = aio.ClientCallDetails( |
|
|
|
|
method=client_call_details.method, |
|
|
|
|
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, |
|
|
|
|
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, |
|
|
|
|
metadata=client_call_details.metadata, |
|
|
|
|
credentials=client_call_details.credentials) |
|
|
|
|
return await continuation(new_client_call_details, request) |
|
|
|
@ -165,7 +172,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): |
|
|
|
|
|
|
|
|
|
new_client_call_details = aio.ClientCallDetails( |
|
|
|
|
method=client_call_details.method, |
|
|
|
|
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, |
|
|
|
|
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, |
|
|
|
|
metadata=client_call_details.metadata, |
|
|
|
|
credentials=client_call_details.credentials) |
|
|
|
|
|
|
|
|
@ -342,8 +349,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): |
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), |
|
|
|
|
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) |
|
|
|
|
call = multicallable( |
|
|
|
|
messages_pb2.SimpleRequest(), |
|
|
|
|
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
|
|
|
|
await call |
|
|
|
@ -375,8 +383,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): |
|
|
|
|
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
|
|
|
|
response_deserializer=messages_pb2.SimpleResponse.FromString) |
|
|
|
|
|
|
|
|
|
call = multicallable(messages_pb2.SimpleRequest(), |
|
|
|
|
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) |
|
|
|
|
call = multicallable( |
|
|
|
|
messages_pb2.SimpleRequest(), |
|
|
|
|
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
|
|
|
|
await call |
|
|
|
@ -532,6 +541,42 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): |
|
|
|
|
self.assertEqual(await call.initial_metadata(), tuple()) |
|
|
|
|
self.assertEqual(await call.trailing_metadata(), None) |
|
|
|
|
|
|
|
|
|
async def test_initial_metadata_modification(self): |
|
|
|
|
|
|
|
|
|
class Interceptor(aio.UnaryUnaryClientInterceptor): |
|
|
|
|
|
|
|
|
|
async def intercept_unary_unary(self, continuation, |
|
|
|
|
client_call_details, request): |
|
|
|
|
if client_call_details.metadata is not None: |
|
|
|
|
new_metadata = client_call_details.metadata + _INITIAL_METADATA_TO_INJECT |
|
|
|
|
else: |
|
|
|
|
new_metadata = _INITIAL_METADATA_TO_INJECT |
|
|
|
|
new_details = aio.ClientCallDetails( |
|
|
|
|
method=client_call_details.method, |
|
|
|
|
timeout=client_call_details.timeout, |
|
|
|
|
metadata=new_metadata, |
|
|
|
|
credentials=client_call_details.credentials, |
|
|
|
|
) |
|
|
|
|
return await continuation(new_details, request) |
|
|
|
|
|
|
|
|
|
async with aio.insecure_channel(self._server_target, |
|
|
|
|
interceptors=[Interceptor() |
|
|
|
|
]) as channel: |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
call = stub.UnaryCall(messages_pb2.SimpleRequest()) |
|
|
|
|
|
|
|
|
|
# Expected to see the echoed initial metadata |
|
|
|
|
self.assertTrue( |
|
|
|
|
_common.seen_metadata(_INITIAL_METADATA_TO_INJECT[0], await |
|
|
|
|
call.initial_metadata())) |
|
|
|
|
|
|
|
|
|
# Expected to see the echoed trailing metadata |
|
|
|
|
self.assertTrue( |
|
|
|
|
_common.seen_metadata(_INITIAL_METADATA_TO_INJECT[1], await |
|
|
|
|
call.trailing_metadata())) |
|
|
|
|
|
|
|
|
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
logging.basicConfig() |
|
|
|
|