Merge pull request #22343 from lidizheng/async-unary-unary-credentials-tests

[Aio] Extend unit tests for async credentials calls
pull/22426/head
Lidi Zheng 5 years ago committed by GitHub
commit 39c4fd7972
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  3. 5
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  4. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pxd.pxi
  5. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pyx.pxi
  6. 11
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/iomgr.pyx.pxi
  7. 10
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
  8. 5
      src/python/grpcio/grpc/experimental/aio/__init__.py
  9. 2
      src/python/grpcio_tests/tests_aio/interop/BUILD.bazel
  10. 6
      src/python/grpcio_tests/tests_aio/tests.json
  11. 2
      src/python/grpcio_tests/tests_aio/unit/BUILD.bazel
  12. 3
      src/python/grpcio_tests/tests_aio/unit/_test_base.py
  13. 6
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  14. 43
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  15. 25
      src/python/grpcio_tests/tests_aio/unit/init_test.py
  16. 130
      src/python/grpcio_tests/tests_aio/unit/secure_call_test.py

@ -125,7 +125,7 @@ cdef class _AioCall(GrpcCallWrapper):
if credentials is not None: if credentials is not None:
set_credentials_error = grpc_call_set_credentials(self.call, credentials.c()) set_credentials_error = grpc_call_set_credentials(self.call, credentials.c())
if set_credentials_error != GRPC_CALL_OK: if set_credentials_error != GRPC_CALL_OK:
raise Exception("Credentials couldn't have been set") raise InternalError("Credentials couldn't have been set: {0}".format(set_credentials_error))
grpc_slice_unref(method_slice) grpc_slice_unref(method_slice)

@ -24,3 +24,4 @@ cdef class AioChannel:
object loop object loop
bytes _target bytes _target
AioChannelStatus _status AioChannelStatus _status
bint _is_secure

@ -36,11 +36,13 @@ cdef class AioChannel:
self._status = AIO_CHANNEL_STATUS_READY self._status = AIO_CHANNEL_STATUS_READY
if credentials is None: if credentials is None:
self._is_secure = False
self.channel = grpc_insecure_channel_create( self.channel = grpc_insecure_channel_create(
<char *>target, <char *>target,
channel_args.c_args(), channel_args.c_args(),
NULL) NULL)
else: else:
self._is_secure = True
self.channel = grpc_secure_channel_create( self.channel = grpc_secure_channel_create(
<grpc_channel_credentials *> credentials.c(), <grpc_channel_credentials *> credentials.c(),
<char *>target, <char *>target,
@ -122,6 +124,9 @@ cdef class AioChannel:
cdef CallCredentials cython_call_credentials cdef CallCredentials cython_call_credentials
if python_call_credentials is not None: if python_call_credentials is not None:
if not self._is_secure:
raise UsageError("Call credentials are only valid on secure channels")
cython_call_credentials = python_call_credentials._credentials cython_call_credentials = python_call_credentials._credentials
else: else:
cython_call_credentials = None cython_call_credentials = None

@ -23,10 +23,10 @@ cdef class _AioState:
cdef grpc_completion_queue *global_completion_queue() cdef grpc_completion_queue *global_completion_queue()
cdef init_grpc_aio() cpdef init_grpc_aio()
cdef shutdown_grpc_aio() cpdef shutdown_grpc_aio()
cdef extern from "src/core/lib/iomgr/timer_manager.h": cdef extern from "src/core/lib/iomgr/timer_manager.h":

@ -114,8 +114,8 @@ cdef _actual_aio_shutdown():
raise ValueError('Unsupported engine type [%s]' % _global_aio_state.engine) raise ValueError('Unsupported engine type [%s]' % _global_aio_state.engine)
cdef init_grpc_aio(): cpdef init_grpc_aio():
"""Initialis the gRPC AsyncIO module. """Initializes the gRPC AsyncIO module.
Expected to be invoked on critical class constructors. Expected to be invoked on critical class constructors.
E.g., AioChannel, AioServer. E.g., AioChannel, AioServer.
@ -126,7 +126,7 @@ cdef init_grpc_aio():
_actual_aio_initialization() _actual_aio_initialization()
cdef shutdown_grpc_aio(): cpdef shutdown_grpc_aio():
"""Shuts down the gRPC AsyncIO module. """Shuts down the gRPC AsyncIO module.
Expected to be invoked on critical class destructors. Expected to be invoked on critical class destructors.

@ -212,7 +212,18 @@ cdef void asyncio_run_loop(size_t timeout_ms) with gil:
pass pass
def _auth_plugin_callback_wrapper(object cb,
str service_url,
str method_name,
object callback):
asyncio.get_event_loop().call_soon(cb, service_url, method_name, callback)
def install_asyncio_iomgr(): def install_asyncio_iomgr():
# Auth plugins invoke user provided logic in another thread by default. We
# need to override that behavior by registering the call to the event loop.
set_async_callback_func(_auth_plugin_callback_wrapper)
asyncio_resolver_vtable.resolve = asyncio_resolve asyncio_resolver_vtable.resolve = asyncio_resolve
asyncio_resolver_vtable.resolve_async = asyncio_resolve_async asyncio_resolver_vtable.resolve_async = asyncio_resolve_async

@ -34,11 +34,13 @@ cdef class CallCredentials:
raise NotImplementedError() raise NotImplementedError()
cdef int _get_metadata( cdef int _get_metadata(void *state,
void *state, grpc_auth_metadata_context context, grpc_auth_metadata_context context,
grpc_credentials_plugin_metadata_cb cb, void *user_data, grpc_credentials_plugin_metadata_cb cb,
void *user_data,
grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX], grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX],
size_t *num_creds_md, grpc_status_code *status, size_t *num_creds_md,
grpc_status_code *status,
const char **error_details) except * with gil: const char **error_details) except * with gil:
cdef size_t metadata_count cdef size_t metadata_count
cdef grpc_metadata *c_metadata cdef grpc_metadata *c_metadata

@ -20,7 +20,8 @@ created. AsyncIO doesn't provide thread safety for most of its APIs.
from typing import Any, Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple
import grpc import grpc
from grpc._cython.cygrpc import (EOF, AbortError, BaseError, InternalError, from grpc._cython.cygrpc import (init_grpc_aio, shutdown_grpc_aio, EOF,
AbortError, BaseError, InternalError,
UsageError) UsageError)
from ._base_call import (Call, RpcContext, StreamStreamCall, StreamUnaryCall, from ._base_call import (Call, RpcContext, StreamStreamCall, StreamUnaryCall,
@ -39,6 +40,8 @@ from ._channel import insecure_channel, secure_channel
################################### __all__ ################################# ################################### __all__ #################################
__all__ = ( __all__ = (
'init_grpc_aio',
'shutdown_grpc_aio',
'AioRpcError', 'AioRpcError',
'RpcContext', 'RpcContext',
'Call', 'Call',

@ -56,6 +56,7 @@ py_binary(
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
"//src/python/grpcio/grpc:grpcio", "//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/interop:resources",
"//src/python/grpcio_tests/tests/interop:server", "//src/python/grpcio_tests/tests/interop:server",
"//src/python/grpcio_tests/tests_aio/unit:_test_server", "//src/python/grpcio_tests/tests_aio/unit:_test_server",
], ],
@ -70,5 +71,6 @@ py_binary(
":methods", ":methods",
"//src/python/grpcio/grpc:grpcio", "//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/interop:client", "//src/python/grpcio_tests/tests/interop:client",
"//src/python/grpcio_tests/tests/interop:resources",
], ],
) )

@ -19,9 +19,11 @@
"unit.compression_test.TestCompression", "unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState", "unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback", "unit.done_callback_test.TestDoneCallback",
"unit.init_test.TestInsecureChannel", "unit.init_test.TestChannel",
"unit.init_test.TestSecureChannel",
"unit.metadata_test.TestMetadata", "unit.metadata_test.TestMetadata",
"unit.secure_call_test.TestStreamStreamSecureCall",
"unit.secure_call_test.TestUnaryStreamSecureCall",
"unit.secure_call_test.TestUnaryUnarySecureCall",
"unit.server_interceptor_test.TestServerInterceptor", "unit.server_interceptor_test.TestServerInterceptor",
"unit.server_test.TestServer", "unit.server_test.TestServer",
"unit.timeout_test.TestTimeout", "unit.timeout_test.TestTimeout",

@ -41,6 +41,7 @@ py_library(
"//src/proto/grpc/testing:py_messages_proto", "//src/proto/grpc/testing:py_messages_proto",
"//src/proto/grpc/testing:test_py_pb2_grpc", "//src/proto/grpc/testing:test_py_pb2_grpc",
"//src/python/grpcio/grpc:grpcio", "//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/unit:resources",
], ],
) )
@ -76,6 +77,7 @@ _FLAKY_TESTS = [
"//src/proto/grpc/testing:benchmark_service_py_pb2_grpc", "//src/proto/grpc/testing:benchmark_service_py_pb2_grpc",
"//src/proto/grpc/testing:py_messages_proto", "//src/proto/grpc/testing:py_messages_proto",
"//src/python/grpcio/grpc:grpcio", "//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/unit:resources",
"//src/python/grpcio_tests/tests/unit/framework/common", "//src/python/grpcio_tests/tests/unit/framework/common",
"@six", "@six",
], ],

@ -64,3 +64,6 @@ class AioTestBase(unittest.TestCase):
return _async_to_sync_decorator(attr, self._TEST_LOOP) return _async_to_sync_decorator(attr, self._TEST_LOOP)
# For other attributes, let them pass. # For other attributes, let them pass.
return attr return attr
aio.init_grpc_aio()

@ -17,6 +17,7 @@ import datetime
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests.unit import resources
from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc
from tests_aio.unit import _constants from tests_aio.unit import _constants
@ -129,8 +130,9 @@ async def start_test_server(port=0,
if secure: if secure:
if server_credentials is None: if server_credentials is None:
server_credentials = grpc.local_server_credentials( server_credentials = grpc.ssl_server_credentials([
grpc.LocalConnectionType.LOCAL_TCP) (resources.private_key(), resources.certificate_chain())
])
port = server.add_secure_port('[::]:%d' % port, server_credentials) port = server.add_secure_port('[::]:%d' % port, server_credentials)
else: else:
port = server.add_insecure_port('[::]:%d' % port) port = server.add_insecure_port('[::]:%d' % port)

@ -14,7 +14,6 @@
"""Tests behavior of the Call classes.""" """Tests behavior of the Call classes."""
import asyncio import asyncio
import datetime
import logging import logging
import unittest import unittest
@ -24,6 +23,8 @@ from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from tests.unit import resources
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_server import start_test_server
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
@ -55,7 +56,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
self.assertTrue(str(call) is not None) self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None) self.assertTrue(repr(call) is not None)
response = await call await call
self.assertTrue(str(call) is not None) self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None) self.assertTrue(repr(call) is not None)
@ -202,6 +203,17 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
with self.assertRaises(asyncio.CancelledError): with self.assertRaises(asyncio.CancelledError):
await task await task
async def test_passing_credentials_fails_over_insecure_channel(self):
call_credentials = grpc.composite_call_credentials(
grpc.access_token_call_credentials("abc"),
grpc.access_token_call_credentials("def"),
)
with self.assertRaisesRegex(
grpc._cygrpc.UsageError,
"Call credentials are only valid on secure channels"):
self._stub.UnaryCall(messages_pb2.SimpleRequest(),
credentials=call_credentials)
class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
@ -410,33 +422,6 @@ class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
with self.assertRaises(asyncio.CancelledError): with self.assertRaises(asyncio.CancelledError):
await task await task
def test_call_credentials(self):
class DummyAuth(grpc.AuthMetadataPlugin):
def __call__(self, context, callback):
signature = context.method_name[::-1]
callback((("test", signature),), None)
async def coro():
server_target, _ = await start_test_server(secure=False) # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary('/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.
SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.
SimpleResponse.FromString)
call_credentials = grpc.metadata_call_credentials(DummyAuth())
call = hi(messages_pb2.SimpleRequest(),
credentials=call_credentials)
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.loop.run_until_complete(coro())
async def test_time_remaining(self): async def test_time_remaining(self):
request = messages_pb2.StreamingOutputCallRequest() request = messages_pb2.StreamingOutputCallRequest()
# First message comes back immediately # First message comes back immediately

@ -20,8 +20,14 @@ from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from tests.unit import resources
class TestInsecureChannel(AioTestBase): _PRIVATE_KEY = resources.private_key()
_CERTIFICATE_CHAIN = resources.certificate_chain()
_TEST_ROOT_CERTIFICATES = resources.test_root_certificates()
class TestChannel(AioTestBase):
async def test_insecure_channel(self): async def test_insecure_channel(self):
server_target, _ = await start_test_server() # pylint: disable=unused-variable server_target, _ = await start_test_server() # pylint: disable=unused-variable
@ -29,22 +35,17 @@ class TestInsecureChannel(AioTestBase):
channel = aio.insecure_channel(server_target) channel = aio.insecure_channel(server_target)
self.assertIsInstance(channel, aio.Channel) self.assertIsInstance(channel, aio.Channel)
async def tests_secure_channel(self):
class TestSecureChannel(AioTestBase):
"""Test a secure channel connected to a secure server"""
def test_secure_channel(self):
async def coro():
server_target, _ = await start_test_server(secure=True) # pylint: disable=unused-variable server_target, _ = await start_test_server(secure=True) # pylint: disable=unused-variable
credentials = grpc.local_channel_credentials( credentials = grpc.ssl_channel_credentials(
grpc.LocalConnectionType.LOCAL_TCP) root_certificates=_TEST_ROOT_CERTIFICATES,
private_key=_PRIVATE_KEY,
certificate_chain=_CERTIFICATE_CHAIN,
)
secure_channel = aio.secure_channel(server_target, credentials) secure_channel = aio.secure_channel(server_target, credentials)
self.assertIsInstance(secure_channel, aio.Channel) self.assertIsInstance(secure_channel, aio.Channel)
self.loop.run_until_complete(coro())
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig() logging.basicConfig()

@ -0,0 +1,130 @@
# Copyright 2020 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests the behaviour of the Call classes under a secure channel."""
import unittest
import logging
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
from tests.unit import resources
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
class _SecureCallMixin:
"""A Mixin to run the call tests over a secure channel."""
async def setUp(self):
server_credentials = grpc.ssl_server_credentials([
(resources.private_key(), resources.certificate_chain())
])
channel_credentials = grpc.ssl_channel_credentials(
resources.test_root_certificates())
self._server_address, self._server = await start_test_server(
secure=True, server_credentials=server_credentials)
channel_options = ((
'grpc.ssl_target_name_override',
_SERVER_HOST_OVERRIDE,
),)
self._channel = aio.secure_channel(self._server_address,
channel_credentials, channel_options)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
class TestUnaryUnarySecureCall(_SecureCallMixin, AioTestBase):
"""unary_unary Calls made over a secure channel."""
async def test_call_ok_over_secure_channel(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_call_with_credentials(self):
call_credentials = grpc.composite_call_credentials(
grpc.access_token_call_credentials("abc"),
grpc.access_token_call_credentials("def"),
)
call = self._stub.UnaryCall(messages_pb2.SimpleRequest(),
credentials=call_credentials)
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
class TestUnaryStreamSecureCall(_SecureCallMixin, AioTestBase):
"""unary_stream calls over a secure channel"""
async def test_unary_stream_async_generator_secure(self):
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)
for _ in range(_NUM_STREAM_RESPONSES))
call_credentials = grpc.composite_call_credentials(
grpc.access_token_call_credentials("abc"),
grpc.access_token_call_credentials("def"),
)
call = self._stub.StreamingOutputCall(request,
credentials=call_credentials)
async for response in call:
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(len(response.payload.body), _RESPONSE_PAYLOAD_SIZE)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
class TestStreamStreamSecureCall(_SecureCallMixin, AioTestBase):
_STREAM_ITERATIONS = 2
async def test_async_generator_secure_channel(self):
async def request_generator():
for _ in range(self._STREAM_ITERATIONS):
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
call_credentials = grpc.composite_call_credentials(
grpc.access_token_call_credentials("abc"),
grpc.access_token_call_credentials("def"),
)
call = self._stub.FullDuplexCall(request_generator(),
credentials=call_credentials)
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save