Merge pull request #22343 from lidizheng/async-unary-unary-credentials-tests

[Aio] Extend unit tests for async credentials calls
pull/22426/head
Lidi Zheng 5 years ago committed by GitHub
commit 39c4fd7972
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  3. 5
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  4. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pxd.pxi
  5. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pyx.pxi
  6. 11
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/iomgr.pyx.pxi
  7. 10
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
  8. 5
      src/python/grpcio/grpc/experimental/aio/__init__.py
  9. 2
      src/python/grpcio_tests/tests_aio/interop/BUILD.bazel
  10. 6
      src/python/grpcio_tests/tests_aio/tests.json
  11. 2
      src/python/grpcio_tests/tests_aio/unit/BUILD.bazel
  12. 3
      src/python/grpcio_tests/tests_aio/unit/_test_base.py
  13. 6
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  14. 43
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  15. 25
      src/python/grpcio_tests/tests_aio/unit/init_test.py
  16. 130
      src/python/grpcio_tests/tests_aio/unit/secure_call_test.py

@ -125,7 +125,7 @@ cdef class _AioCall(GrpcCallWrapper):
if credentials is not None:
set_credentials_error = grpc_call_set_credentials(self.call, credentials.c())
if set_credentials_error != GRPC_CALL_OK:
raise Exception("Credentials couldn't have been set")
raise InternalError("Credentials couldn't have been set: {0}".format(set_credentials_error))
grpc_slice_unref(method_slice)

@ -24,3 +24,4 @@ cdef class AioChannel:
object loop
bytes _target
AioChannelStatus _status
bint _is_secure

@ -36,11 +36,13 @@ cdef class AioChannel:
self._status = AIO_CHANNEL_STATUS_READY
if credentials is None:
self._is_secure = False
self.channel = grpc_insecure_channel_create(
<char *>target,
channel_args.c_args(),
NULL)
else:
self._is_secure = True
self.channel = grpc_secure_channel_create(
<grpc_channel_credentials *> credentials.c(),
<char *>target,
@ -122,6 +124,9 @@ cdef class AioChannel:
cdef CallCredentials cython_call_credentials
if python_call_credentials is not None:
if not self._is_secure:
raise UsageError("Call credentials are only valid on secure channels")
cython_call_credentials = python_call_credentials._credentials
else:
cython_call_credentials = None

@ -23,10 +23,10 @@ cdef class _AioState:
cdef grpc_completion_queue *global_completion_queue()
cdef init_grpc_aio()
cpdef init_grpc_aio()
cdef shutdown_grpc_aio()
cpdef shutdown_grpc_aio()
cdef extern from "src/core/lib/iomgr/timer_manager.h":

@ -114,8 +114,8 @@ cdef _actual_aio_shutdown():
raise ValueError('Unsupported engine type [%s]' % _global_aio_state.engine)
cdef init_grpc_aio():
"""Initialis the gRPC AsyncIO module.
cpdef init_grpc_aio():
"""Initializes the gRPC AsyncIO module.
Expected to be invoked on critical class constructors.
E.g., AioChannel, AioServer.
@ -126,7 +126,7 @@ cdef init_grpc_aio():
_actual_aio_initialization()
cdef shutdown_grpc_aio():
cpdef shutdown_grpc_aio():
"""Shuts down the gRPC AsyncIO module.
Expected to be invoked on critical class destructors.

@ -212,7 +212,18 @@ cdef void asyncio_run_loop(size_t timeout_ms) with gil:
pass
def _auth_plugin_callback_wrapper(object cb,
str service_url,
str method_name,
object callback):
asyncio.get_event_loop().call_soon(cb, service_url, method_name, callback)
def install_asyncio_iomgr():
# Auth plugins invoke user provided logic in another thread by default. We
# need to override that behavior by registering the call to the event loop.
set_async_callback_func(_auth_plugin_callback_wrapper)
asyncio_resolver_vtable.resolve = asyncio_resolve
asyncio_resolver_vtable.resolve_async = asyncio_resolve_async

@ -34,11 +34,13 @@ cdef class CallCredentials:
raise NotImplementedError()
cdef int _get_metadata(
void *state, grpc_auth_metadata_context context,
grpc_credentials_plugin_metadata_cb cb, void *user_data,
cdef int _get_metadata(void *state,
grpc_auth_metadata_context context,
grpc_credentials_plugin_metadata_cb cb,
void *user_data,
grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX],
size_t *num_creds_md, grpc_status_code *status,
size_t *num_creds_md,
grpc_status_code *status,
const char **error_details) except * with gil:
cdef size_t metadata_count
cdef grpc_metadata *c_metadata

@ -20,7 +20,8 @@ created. AsyncIO doesn't provide thread safety for most of its APIs.
from typing import Any, Optional, Sequence, Tuple
import grpc
from grpc._cython.cygrpc import (EOF, AbortError, BaseError, InternalError,
from grpc._cython.cygrpc import (init_grpc_aio, shutdown_grpc_aio, EOF,
AbortError, BaseError, InternalError,
UsageError)
from ._base_call import (Call, RpcContext, StreamStreamCall, StreamUnaryCall,
@ -39,6 +40,8 @@ from ._channel import insecure_channel, secure_channel
################################### __all__ #################################
__all__ = (
'init_grpc_aio',
'shutdown_grpc_aio',
'AioRpcError',
'RpcContext',
'Call',

@ -56,6 +56,7 @@ py_binary(
python_version = "PY3",
deps = [
"//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/interop:resources",
"//src/python/grpcio_tests/tests/interop:server",
"//src/python/grpcio_tests/tests_aio/unit:_test_server",
],
@ -70,5 +71,6 @@ py_binary(
":methods",
"//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/interop:client",
"//src/python/grpcio_tests/tests/interop:resources",
],
)

@ -19,9 +19,11 @@
"unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback",
"unit.init_test.TestInsecureChannel",
"unit.init_test.TestSecureChannel",
"unit.init_test.TestChannel",
"unit.metadata_test.TestMetadata",
"unit.secure_call_test.TestStreamStreamSecureCall",
"unit.secure_call_test.TestUnaryStreamSecureCall",
"unit.secure_call_test.TestUnaryUnarySecureCall",
"unit.server_interceptor_test.TestServerInterceptor",
"unit.server_test.TestServer",
"unit.timeout_test.TestTimeout",

@ -41,6 +41,7 @@ py_library(
"//src/proto/grpc/testing:py_messages_proto",
"//src/proto/grpc/testing:test_py_pb2_grpc",
"//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/unit:resources",
],
)
@ -76,6 +77,7 @@ _FLAKY_TESTS = [
"//src/proto/grpc/testing:benchmark_service_py_pb2_grpc",
"//src/proto/grpc/testing:py_messages_proto",
"//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_tests/tests/unit:resources",
"//src/python/grpcio_tests/tests/unit/framework/common",
"@six",
],

@ -64,3 +64,6 @@ class AioTestBase(unittest.TestCase):
return _async_to_sync_decorator(attr, self._TEST_LOOP)
# For other attributes, let them pass.
return attr
aio.init_grpc_aio()

@ -17,6 +17,7 @@ import datetime
import grpc
from grpc.experimental import aio
from tests.unit import resources
from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc
from tests_aio.unit import _constants
@ -129,8 +130,9 @@ async def start_test_server(port=0,
if secure:
if server_credentials is None:
server_credentials = grpc.local_server_credentials(
grpc.LocalConnectionType.LOCAL_TCP)
server_credentials = grpc.ssl_server_credentials([
(resources.private_key(), resources.certificate_chain())
])
port = server.add_secure_port('[::]:%d' % port, server_credentials)
else:
port = server.add_insecure_port('[::]:%d' % port)

@ -14,7 +14,6 @@
"""Tests behavior of the Call classes."""
import asyncio
import datetime
import logging
import unittest
@ -24,6 +23,8 @@ from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_base import AioTestBase
from tests.unit import resources
from tests_aio.unit._test_server import start_test_server
_NUM_STREAM_RESPONSES = 5
@ -55,7 +56,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None)
response = await call
await call
self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None)
@ -202,6 +203,17 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
with self.assertRaises(asyncio.CancelledError):
await task
async def test_passing_credentials_fails_over_insecure_channel(self):
call_credentials = grpc.composite_call_credentials(
grpc.access_token_call_credentials("abc"),
grpc.access_token_call_credentials("def"),
)
with self.assertRaisesRegex(
grpc._cygrpc.UsageError,
"Call credentials are only valid on secure channels"):
self._stub.UnaryCall(messages_pb2.SimpleRequest(),
credentials=call_credentials)
class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
@ -410,33 +422,6 @@ class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
with self.assertRaises(asyncio.CancelledError):
await task
def test_call_credentials(self):
class DummyAuth(grpc.AuthMetadataPlugin):
def __call__(self, context, callback):
signature = context.method_name[::-1]
callback((("test", signature),), None)
async def coro():
server_target, _ = await start_test_server(secure=False) # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary('/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.
SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.
SimpleResponse.FromString)
call_credentials = grpc.metadata_call_credentials(DummyAuth())
call = hi(messages_pb2.SimpleRequest(),
credentials=call_credentials)
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.loop.run_until_complete(coro())
async def test_time_remaining(self):
request = messages_pb2.StreamingOutputCallRequest()
# First message comes back immediately

@ -20,8 +20,14 @@ from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from tests.unit import resources
class TestInsecureChannel(AioTestBase):
_PRIVATE_KEY = resources.private_key()
_CERTIFICATE_CHAIN = resources.certificate_chain()
_TEST_ROOT_CERTIFICATES = resources.test_root_certificates()
class TestChannel(AioTestBase):
async def test_insecure_channel(self):
server_target, _ = await start_test_server() # pylint: disable=unused-variable
@ -29,22 +35,17 @@ class TestInsecureChannel(AioTestBase):
channel = aio.insecure_channel(server_target)
self.assertIsInstance(channel, aio.Channel)
class TestSecureChannel(AioTestBase):
"""Test a secure channel connected to a secure server"""
def test_secure_channel(self):
async def coro():
async def tests_secure_channel(self):
server_target, _ = await start_test_server(secure=True) # pylint: disable=unused-variable
credentials = grpc.local_channel_credentials(
grpc.LocalConnectionType.LOCAL_TCP)
credentials = grpc.ssl_channel_credentials(
root_certificates=_TEST_ROOT_CERTIFICATES,
private_key=_PRIVATE_KEY,
certificate_chain=_CERTIFICATE_CHAIN,
)
secure_channel = aio.secure_channel(server_target, credentials)
self.assertIsInstance(secure_channel, aio.Channel)
self.loop.run_until_complete(coro())
if __name__ == '__main__':
logging.basicConfig()

@ -0,0 +1,130 @@
# Copyright 2020 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests the behaviour of the Call classes under a secure channel."""
import unittest
import logging
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
from tests.unit import resources
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
class _SecureCallMixin:
"""A Mixin to run the call tests over a secure channel."""
async def setUp(self):
server_credentials = grpc.ssl_server_credentials([
(resources.private_key(), resources.certificate_chain())
])
channel_credentials = grpc.ssl_channel_credentials(
resources.test_root_certificates())
self._server_address, self._server = await start_test_server(
secure=True, server_credentials=server_credentials)
channel_options = ((
'grpc.ssl_target_name_override',
_SERVER_HOST_OVERRIDE,
),)
self._channel = aio.secure_channel(self._server_address,
channel_credentials, channel_options)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
class TestUnaryUnarySecureCall(_SecureCallMixin, AioTestBase):
"""unary_unary Calls made over a secure channel."""
async def test_call_ok_over_secure_channel(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_call_with_credentials(self):
call_credentials = grpc.composite_call_credentials(
grpc.access_token_call_credentials("abc"),
grpc.access_token_call_credentials("def"),
)
call = self._stub.UnaryCall(messages_pb2.SimpleRequest(),
credentials=call_credentials)
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
class TestUnaryStreamSecureCall(_SecureCallMixin, AioTestBase):
"""unary_stream calls over a secure channel"""
async def test_unary_stream_async_generator_secure(self):
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)
for _ in range(_NUM_STREAM_RESPONSES))
call_credentials = grpc.composite_call_credentials(
grpc.access_token_call_credentials("abc"),
grpc.access_token_call_credentials("def"),
)
call = self._stub.StreamingOutputCall(request,
credentials=call_credentials)
async for response in call:
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(len(response.payload.body), _RESPONSE_PAYLOAD_SIZE)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
class TestStreamStreamSecureCall(_SecureCallMixin, AioTestBase):
_STREAM_ITERATIONS = 2
async def test_async_generator_secure_channel(self):
async def request_generator():
for _ in range(self._STREAM_ITERATIONS):
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
call_credentials = grpc.composite_call_credentials(
grpc.access_token_call_credentials("abc"),
grpc.access_token_call_credentials("def"),
)
call = self._stub.FullDuplexCall(request_generator(),
credentials=call_credentials)
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save