Merge pull request #21351 from Skyscanner/async-unary-unary-credentials

[Aio] Support credentials for unary calls
pull/20316/head
Lidi Zheng 5 years ago committed by GitHub
commit 3965762ab0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 21
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  4. 23
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 31
      src/python/grpcio/grpc/experimental/aio/__init__.py
  6. 37
      src/python/grpcio/grpc/experimental/aio/_call.py
  7. 16
      src/python/grpcio/grpc/experimental/aio/_channel.py
  8. 23
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  9. 1
      src/python/grpcio_tests/tests_aio/tests.json
  10. 11
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  11. 27
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  12. 3
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  13. 18
      src/python/grpcio_tests/tests_aio/unit/init_test.py

@ -28,4 +28,4 @@ cdef class _AioCall(GrpcCallWrapper):
# because Core is holding a pointer for the callback handler.
bint _is_locally_cancelled
cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except *
cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except *

@ -28,12 +28,13 @@ cdef class _AioCall:
def __cinit__(self,
AioChannel channel,
object deadline,
bytes method):
bytes method,
CallCredentials credentials):
self.call = NULL
self._channel = channel
self._references = []
self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method)
self._create_grpc_call(deadline, method, credentials)
self._is_locally_cancelled = False
def __dealloc__(self):
@ -45,12 +46,13 @@ cdef class _AioCall:
id_ = id(self)
return f"<{class_name} {id_}>"
cdef grpc_call* _create_grpc_call(self,
object deadline,
bytes method) except *:
cdef void _create_grpc_call(self,
object deadline,
bytes method,
CallCredentials credentials) except *:
"""Creates the corresponding Core object for this RPC.
For unary calls, the grpc_call lives shortly and can be destroied after
For unary calls, the grpc_call lives shortly and can be destroyed after
invoke start_batch. However, if either side is streaming, the grpc_call
life span will be longer than one function. So, it would better save it
as an instance variable than a stack variable, which reflects its
@ -58,6 +60,7 @@ cdef class _AioCall:
"""
cdef grpc_slice method_slice
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef grpc_call_error set_credentials_error
method_slice = grpc_slice_from_copied_buffer(
<const char *> method,
@ -73,6 +76,12 @@ cdef class _AioCall:
c_deadline,
NULL
)
if credentials is not None:
set_credentials_error = grpc_call_set_credentials(self.call, credentials.c())
if set_credentials_error != GRPC_CALL_OK:
raise Exception("Credentials couldn't have been set")
grpc_slice_unref(method_slice)
def cancel(self, AioRpcStatus status):

@ -14,7 +14,7 @@
cdef class CallbackFailureHandler:
def __cinit__(self,
str core_function_name,
object error_details,
@ -78,7 +78,7 @@ cdef class CallbackCompletionQueue:
cdef grpc_completion_queue* c_ptr(self):
return self._cq
async def shutdown(self):
grpc_completion_queue_shutdown(self._cq)
await self._shutdown_completed

@ -12,14 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
cdef class AioChannel:
def __cinit__(self, bytes target, tuple options):
def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials):
if options is None:
options = ()
cdef _ChannelArgs channel_args = _ChannelArgs(options)
self.channel = grpc_insecure_channel_create(<char *>target, channel_args.c_args(), NULL)
self.cq = CallbackCompletionQueue()
self._target = target
self.cq = CallbackCompletionQueue()
if credentials is None:
self.channel = grpc_insecure_channel_create(
<char *>target,
channel_args.c_args(),
NULL)
else:
self.channel = grpc_secure_channel_create(
<grpc_channel_credentials *> credentials.c(),
<char *> target,
channel_args.c_args(),
NULL)
def __repr__(self):
class_name = self.__class__.__name__
@ -31,11 +43,12 @@ cdef class AioChannel:
def call(self,
bytes method,
object deadline):
object deadline,
CallCredentials credentials):
"""Assembles a Cython Call object.
Returns:
The _AioCall object.
"""
cdef _AioCall call = _AioCall(self, deadline, method)
cdef _AioCall call = _AioCall(self, deadline, method, credentials)
return call

@ -52,10 +52,33 @@ def insecure_channel(
Returns:
A Channel.
"""
return Channel(target, () if options is None else options, None,
compression, interceptors)
def secure_channel(
target: Text,
credentials: grpc.ChannelCredentials,
options: Optional[list] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
"""Creates a secure asynchronous Channel to a server.
Args:
target: The server address.
credentials: A ChannelCredentials instance.
options: An optional list of key-value pairs (channel args
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns:
An aio.Channel.
"""
return Channel(target, () if options is None else options,
None,
compression,
interceptors=interceptors)
credentials._credentials, compression, interceptors)
################################### __all__ #################################
@ -64,4 +87,4 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall',
'UnaryStreamCall', 'init_grpc_aio', 'Channel',
'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
'insecure_channel', 'server', 'Server')
'insecure_channel', 'secure_channel', 'server')

@ -260,16 +260,24 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
_call: asyncio.Task
_cython_call: cygrpc._AioCall
def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
def __init__( # pylint: disable=R0913
self, request: RequestType, deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__()
self._request = request
self._channel = channel
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._cython_call = self._channel.call(method, deadline)
if credentials is not None:
grpc_credentials = credentials._credentials
else:
grpc_credentials = None
self._cython_call = self._channel.call(method, deadline,
grpc_credentials)
self._call = self._loop.create_task(self._invoke())
def __del__(self) -> None:
@ -345,10 +353,12 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
_send_unary_request_task: asyncio.Task
_message_aiter: AsyncIterable[ResponseType]
def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
def __init__( # pylint: disable=R0913
self, request: RequestType, deadline: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
super().__init__()
self._request = request
self._channel = channel
@ -357,7 +367,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._send_unary_request_task = self._loop.create_task(
self._send_unary_request())
self._message_aiter = self._fetch_stream_responses()
self._cython_call = self._channel.call(method, deadline)
if credentials is not None:
grpc_credentials = credentials._credentials
else:
grpc_credentials = None
self._cython_call = self._channel.call(method, deadline,
grpc_credentials)
def __del__(self) -> None:
if not self._status.done():

@ -85,13 +85,9 @@ class UnaryUnaryMultiCallable:
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if credentials:
raise NotImplementedError("TODO: credentials not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -99,6 +95,7 @@ class UnaryUnaryMultiCallable:
return UnaryUnaryCall(
request,
_timeout_to_deadline(timeout),
credentials,
self._channel,
self._method,
self._request_serializer,
@ -109,6 +106,7 @@ class UnaryUnaryMultiCallable:
self._interceptors,
request,
timeout,
credentials,
self._channel,
self._method,
self._request_serializer,
@ -158,9 +156,6 @@ class UnaryStreamMultiCallable:
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if credentials:
raise NotImplementedError("TODO: credentials not implemented yet")
if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")
@ -173,6 +168,7 @@ class UnaryStreamMultiCallable:
return UnaryStreamCall(
request,
deadline,
credentials,
self._channel,
self._method,
self._request_serializer,
@ -204,9 +200,6 @@ class Channel:
intercepting any RPC executed with that channel.
"""
if credentials:
raise NotImplementedError("TODO: credentials not implemented yet")
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
@ -228,7 +221,8 @@ class Channel:
"UnaryUnaryClientInterceptors, the following are invalid: {}"\
.format(invalid_interceptors))
self._channel = cygrpc.AioChannel(_common.encode(target), options)
self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials)
def unary_unary(
self,

@ -106,24 +106,25 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
def __init__( # pylint: disable=R0913
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._loop = asyncio.get_event_loop()
self._interceptors_task = asyncio.ensure_future(
self._invoke(interceptors, method, timeout, request,
self._invoke(interceptors, method, timeout, credentials, request,
request_serializer, response_deserializer))
def __del__(self):
self.cancel()
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction
) -> UnaryUnaryCall:
async def _invoke( # pylint: disable=R0913
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
credentials: Optional[grpc.CallCredentials], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> UnaryUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
@ -147,10 +148,12 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
else:
return UnaryUnaryCall(
request, _timeout_to_deadline(client_call_details.timeout),
self._channel, client_call_details.method,
request_serializer, response_deserializer)
client_call_details.credentials, self._channel,
client_call_details.method, request_serializer,
response_deserializer)
client_call_details = ClientCallDetails(method, timeout, None, None)
client_call_details = ClientCallDetails(method, timeout, None,
credentials)
return await _run_interceptor(iter(interceptors), client_call_details,
request)

@ -6,6 +6,7 @@
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_test.TestChannel",
"unit.init_test.TestInsecureChannel",
"unit.init_test.TestSecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.server_test.TestServer"

@ -17,6 +17,7 @@ import logging
import datetime
import grpc
from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2
@ -51,7 +52,7 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
return messages_pb2.SimpleResponse()
async def start_test_server():
async def start_test_server(secure=False):
server = aio.server(options=(('grpc.so_reuseport', 0),))
servicer = _TestServiceServicer()
test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)
@ -70,7 +71,13 @@ async def start_test_server():
'grpc.testing.TestService', rpc_method_handlers)
server.add_generic_rpc_handlers((extra_handler,))
port = server.add_insecure_port('[::]:0')
if secure:
server_credentials = grpc.local_server_credentials(
grpc.LocalConnectionType.LOCAL_TCP)
port = server.add_secure_port('[::]:0', server_credentials)
else:
port = server.add_insecure_port('[::]:0')
await server.start()
# NOTE(lidizheng) returning the server to prevent it from deallocation
return 'localhost:%d' % port, server

@ -398,6 +398,33 @@ class TestUnaryStreamCall(AioTestBase):
with self.assertRaises(asyncio.CancelledError):
await task
def test_call_credentials(self):
class DummyAuth(grpc.AuthMetadataPlugin):
def __call__(self, context, callback):
signature = context.method_name[::-1]
callback((("test", signature),), None)
async def coro():
server_target, _ = await start_test_server(secure=False) # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary('/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.
SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.
SimpleResponse.FromString)
call_credentials = grpc.metadata_call_credentials(DummyAuth())
call = hi(messages_pb2.SimpleRequest(),
credentials=call_credentials)
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.loop.run_until_complete(coro())
if __name__ == '__main__':
logging.basicConfig()

@ -14,6 +14,7 @@
"""Tests behavior of the grpc.aio.Channel class."""
import logging
import os
import threading
import unittest
@ -82,6 +83,8 @@ class TestChannel(AioTestBase):
self.assertIsNotNone(
exception_context.exception.trailing_metadata())
@unittest.skipIf(os.name == 'nt',
'TODO: https://github.com/grpc/grpc/issues/21658')
async def test_unary_call_does_not_times_out(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(

@ -14,6 +14,8 @@
import logging
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
@ -28,6 +30,22 @@ class TestInsecureChannel(AioTestBase):
self.assertIsInstance(channel, aio.Channel)
class TestSecureChannel(AioTestBase):
"""Test a secure channel connected to a secure server"""
def test_secure_channel(self):
async def coro():
server_target, _ = await start_test_server(secure=True) # pylint: disable=unused-variable
credentials = grpc.local_channel_credentials(
grpc.LocalConnectionType.LOCAL_TCP)
secure_channel = aio.secure_channel(server_target, credentials)
self.assertIsInstance(secure_channel, aio.Channel)
self.loop.run_until_complete(coro())
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)

Loading…
Cancel
Save