diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi index fef0e1ae1cb..cd425d2e941 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi @@ -18,9 +18,11 @@ cdef class _AsyncioSocket: # Common attributes grpc_custom_socket * _grpc_socket grpc_custom_read_callback _grpc_read_cb + grpc_custom_write_callback _grpc_write_cb object _reader object _writer object _task_read + object _task_write object _task_connect char * _read_buffer # Caches the picked event loop, so we can avoid the 30ns overhead each diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi index 94ed40fc708..1664ef7e35a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi @@ -25,10 +25,12 @@ cdef class _AsyncioSocket: self._grpc_socket = NULL self._grpc_connect_cb = NULL self._grpc_read_cb = NULL + self._grpc_write_cb = NULL self._reader = None self._writer = None self._task_connect = None self._task_read = None + self._task_write = None self._read_buffer = NULL self._server = None self._py_socket = None @@ -82,33 +84,26 @@ cdef class _AsyncioSocket: 0 ) - def _read_cb(self, future): - error = False + async def _async_read(self, size_t length): + self._task_read = None try: - buffer_ = future.result() - except Exception as e: - error = True - error_msg = "%s: %s" % (type(e), str(e)) - _LOGGER.debug(e) - finally: - self._task_read = None - - if not error: - string.memcpy( - self._read_buffer, - buffer_, - len(buffer_) - ) + inbound_buffer = await self._reader.read(n=length) + except ConnectionError as e: self._grpc_read_cb( self._grpc_socket, - len(buffer_), - 0 + -1, + grpc_socket_error("Read failed: {}".format(e).encode()) ) else: + string.memcpy( + self._read_buffer, + inbound_buffer, + len(inbound_buffer) + ) self._grpc_read_cb( self._grpc_socket, - -1, - grpc_socket_error("Read failed: {}".format(error_msg).encode()) + len(inbound_buffer), + 0 ) cdef void connect(self, @@ -127,13 +122,25 @@ cdef class _AsyncioSocket: cdef void read(self, char * buffer_, size_t length, grpc_custom_read_callback grpc_read_cb): assert not self._task_read - self._task_read = self._loop.create_task( - self._reader.read(n=length) - ) self._grpc_read_cb = grpc_read_cb - self._task_read.add_done_callback(self._read_cb) self._read_buffer = buffer_ - + self._task_read = self._loop.create_task(self._async_read(length)) + + async def _async_write(self, bytearray outbound_buffer): + self._writer.write(outbound_buffer) + self._task_write = None + try: + await self._writer.drain() + self._grpc_write_cb( + self._grpc_socket, + 0 + ) + except ConnectionError as connection_error: + self._grpc_write_cb( + self._grpc_socket, + grpc_socket_error("Socket write failed: {}".format(connection_error).encode()), + ) + cdef void write(self, grpc_slice_buffer * g_slice_buffer, grpc_custom_write_callback grpc_write_cb): """Performs write to network socket in AsyncIO. @@ -141,6 +148,7 @@ cdef class _AsyncioSocket: When the write is finished, we need to call grpc_write_cb to notify Core that the work is done. """ + assert not self._task_write cdef char* start cdef bytearray outbound_buffer = bytearray() for i in range(g_slice_buffer.count): @@ -148,11 +156,8 @@ cdef class _AsyncioSocket: length = grpc_slice_buffer_length(g_slice_buffer, i) outbound_buffer.extend(start[:length]) - self._writer.write(outbound_buffer) - grpc_write_cb( - self._grpc_socket, - 0 - ) + self._grpc_write_cb = grpc_write_cb + self._task_write = self._loop.create_task(self._async_write(outbound_buffer)) cdef bint is_connected(self): return self._reader and not self._reader._transport.is_closing() diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 8dc52b8b842..0839c79010d 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -30,11 +30,12 @@ from ._channel import Channel, UnaryUnaryMultiCallable from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall, UnaryUnaryClientInterceptor) from ._server import Server, server +from ._typing import ChannelArgumentType def insecure_channel( target: Text, - options: Optional[Sequence[Tuple[Text, Any]]] = None, + options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): """Creates an insecure asynchronous Channel to a server. @@ -58,7 +59,7 @@ def insecure_channel( def secure_channel( target: Text, credentials: grpc.ChannelCredentials, - options: Optional[list] = None, + options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): """Creates a secure asynchronous Channel to a server. diff --git a/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py b/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py index e6dea065a32..c8b6083ae39 100644 --- a/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py +++ b/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py @@ -25,6 +25,8 @@ from tests_aio.interop import methods from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' + class InteropTestCaseMixin: """Unit test methods. @@ -104,6 +106,30 @@ class InsecureLocalInteropTest(InteropTestCaseMixin, AioTestBase): await self._server.stop(None) +class SecureLocalInteropTest(InteropTestCaseMixin, AioTestBase): + + async def setUp(self): + server_credentials = grpc.ssl_server_credentials([ + (resources.private_key(), resources.certificate_chain()) + ]) + channel_credentials = grpc.ssl_channel_credentials( + resources.test_root_certificates()) + channel_options = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, + ),) + + address, self._server = await start_test_server( + secure=True, server_credentials=server_credentials) + self._channel = aio.secure_channel(address, channel_credentials, + channel_options) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + if __name__ == '__main__': - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig(level=logging.INFO) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/abort_test.py b/src/python/grpcio_tests/tests_aio/unit/abort_test.py index b5d504e419f..828b6884dfa 100644 --- a/src/python/grpcio_tests/tests_aio/unit/abort_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/abort_test.py @@ -136,6 +136,7 @@ class TestAbort(AioTestBase): with self.assertRaises(aio.AioRpcError) as exception_context: await call.read() + await call.read() rpc_error = exception_context.exception self.assertEqual(_ABORT_CODE, rpc_error.code())