From 11a29eb95a2b1e7e1cb583b782fb1883cd40f775 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Mon, 22 Jun 2020 10:22:19 -0700 Subject: [PATCH] Implement methods to access auth context and peer info --- .../grpc/_cython/_cygrpc/aio/server.pyx.pxi | 49 +++++ .../grpc/experimental/aio/_base_server.py | 43 +++- src/python/grpcio_tests/tests_aio/tests.json | 2 + .../tests_aio/unit/auth_context_test.py | 194 ++++++++++++++++++ .../tests_aio/unit/context_peer_test.py | 65 ++++++ 5 files changed, 352 insertions(+), 1 deletion(-) create mode 100644 src/python/grpcio_tests/tests_aio/unit/auth_context_test.py create mode 100644 src/python/grpcio_tests/tests_aio/unit/context_peer_test.py diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index b842ec6f2ba..63dbfdd75c0 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -213,6 +213,43 @@ cdef class _ServicerContext: def disable_next_message_compression(self): self._rpc_state.disable_next_compression = True + def peer(self): + cdef char *c_peer = NULL + c_peer = grpc_call_get_peer(self._rpc_state.call) + peer = (c_peer).decode('utf8') + gpr_free(c_peer) + return peer + + def peer_identities(self): + cdef Call query_call = Call() + query_call.c_call = self._rpc_state.call + identities = peer_identities(query_call) + query_call.c_call = NULL + return identities + + def peer_identity_key(self): + cdef Call query_call = Call() + query_call.c_call = self._rpc_state.call + identity_key = peer_identity_key(query_call) + query_call.c_call = NULL + if identity_key: + return identity_key.decode('utf8') + else: + return None + + def auth_context(self): + cdef Call query_call = Call() + query_call.c_call = self._rpc_state.call + bytes_ctx = auth_context(query_call) + query_call.c_call = NULL + if bytes_ctx: + ctx = {} + for key in bytes_ctx: + ctx[key.decode('utf8')] = bytes_ctx[key] + return ctx + else: + return {} + cdef class _SyncServicerContext: """Sync servicer context for sync handler compatibility.""" @@ -260,6 +297,18 @@ cdef class _SyncServicerContext: def add_callback(self, object callback): self._callbacks.append(callback) + def peer(self): + return self._context.peer() + + def peer_identities(self): + return self._context.peer_identities() + + def peer_identity_key(self): + return self._context.peer_identity_key() + + def auth_context(self): + return self._context.auth_context() + async def _run_interceptor(object interceptors, object query_handler, object handler_call_details): diff --git a/src/python/grpcio/grpc/experimental/aio/_base_server.py b/src/python/grpcio/grpc/experimental/aio/_base_server.py index 86c15fc86b0..926c8651714 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_server.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_server.py @@ -14,7 +14,7 @@ """Abstract base classes for server-side classes.""" import abc -from typing import Generic, Optional, Sequence +from typing import Generic, Mapping, Optional, Iterable, Sequence import grpc @@ -251,3 +251,44 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): This method will override any compression configuration set during server creation or set on the call. """ + + @abc.abstractmethod + def peer(self) -> str: + """Identifies the peer that invoked the RPC being serviced. + + Returns: + A string identifying the peer that invoked the RPC being serviced. + The string format is determined by gRPC runtime. + """ + + @abc.abstractmethod + def peer_identities(self) -> Optional[Iterable[bytes]]: + """Gets one or more peer identity(s). + + Equivalent to + servicer_context.auth_context().get(servicer_context.peer_identity_key()) + + Returns: + An iterable of the identities, or None if the call is not + authenticated. Each identity is returned as a raw bytes type. + """ + + @abc.abstractmethod + def peer_identity_key(self) -> Optional[str]: + """The auth property used to identify the peer. + + For example, "x509_common_name" or "x509_subject_alternative_name" are + used to identify an SSL peer. + + Returns: + The auth property (string) that indicates the + peer identity, or None if the call is not authenticated. + """ + + @abc.abstractmethod + def auth_context(self) -> Mapping[str, Iterable[bytes]]: + """Gets the auth context for the call. + + Returns: + A map of strings to an iterable of bytes for each auth property. + """ diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index f01d7d0570d..68ff13e5ccc 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -9,6 +9,7 @@ "unit._metadata_test.TestTypeMetadata", "unit.abort_test.TestAbort", "unit.aio_rpc_error_test.TestAioRpcError", + "unit.auth_context_test.TestAuthContext", "unit.call_test.TestStreamStreamCall", "unit.call_test.TestStreamUnaryCall", "unit.call_test.TestUnaryStreamCall", @@ -16,6 +17,7 @@ "unit.channel_argument_test.TestChannelArgument", "unit.channel_ready_test.TestChannelReady", "unit.channel_test.TestChannel", + "unit.context_peer.TestContextPeer", "unit.client_stream_stream_interceptor_test.TestStreamStreamClientInterceptor", "unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor", "unit.client_unary_stream_interceptor_test.TestUnaryStreamClientInterceptor", diff --git a/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py b/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py new file mode 100644 index 00000000000..fb303714682 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py @@ -0,0 +1,194 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Porting auth context tests from sync stack.""" + +import pickle +import unittest +import logging + +import grpc +from grpc.experimental import aio +from grpc.experimental import session_cache +import six + +from tests.unit import resources +from tests_aio.unit._test_base import AioTestBase + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_CLIENT_IDS = ( + b'*.test.google.fr', + b'waterzooi.test.google.be', + b'*.test.youtube.com', + b'192.168.1.3', +) +_ID = 'id' +_ID_KEY = 'id_key' +_AUTH_CTX = 'auth_ctx' + +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() +_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) +_PROPERTY_OPTIONS = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, +),) + + +async def handle_unary_unary(unused_request: bytes, + servicer_context: aio.ServicerContext): + return pickle.dumps({ + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context() + }) + + +class TestAuthContext(AioTestBase): + + async def test_insecure(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = aio.server() + server.add_generic_rpc_handlers((handler,)) + port = server.add_insecure_port('[::]:0') + await server.start() + + async with aio.insecure_channel('localhost:%d' % port) as channel: + response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) + await server.stop(None) + + auth_data = pickle.loads(response) + self.assertIsNone(auth_data[_ID]) + self.assertIsNone(auth_data[_ID_KEY]) + self.assertDictEqual({}, auth_data[_AUTH_CTX]) + + async def test_secure_no_cert(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = aio.server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + await server.start() + + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel = aio.secure_channel('localhost:{}'.format(port), + channel_creds, + options=_PROPERTY_OPTIONS) + response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) + await channel.close() + await server.stop(None) + + auth_data = pickle.loads(response) + self.assertIsNone(auth_data[_ID]) + self.assertIsNone(auth_data[_ID_KEY]) + self.assertDictEqual( + { + 'security_level': [b'TSI_PRIVACY_AND_INTEGRITY'], + 'transport_security_type': [b'ssl'], + 'ssl_session_reused': [b'false'], + }, auth_data[_AUTH_CTX]) + + async def test_secure_client_cert(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = aio.server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials( + _SERVER_CERTS, + root_certificates=_TEST_ROOT_CERTIFICATES, + require_client_auth=True) + port = server.add_secure_port('[::]:0', server_cred) + await server.start() + + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES, + private_key=_PRIVATE_KEY, + certificate_chain=_CERTIFICATE_CHAIN) + channel = aio.secure_channel('localhost:{}'.format(port), + channel_creds, + options=_PROPERTY_OPTIONS) + + response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) + await channel.close() + await server.stop(None) + + auth_data = pickle.loads(response) + auth_ctx = auth_data[_AUTH_CTX] + self.assertCountEqual(_CLIENT_IDS, auth_data[_ID]) + self.assertEqual('x509_subject_alternative_name', auth_data[_ID_KEY]) + self.assertSequenceEqual([b'ssl'], auth_ctx['transport_security_type']) + self.assertSequenceEqual([b'*.test.google.com'], + auth_ctx['x509_common_name']) + + async def _do_one_shot_client_rpc(self, channel_creds, channel_options, + port, expect_ssl_session_reused): + channel = aio.secure_channel('localhost:{}'.format(port), + channel_creds, + options=channel_options) + response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) + auth_data = pickle.loads(response) + self.assertEqual(expect_ssl_session_reused, + auth_data[_AUTH_CTX]['ssl_session_reused']) + await channel.close() + + async def test_session_resumption(self): + # Set up a secure server + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = aio.server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + await server.start() + + # Create a cache for TLS session tickets + cache = session_cache.ssl_session_cache_lru(1) + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel_options = _PROPERTY_OPTIONS + ( + ('grpc.ssl_session_cache', cache),) + + # Initial connection has no session to resume + await self._do_one_shot_client_rpc(channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'false']) + + # Subsequent connections resume sessions + await self._do_one_shot_client_rpc(channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'true']) + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py b/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py new file mode 100644 index 00000000000..ea5f4621afb --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py @@ -0,0 +1,65 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the server context ability to access peer info.""" + +import asyncio +import logging +import os +import unittest +from typing import Callable, Iterable, Sequence, Tuple + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests.unit.framework.common import test_constants +from tests_aio.unit import _common +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import TestServiceServicer, start_test_server + +_REQUEST = b'\x03\x07' +_TEST_METHOD = '/test/UnaryUnary' + + +class TestContextPeer(AioTestBase): + + async def test_peer(self): + + @grpc.unary_unary_rpc_method_handler + async def check_peer_unary_unary(request: bytes, + context: aio.ServicerContext): + self.assertEqual(_REQUEST, request) + # The peer address could be ipv4 or ipv6 + self.assertIn('ip', context.peer()) + return request + + # Creates a server + server = aio.server() + handlers = grpc.method_handlers_generic_handler( + 'test', {'UnaryUnary': check_peer_unary_unary}) + server.add_generic_rpc_handlers((handlers,)) + port = server.add_insecure_port('[::]:0') + await server.start() + + # Creates a channel + async with aio.insecure_channel('localhost:%d' % port) as channel: + response = await channel.unary_unary(_TEST_METHOD)(_REQUEST) + self.assertEqual(_REQUEST, response) + + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2)