Implement methods to access auth context and peer info

pull/23265/head
Lidi Zheng 4 years ago
parent dbb39a525d
commit 11a29eb95a
  1. 49
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  2. 43
      src/python/grpcio/grpc/experimental/aio/_base_server.py
  3. 2
      src/python/grpcio_tests/tests_aio/tests.json
  4. 194
      src/python/grpcio_tests/tests_aio/unit/auth_context_test.py
  5. 65
      src/python/grpcio_tests/tests_aio/unit/context_peer_test.py

@ -213,6 +213,43 @@ cdef class _ServicerContext:
def disable_next_message_compression(self):
self._rpc_state.disable_next_compression = True
def peer(self):
cdef char *c_peer = NULL
c_peer = grpc_call_get_peer(self._rpc_state.call)
peer = (<bytes>c_peer).decode('utf8')
gpr_free(c_peer)
return peer
def peer_identities(self):
cdef Call query_call = Call()
query_call.c_call = self._rpc_state.call
identities = peer_identities(query_call)
query_call.c_call = NULL
return identities
def peer_identity_key(self):
cdef Call query_call = Call()
query_call.c_call = self._rpc_state.call
identity_key = peer_identity_key(query_call)
query_call.c_call = NULL
if identity_key:
return identity_key.decode('utf8')
else:
return None
def auth_context(self):
cdef Call query_call = Call()
query_call.c_call = self._rpc_state.call
bytes_ctx = auth_context(query_call)
query_call.c_call = NULL
if bytes_ctx:
ctx = {}
for key in bytes_ctx:
ctx[key.decode('utf8')] = bytes_ctx[key]
return ctx
else:
return {}
cdef class _SyncServicerContext:
"""Sync servicer context for sync handler compatibility."""
@ -260,6 +297,18 @@ cdef class _SyncServicerContext:
def add_callback(self, object callback):
self._callbacks.append(callback)
def peer(self):
return self._context.peer()
def peer_identities(self):
return self._context.peer_identities()
def peer_identity_key(self):
return self._context.peer_identity_key()
def auth_context(self):
return self._context.auth_context()
async def _run_interceptor(object interceptors, object query_handler,
object handler_call_details):

@ -14,7 +14,7 @@
"""Abstract base classes for server-side classes."""
import abc
from typing import Generic, Optional, Sequence
from typing import Generic, Mapping, Optional, Iterable, Sequence
import grpc
@ -251,3 +251,44 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC):
This method will override any compression configuration set during
server creation or set on the call.
"""
@abc.abstractmethod
def peer(self) -> str:
"""Identifies the peer that invoked the RPC being serviced.
Returns:
A string identifying the peer that invoked the RPC being serviced.
The string format is determined by gRPC runtime.
"""
@abc.abstractmethod
def peer_identities(self) -> Optional[Iterable[bytes]]:
"""Gets one or more peer identity(s).
Equivalent to
servicer_context.auth_context().get(servicer_context.peer_identity_key())
Returns:
An iterable of the identities, or None if the call is not
authenticated. Each identity is returned as a raw bytes type.
"""
@abc.abstractmethod
def peer_identity_key(self) -> Optional[str]:
"""The auth property used to identify the peer.
For example, "x509_common_name" or "x509_subject_alternative_name" are
used to identify an SSL peer.
Returns:
The auth property (string) that indicates the
peer identity, or None if the call is not authenticated.
"""
@abc.abstractmethod
def auth_context(self) -> Mapping[str, Iterable[bytes]]:
"""Gets the auth context for the call.
Returns:
A map of strings to an iterable of bytes for each auth property.
"""

@ -9,6 +9,7 @@
"unit._metadata_test.TestTypeMetadata",
"unit.abort_test.TestAbort",
"unit.aio_rpc_error_test.TestAioRpcError",
"unit.auth_context_test.TestAuthContext",
"unit.call_test.TestStreamStreamCall",
"unit.call_test.TestStreamUnaryCall",
"unit.call_test.TestUnaryStreamCall",
@ -16,6 +17,7 @@
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_ready_test.TestChannelReady",
"unit.channel_test.TestChannel",
"unit.context_peer.TestContextPeer",
"unit.client_stream_stream_interceptor_test.TestStreamStreamClientInterceptor",
"unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor",
"unit.client_unary_stream_interceptor_test.TestUnaryStreamClientInterceptor",

@ -0,0 +1,194 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Porting auth context tests from sync stack."""
import pickle
import unittest
import logging
import grpc
from grpc.experimental import aio
from grpc.experimental import session_cache
import six
from tests.unit import resources
from tests_aio.unit._test_base import AioTestBase
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x00\x00\x00'
_UNARY_UNARY = '/test/UnaryUnary'
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
_CLIENT_IDS = (
b'*.test.google.fr',
b'waterzooi.test.google.be',
b'*.test.youtube.com',
b'192.168.1.3',
)
_ID = 'id'
_ID_KEY = 'id_key'
_AUTH_CTX = 'auth_ctx'
_PRIVATE_KEY = resources.private_key()
_CERTIFICATE_CHAIN = resources.certificate_chain()
_TEST_ROOT_CERTIFICATES = resources.test_root_certificates()
_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),)
_PROPERTY_OPTIONS = ((
'grpc.ssl_target_name_override',
_SERVER_HOST_OVERRIDE,
),)
async def handle_unary_unary(unused_request: bytes,
servicer_context: aio.ServicerContext):
return pickle.dumps({
_ID: servicer_context.peer_identities(),
_ID_KEY: servicer_context.peer_identity_key(),
_AUTH_CTX: servicer_context.auth_context()
})
class TestAuthContext(AioTestBase):
async def test_insecure(self):
handler = grpc.method_handlers_generic_handler('test', {
'UnaryUnary':
grpc.unary_unary_rpc_method_handler(handle_unary_unary)
})
server = aio.server()
server.add_generic_rpc_handlers((handler,))
port = server.add_insecure_port('[::]:0')
await server.start()
async with aio.insecure_channel('localhost:%d' % port) as channel:
response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST)
await server.stop(None)
auth_data = pickle.loads(response)
self.assertIsNone(auth_data[_ID])
self.assertIsNone(auth_data[_ID_KEY])
self.assertDictEqual({}, auth_data[_AUTH_CTX])
async def test_secure_no_cert(self):
handler = grpc.method_handlers_generic_handler('test', {
'UnaryUnary':
grpc.unary_unary_rpc_method_handler(handle_unary_unary)
})
server = aio.server()
server.add_generic_rpc_handlers((handler,))
server_cred = grpc.ssl_server_credentials(_SERVER_CERTS)
port = server.add_secure_port('[::]:0', server_cred)
await server.start()
channel_creds = grpc.ssl_channel_credentials(
root_certificates=_TEST_ROOT_CERTIFICATES)
channel = aio.secure_channel('localhost:{}'.format(port),
channel_creds,
options=_PROPERTY_OPTIONS)
response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST)
await channel.close()
await server.stop(None)
auth_data = pickle.loads(response)
self.assertIsNone(auth_data[_ID])
self.assertIsNone(auth_data[_ID_KEY])
self.assertDictEqual(
{
'security_level': [b'TSI_PRIVACY_AND_INTEGRITY'],
'transport_security_type': [b'ssl'],
'ssl_session_reused': [b'false'],
}, auth_data[_AUTH_CTX])
async def test_secure_client_cert(self):
handler = grpc.method_handlers_generic_handler('test', {
'UnaryUnary':
grpc.unary_unary_rpc_method_handler(handle_unary_unary)
})
server = aio.server()
server.add_generic_rpc_handlers((handler,))
server_cred = grpc.ssl_server_credentials(
_SERVER_CERTS,
root_certificates=_TEST_ROOT_CERTIFICATES,
require_client_auth=True)
port = server.add_secure_port('[::]:0', server_cred)
await server.start()
channel_creds = grpc.ssl_channel_credentials(
root_certificates=_TEST_ROOT_CERTIFICATES,
private_key=_PRIVATE_KEY,
certificate_chain=_CERTIFICATE_CHAIN)
channel = aio.secure_channel('localhost:{}'.format(port),
channel_creds,
options=_PROPERTY_OPTIONS)
response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST)
await channel.close()
await server.stop(None)
auth_data = pickle.loads(response)
auth_ctx = auth_data[_AUTH_CTX]
self.assertCountEqual(_CLIENT_IDS, auth_data[_ID])
self.assertEqual('x509_subject_alternative_name', auth_data[_ID_KEY])
self.assertSequenceEqual([b'ssl'], auth_ctx['transport_security_type'])
self.assertSequenceEqual([b'*.test.google.com'],
auth_ctx['x509_common_name'])
async def _do_one_shot_client_rpc(self, channel_creds, channel_options,
port, expect_ssl_session_reused):
channel = aio.secure_channel('localhost:{}'.format(port),
channel_creds,
options=channel_options)
response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST)
auth_data = pickle.loads(response)
self.assertEqual(expect_ssl_session_reused,
auth_data[_AUTH_CTX]['ssl_session_reused'])
await channel.close()
async def test_session_resumption(self):
# Set up a secure server
handler = grpc.method_handlers_generic_handler('test', {
'UnaryUnary':
grpc.unary_unary_rpc_method_handler(handle_unary_unary)
})
server = aio.server()
server.add_generic_rpc_handlers((handler,))
server_cred = grpc.ssl_server_credentials(_SERVER_CERTS)
port = server.add_secure_port('[::]:0', server_cred)
await server.start()
# Create a cache for TLS session tickets
cache = session_cache.ssl_session_cache_lru(1)
channel_creds = grpc.ssl_channel_credentials(
root_certificates=_TEST_ROOT_CERTIFICATES)
channel_options = _PROPERTY_OPTIONS + (
('grpc.ssl_session_cache', cache),)
# Initial connection has no session to resume
await self._do_one_shot_client_rpc(channel_creds,
channel_options,
port,
expect_ssl_session_reused=[b'false'])
# Subsequent connections resume sessions
await self._do_one_shot_client_rpc(channel_creds,
channel_options,
port,
expect_ssl_session_reused=[b'true'])
await server.stop(None)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main()

@ -0,0 +1,65 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing the server context ability to access peer info."""
import asyncio
import logging
import os
import unittest
from typing import Callable, Iterable, Sequence, Tuple
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit import _common
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import TestServiceServicer, start_test_server
_REQUEST = b'\x03\x07'
_TEST_METHOD = '/test/UnaryUnary'
class TestContextPeer(AioTestBase):
async def test_peer(self):
@grpc.unary_unary_rpc_method_handler
async def check_peer_unary_unary(request: bytes,
context: aio.ServicerContext):
self.assertEqual(_REQUEST, request)
# The peer address could be ipv4 or ipv6
self.assertIn('ip', context.peer())
return request
# Creates a server
server = aio.server()
handlers = grpc.method_handlers_generic_handler(
'test', {'UnaryUnary': check_peer_unary_unary})
server.add_generic_rpc_handlers((handlers,))
port = server.add_insecure_port('[::]:0')
await server.start()
# Creates a channel
async with aio.insecure_channel('localhost:%d' % port) as channel:
response = await channel.unary_unary(_TEST_METHOD)(_REQUEST)
self.assertEqual(_REQUEST, response)
await server.stop(None)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save