Fix the server credentials & improve socket implementation

pull/21855/head
Lidi Zheng 5 years ago
parent 343e77ab9e
commit 47246c86bb
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi
  2. 65
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi
  3. 5
      src/python/grpcio/grpc/experimental/aio/__init__.py
  4. 28
      src/python/grpcio_tests/tests_aio/interop/local_interop_test.py
  5. 1
      src/python/grpcio_tests/tests_aio/unit/abort_test.py

@ -18,9 +18,11 @@ cdef class _AsyncioSocket:
# Common attributes
grpc_custom_socket * _grpc_socket
grpc_custom_read_callback _grpc_read_cb
grpc_custom_write_callback _grpc_write_cb
object _reader
object _writer
object _task_read
object _task_write
object _task_connect
char * _read_buffer
# Caches the picked event loop, so we can avoid the 30ns overhead each

@ -25,10 +25,12 @@ cdef class _AsyncioSocket:
self._grpc_socket = NULL
self._grpc_connect_cb = NULL
self._grpc_read_cb = NULL
self._grpc_write_cb = NULL
self._reader = None
self._writer = None
self._task_connect = None
self._task_read = None
self._task_write = None
self._read_buffer = NULL
self._server = None
self._py_socket = None
@ -82,33 +84,26 @@ cdef class _AsyncioSocket:
<grpc_error*>0
)
def _read_cb(self, future):
error = False
async def _async_read(self, size_t length):
self._task_read = None
try:
buffer_ = future.result()
except Exception as e:
error = True
error_msg = "%s: %s" % (type(e), str(e))
_LOGGER.debug(e)
finally:
self._task_read = None
if not error:
string.memcpy(
<void*>self._read_buffer,
<char*>buffer_,
len(buffer_)
)
inbound_buffer = await self._reader.read(n=length)
except ConnectionError as e:
self._grpc_read_cb(
<grpc_custom_socket*>self._grpc_socket,
len(buffer_),
<grpc_error*>0
-1,
grpc_socket_error("Read failed: {}".format(e).encode())
)
else:
string.memcpy(
<void*>self._read_buffer,
<char*>inbound_buffer,
len(inbound_buffer)
)
self._grpc_read_cb(
<grpc_custom_socket*>self._grpc_socket,
-1,
grpc_socket_error("Read failed: {}".format(error_msg).encode())
len(inbound_buffer),
<grpc_error*>0
)
cdef void connect(self,
@ -127,13 +122,25 @@ cdef class _AsyncioSocket:
cdef void read(self, char * buffer_, size_t length, grpc_custom_read_callback grpc_read_cb):
assert not self._task_read
self._task_read = self._loop.create_task(
self._reader.read(n=length)
)
self._grpc_read_cb = grpc_read_cb
self._task_read.add_done_callback(self._read_cb)
self._read_buffer = buffer_
self._task_read = self._loop.create_task(self._async_read(length))
async def _async_write(self, bytearray outbound_buffer):
self._writer.write(outbound_buffer)
self._task_write = None
try:
await self._writer.drain()
self._grpc_write_cb(
<grpc_custom_socket*>self._grpc_socket,
<grpc_error*>0
)
except ConnectionError as connection_error:
self._grpc_write_cb(
<grpc_custom_socket*>self._grpc_socket,
grpc_socket_error("Socket write failed: {}".format(connection_error).encode()),
)
cdef void write(self, grpc_slice_buffer * g_slice_buffer, grpc_custom_write_callback grpc_write_cb):
"""Performs write to network socket in AsyncIO.
@ -141,6 +148,7 @@ cdef class _AsyncioSocket:
When the write is finished, we need to call grpc_write_cb to notify
Core that the work is done.
"""
assert not self._task_write
cdef char* start
cdef bytearray outbound_buffer = bytearray()
for i in range(g_slice_buffer.count):
@ -148,11 +156,8 @@ cdef class _AsyncioSocket:
length = grpc_slice_buffer_length(g_slice_buffer, i)
outbound_buffer.extend(<bytes>start[:length])
self._writer.write(outbound_buffer)
grpc_write_cb(
<grpc_custom_socket*>self._grpc_socket,
<grpc_error*>0
)
self._grpc_write_cb = grpc_write_cb
self._task_write = self._loop.create_task(self._async_write(outbound_buffer))
cdef bint is_connected(self):
return self._reader and not self._reader._transport.is_closing()

@ -30,11 +30,12 @@ from ._channel import Channel, UnaryUnaryMultiCallable
from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor)
from ._server import Server, server
from ._typing import ChannelArgumentType
def insecure_channel(
target: Text,
options: Optional[Sequence[Tuple[Text, Any]]] = None,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
"""Creates an insecure asynchronous Channel to a server.
@ -58,7 +59,7 @@ def insecure_channel(
def secure_channel(
target: Text,
credentials: grpc.ChannelCredentials,
options: Optional[list] = None,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
"""Creates a secure asynchronous Channel to a server.

@ -25,6 +25,8 @@ from tests_aio.interop import methods
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
class InteropTestCaseMixin:
"""Unit test methods.
@ -104,6 +106,30 @@ class InsecureLocalInteropTest(InteropTestCaseMixin, AioTestBase):
await self._server.stop(None)
class SecureLocalInteropTest(InteropTestCaseMixin, AioTestBase):
async def setUp(self):
server_credentials = grpc.ssl_server_credentials([
(resources.private_key(), resources.certificate_chain())
])
channel_credentials = grpc.ssl_channel_credentials(
resources.test_root_certificates())
channel_options = ((
'grpc.ssl_target_name_override',
_SERVER_HOST_OVERRIDE,
),)
address, self._server = await start_test_server(
secure=True, server_credentials=server_credentials)
self._channel = aio.secure_channel(address, channel_credentials,
channel_options)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
unittest.main(verbosity=2)

@ -136,6 +136,7 @@ class TestAbort(AioTestBase):
with self.assertRaises(aio.AioRpcError) as exception_context:
await call.read()
await call.read()
rpc_error = exception_context.exception
self.assertEqual(_ABORT_CODE, rpc_error.code())

Loading…
Cancel
Save