From ac845a1cd0df9213409b6a41c838306d7693d2ab Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Fri, 27 Dec 2019 11:48:49 -0800 Subject: [PATCH 01/18] Fix log statement --- src/core/lib/iomgr/executor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/lib/iomgr/executor.cc b/src/core/lib/iomgr/executor.cc index 63d946c2479..9f92c9fae21 100644 --- a/src/core/lib/iomgr/executor.cc +++ b/src/core/lib/iomgr/executor.cc @@ -143,7 +143,7 @@ void Executor::SetThreading(bool threading) { if (threading) { if (curr_num_threads > 0) { - EXECUTOR_TRACE("(%s) SetThreading(true). curr_num_threads == 0", name_); + EXECUTOR_TRACE("(%s) SetThreading(true). curr_num_threads > 0", name_); return; } From f1b29deea62afeb62237d740421415cdb43a57c0 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 17 Dec 2019 17:44:52 -0800 Subject: [PATCH 02/18] Improve cancellation mechanism: * Remove the weird cancellation_future; * Convert all CancelledError into RpcError with CANCELLED; * Move part of call logic from Cython to Python layer; * Make unary-stream call based on reader API instead of async generator. --- .../grpc/_cython/_cygrpc/aio/call.pxd.pxi | 13 +- .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 107 +++++--------- .../grpc/_cython/_cygrpc/aio/channel.pyx.pxi | 37 +---- .../grpcio/grpc/experimental/aio/_call.py | 132 ++++++++++-------- .../grpcio_tests/tests_aio/unit/call_test.py | 15 +- 5 files changed, 129 insertions(+), 175 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi index 3844797c50e..b800cee6028 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi @@ -22,13 +22,12 @@ cdef class _AioCall: # time we need access to the event loop. object _loop - # Streaming call only attributes: - # - # A asyncio.Event that indicates if the status is received on the client side. - object _status_received - # A tuple of key value pairs representing the initial metadata sent by peer. - tuple _initial_metadata + # Flag indicates whether cancel being called or not. Cancellation from + # Core or peer works perfectly fine with normal procedure. However, we + # need this flag to clean up resources for cancellation from the + # application layer. Directly cancelling tasks might cause segfault + # because Core is holding a pointer for the callback handler. + bint _is_locally_cancelled cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except * cdef void _destroy_grpc_call(self) - cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index b98809a12e0..0b5c9a7589a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -33,8 +33,7 @@ cdef class _AioCall: self._grpc_call_wrapper = GrpcCallWrapper() self._loop = asyncio.get_event_loop() self._create_grpc_call(deadline, method) - - self._status_received = asyncio.Event(loop=self._loop) + self._is_locally_cancelled = False def __dealloc__(self): self._destroy_grpc_call() @@ -78,17 +77,21 @@ cdef class _AioCall: """Destroys the corresponding Core object for this RPC.""" grpc_call_unref(self._grpc_call_wrapper.call) - cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future): - """Cancels the RPC in Core, and return the final RPC status.""" - cdef AioRpcStatus status + def cancel(self, AioRpcStatus status): + """Cancels the RPC in Core with given RPC status. + + Above abstractions must invoke this method to set Core objects into + proper state. + """ + self._is_locally_cancelled = True + cdef object details cdef char *c_details cdef grpc_call_error error # Try to fetch application layer cancellation details in the future. # * If cancellation details present, cancel with status; # * If details not present, cancel with unknown reason. - if cancellation_future.done(): - status = cancellation_future.result() + if status is not None: details = str_to_bytes(status.details()) self._references.append(details) c_details = details @@ -100,7 +103,6 @@ cdef class _AioCall: NULL, ) assert error == GRPC_CALL_OK - return status else: # By implementation, grpc_call_cancel always return OK error = grpc_call_cancel(self._grpc_call_wrapper.call, NULL) @@ -111,12 +113,9 @@ cdef class _AioCall: None, None, ) - cancellation_future.set_result(status) - return status async def unary_unary(self, bytes request, - object cancellation_future, object initial_metadata_observer, object status_observer): """Performs a unary unary RPC. @@ -145,19 +144,10 @@ cdef class _AioCall: receive_initial_metadata_op, receive_message_op, receive_status_on_client_op) - try: - await execute_batch(self._grpc_call_wrapper, - ops, - self._loop) - except asyncio.CancelledError: - status = self._cancel_and_create_status(cancellation_future) - initial_metadata_observer(None) - status_observer(status) - raise - else: - initial_metadata_observer( - receive_initial_metadata_op.initial_metadata() - ) + # Executes all operations in one batch. + await execute_batch(self._grpc_call_wrapper, + ops, + self._loop) status = AioRpcStatus( receive_status_on_client_op.code(), @@ -179,6 +169,11 @@ cdef class _AioCall: cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) cdef tuple ops = (op,) await execute_batch(self._grpc_call_wrapper, ops, self._loop) + + # Halts if the RPC is locally cancelled + if self._is_locally_cancelled: + return + cdef AioRpcStatus status = AioRpcStatus( op.code(), op.details(), @@ -186,52 +181,30 @@ cdef class _AioCall: op.error_string(), ) status_observer(status) - self._status_received.set() - - def _handle_cancellation_from_application(self, - object cancellation_future, - object status_observer): - def _cancellation_action(finished_future): - if not self._status_received.set(): - status = self._cancel_and_create_status(finished_future) - status_observer(status) - self._status_received.set() - - cancellation_future.add_done_callback(_cancellation_action) - async def _message_async_generator(self): + async def receive_serialized_message(self): + """Receives one single raw message in bytes.""" cdef bytes received_message - # Infinitely receiving messages, until: + # Receives a message. Returns None when failed: # * EOF, no more messages to read; - # * The client application cancells; + # * The client application cancels; # * The server sends final status. - while True: - if self._status_received.is_set(): - return - - received_message = await _receive_message( - self._grpc_call_wrapper, - self._loop - ) - if received_message is None: - # The read operation failed, Core should explain why it fails - await self._status_received.wait() - return - else: - yield received_message + received_message = await _receive_message( + self._grpc_call_wrapper, + self._loop + ) + return received_message async def unary_stream(self, bytes request, - object cancellation_future, object initial_metadata_observer, object status_observer): - """Actual implementation of the complete unary-stream call. - - Needs to pay extra attention to the raise mechanism. If we want to - propagate the final status exception, then we have to raise it. - Othersize, it would end normally and raise `StopAsyncIteration()`. - """ + """Implementation of the start of a unary-stream call.""" + # Peer may prematurely end this RPC at any point. We need a corutine + # that watches if the server sends the final status. + self._loop.create_task(self._handle_status_once_received(status_observer)) + cdef tuple outbound_ops cdef Operation initial_metadata_op = SendInitialMetadataOperation( _EMPTY_METADATA, @@ -248,21 +221,13 @@ cdef class _AioCall: send_close_op, ) - # Actually sends out the request message. + # Sends out the request message. await execute_batch(self._grpc_call_wrapper, - outbound_ops, - self._loop) - - # Peer may prematurely end this RPC at any point. We need a mechanism - # that handles both the normal case and the error case. - self._loop.create_task(self._handle_status_once_received(status_observer)) - self._handle_cancellation_from_application(cancellation_future, - status_observer) + outbound_ops, + self._loop) # Receives initial metadata. initial_metadata_observer( await _receive_initial_metadata(self._grpc_call_wrapper, self._loop), ) - - return self._message_async_generator() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index 81b6c208619..4b6dd8c3b35 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -26,38 +26,13 @@ cdef class AioChannel: def close(self): grpc_channel_destroy(self.channel) - async def unary_unary(self, - bytes method, - bytes request, - object deadline, - object cancellation_future, - object initial_metadata_observer, - object status_observer): - """Assembles a unary-unary RPC. + def call(self, + bytes method, + object deadline): + """Assembles a Cython Call object. Returns: - The response message in bytes. + The _AioCall object. """ cdef _AioCall call = _AioCall(self, deadline, method) - return await call.unary_unary(request, - cancellation_future, - initial_metadata_observer, - status_observer) - - def unary_stream(self, - bytes method, - bytes request, - object deadline, - object cancellation_future, - object initial_metadata_observer, - object status_observer): - """Assembles a unary-stream RPC. - - Returns: - An async generator that yields raw responses. - """ - cdef _AioCall call = _AioCall(self, deadline, method) - return call.unary_stream(request, - cancellation_future, - initial_metadata_observer, - status_observer) + return call diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index be7a48157d0..9e32cbd51f0 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -15,6 +15,7 @@ import asyncio from typing import AsyncIterable, Awaitable, Dict, Optional +import logging import grpc from grpc import _common @@ -167,8 +168,7 @@ class Call(_base_call.Call): raise NotImplementedError() def cancelled(self) -> bool: - return self._cancellation.done( - ) or self._code == grpc.StatusCode.CANCELLED + return self._code == grpc.StatusCode.CANCELLED def done(self) -> bool: return self._status.done() @@ -205,14 +205,17 @@ class Call(_base_call.Call): cancellation (by application) and Core receiving status from peer. We make no promise here which one will win. """ - if self._status.done(): - return - else: - self._status.set_result(status) - self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[ - status.code()] + logging.debug('Call._set_status, %s, %s', self._status.done(), status) + # In case of the RPC finished without receiving metadata. + if not self._initial_metadata.done(): + self._initial_metadata.set_result(None) + + # Sets final status + self._status.set_result(status) + self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()] async def _raise_rpc_error_if_not_ok(self) -> None: + await self._status if self._code != grpc.StatusCode.OK: raise _create_rpc_error(await self.initial_metadata(), self._status.result()) @@ -245,12 +248,11 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): Returned when an instance of `UnaryUnaryMultiCallable` object is called. """ _request: RequestType - _deadline: Optional[float] _channel: cygrpc.AioChannel - _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction _call: asyncio.Task + _cython_call: cygrpc._AioCall def __init__(self, request: RequestType, deadline: Optional[float], channel: cygrpc.AioChannel, method: bytes, @@ -258,11 +260,10 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): response_deserializer: DeserializingFunction) -> None: super().__init__() self._request = request - self._deadline = deadline self._channel = channel - self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._cython_call = self._channel.call(method, deadline) self._call = self._loop.create_task(self._invoke()) def __del__(self) -> None: @@ -275,19 +276,20 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): serialized_request = _common.serialize(self._request, self._request_serializer) - # NOTE(lidiz) asyncio.CancelledError is not a good transport for - # status, since the Task class do not cache the exact - # asyncio.CancelledError object. So, the solution is catching the error - # in Cython layer, then cancel the RPC and update the status, finally - # re-raise the CancelledError. - serialized_response = await self._channel.unary_unary( - self._method, - serialized_request, - self._deadline, - self._cancellation, - self._set_initial_metadata, - self._set_status, - ) + # NOTE(lidiz) asyncio.CancelledError is not a good transport for status, + # because the asyncio.Task class do not cache the exception object. + # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 + try: + serialized_response = await self._cython_call.unary_unary( + serialized_request, + self._set_initial_metadata, + self._set_status, + ) + except asyncio.CancelledError: + # Only this class can inject the CancelledError into the RPC + # coroutine, so we are certain that this exception is due to local + # cancellation. + assert self._code == grpc.StatusCode.CANCELLED await self._raise_rpc_error_if_not_ok() return _common.deserialize(serialized_response, @@ -295,8 +297,8 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: """Forwards the application cancellation reasoning.""" - if not self._status.done() and not self._cancellation.done(): - self._cancellation.set_result(status) + if not self._status.done(): + self._set_status(status) self._call.cancel() return True else: @@ -328,13 +330,11 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): Returned when an instance of `UnaryStreamMultiCallable` object is called. """ _request: RequestType - _deadline: Optional[float] _channel: cygrpc.AioChannel - _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction - _call: asyncio.Task - _bytes_aiter: AsyncIterable[bytes] + _cython_call: cygrpc._AioCall + _send_unary_request_task: asyncio.Task _message_aiter: AsyncIterable[ResponseType] def __init__(self, request: RequestType, deadline: Optional[float], @@ -343,13 +343,13 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): response_deserializer: DeserializingFunction) -> None: super().__init__() self._request = request - self._deadline = deadline self._channel = channel - self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer - self._call = self._loop.create_task(self._invoke()) - self._message_aiter = self._process() + self._send_unary_request_task = self._loop.create_task( + self._send_unary_request()) + self._message_aiter = self._fetch_stream_responses() + self._cython_call = self._channel.call(method, deadline) def __del__(self) -> None: if not self._status.done(): @@ -357,32 +357,18 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, _GC_CANCELLATION_DETAILS, None, None)) - async def _invoke(self) -> ResponseType: + async def _send_unary_request(self) -> ResponseType: serialized_request = _common.serialize(self._request, self._request_serializer) + await self._cython_call.unary_stream( + serialized_request, self._set_initial_metadata, self._set_status) - self._bytes_aiter = await self._channel.unary_stream( - self._method, - serialized_request, - self._deadline, - self._cancellation, - self._set_initial_metadata, - self._set_status, - ) - - async def _process(self) -> ResponseType: - await self._call - async for serialized_response in self._bytes_aiter: - if self._cancellation.done(): - await self._status - if self._status.done(): - # Raises pre-maturely if final status received here. Generates - # more helpful stack trace for end users. - await self._raise_rpc_error_if_not_ok() - yield _common.deserialize(serialized_response, - self._response_deserializer) - - await self._raise_rpc_error_if_not_ok() + async def _fetch_stream_responses(self) -> ResponseType: + await self._send_unary_request_task + message = await self._read() + while message: + yield message + message = await self._read() def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: """Forwards the application cancellation reasoning. @@ -395,8 +381,15 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): and the client calling "cancel" at the same time, this method respects the winner in Core. """ - if not self._status.done() and not self._cancellation.done(): - self._cancellation.set_result(status) + if not self._status.done(): + self._set_status(status) + self._cython_call.cancel(status) + + if not self._send_unary_request_task.done(): + # Injects CancelledError to the Task. The exception will + # propagate to _fetch_stream_responses as well, if the sending + # is not done. + self._send_unary_request_task.cancel() return True else: return False @@ -409,8 +402,25 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): def __aiter__(self) -> AsyncIterable[ResponseType]: return self._message_aiter + async def _read(self) -> ResponseType: + serialized_response = await self._cython_call.receive_serialized_message( + ) + if serialized_response is None: + return None + else: + return _common.deserialize(serialized_response, + self._response_deserializer) + async def read(self) -> ResponseType: if self._status.done(): await self._raise_rpc_error_if_not_ok() raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) - return await self._message_aiter.__anext__() + + response_message = await self._read() + if response_message is None: + # If the read operation failed, Core should explain why. + await self._raise_rpc_error_if_not_ok() + # If everything is okay, there is something wrong internally. + assert False, 'Read operation failed with StatusCode.OK' + else: + return response_message diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 78e6dac21ae..849b5f471c8 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -126,18 +126,23 @@ class TestUnaryUnaryCall(AioTestBase): self.assertTrue(call.cancel()) self.assertFalse(call.cancel()) - with self.assertRaises(asyncio.CancelledError) as exception_context: + with self.assertRaises(grpc.RpcError) as exception_context: await call + # The info in the RpcError should match the info in Call object. + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), await call.code()) + self.assertEqual(rpc_error.details(), await call.details()) + self.assertEqual(rpc_error.trailing_metadata(), await + call.trailing_metadata()) + self.assertEqual(rpc_error.debug_error_string(), await + call.debug_error_string()) + self.assertTrue(call.cancelled()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.details(), 'Locally cancelled by application!') - # NOTE(lidiz) The CancelledError is almost always re-created, - # so we might not want to use it to transmit data. - # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 - class TestUnaryStreamCall(AioTestBase): From 65e4f17a2c5afe4c0df12867b55270a345fbc742 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 17 Dec 2019 19:29:21 -0800 Subject: [PATCH 03/18] Remove unused code --- src/python/grpcio/grpc/experimental/aio/_call.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 9e32cbd51f0..1b0ace52376 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -149,14 +149,12 @@ class Call(_base_call.Call): _code: grpc.StatusCode _status: Awaitable[cygrpc.AioRpcStatus] _initial_metadata: Awaitable[MetadataType] - _cancellation: asyncio.Future def __init__(self) -> None: self._loop = asyncio.get_event_loop() self._code = None self._status = self._loop.create_future() self._initial_metadata = self._loop.create_future() - self._cancellation = self._loop.create_future() def cancel(self) -> bool: """Placeholder cancellation method. @@ -205,7 +203,6 @@ class Call(_base_call.Call): cancellation (by application) and Core receiving status from peer. We make no promise here which one will win. """ - logging.debug('Call._set_status, %s, %s', self._status.done(), status) # In case of the RPC finished without receiving metadata. if not self._initial_metadata.done(): self._initial_metadata.set_result(None) From e8283e4818edd1e31709447e612c54a4f38ce43e Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 17 Dec 2019 19:32:06 -0800 Subject: [PATCH 04/18] Reword the comment --- src/python/grpcio/grpc/experimental/aio/_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 1b0ace52376..bd720c159ac 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -417,7 +417,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): if response_message is None: # If the read operation failed, Core should explain why. await self._raise_rpc_error_if_not_ok() - # If everything is okay, there is something wrong internally. + # If no exception raised, there is something wrong internally. assert False, 'Read operation failed with StatusCode.OK' else: return response_message From d49b0849f05bf6d98897cb8a26ffee0fc4fef877 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 18 Dec 2019 13:35:25 -0800 Subject: [PATCH 05/18] Adding more catch clauses for CancelledError --- .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 6 ++ .../grpcio/grpc/experimental/aio/_call.py | 43 ++++++--- .../grpcio_tests/tests_aio/unit/call_test.py | 92 +++++++++++++++++++ 3 files changed, 129 insertions(+), 12 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 0b5c9a7589a..a726084ec4d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -77,6 +77,11 @@ cdef class _AioCall: """Destroys the corresponding Core object for this RPC.""" grpc_call_unref(self._grpc_call_wrapper.call) + @property + def locally_cancelled(self): + """Grant Python layer access of the cancelled flag.""" + return self._is_locally_cancelled + def cancel(self, AioRpcStatus status): """Cancels the RPC in Core with given RPC status. @@ -145,6 +150,7 @@ cdef class _AioCall: receive_status_on_client_op) # Executes all operations in one batch. + # Might raise CancelledError, handling it in Python UnaryUnaryCall. await execute_batch(self._grpc_call_wrapper, ops, self._loop) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index bd720c159ac..c8fb5d18568 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -15,7 +15,6 @@ import asyncio from typing import AsyncIterable, Awaitable, Dict, Optional -import logging import grpc from grpc import _common @@ -42,6 +41,8 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' '\tdebug_error_string = "{}"\n' '>') +_EMPTY_METADATA = tuple() + class AioRpcError(grpc.RpcError): """An implementation of RpcError to be used by the asynchronous API. @@ -205,7 +206,7 @@ class Call(_base_call.Call): """ # In case of the RPC finished without receiving metadata. if not self._initial_metadata.done(): - self._initial_metadata.set_result(None) + self._initial_metadata.set_result(_EMPTY_METADATA) # Sets final status self._status.set_result(status) @@ -283,10 +284,10 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): self._set_status, ) except asyncio.CancelledError: - # Only this class can inject the CancelledError into the RPC - # coroutine, so we are certain that this exception is due to local - # cancellation. - assert self._code == grpc.StatusCode.CANCELLED + if self._code != grpc.StatusCode.CANCELLED: + self.cancel() + + # Raises RpcError here if RPC failed or cancelled await self._raise_rpc_error_if_not_ok() return _common.deserialize(serialized_response, @@ -357,8 +358,16 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): async def _send_unary_request(self) -> ResponseType: serialized_request = _common.serialize(self._request, self._request_serializer) - await self._cython_call.unary_stream( - serialized_request, self._set_initial_metadata, self._set_status) + try: + await self._cython_call.unary_stream( + serialized_request, + self._set_initial_metadata, + self._set_status + ) + except asyncio.CancelledError: + if self._code != grpc.StatusCode.CANCELLED: + self.cancel() + await self._raise_rpc_error_if_not_ok() async def _fetch_stream_responses(self) -> ResponseType: await self._send_unary_request_task @@ -400,12 +409,21 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): return self._message_aiter async def _read(self) -> ResponseType: - serialized_response = await self._cython_call.receive_serialized_message( - ) - if serialized_response is None: + # Wait for the request being sent + await self._send_unary_request_task + + # Reads response message from Core + try: + raw_response = await self._cython_call.receive_serialized_message() + except asyncio.CancelledError: + if self._code != grpc.StatusCode.CANCELLED: + self.cancel() + await self._raise_rpc_error_if_not_ok() + + if raw_response is None: return None else: - return _common.deserialize(serialized_response, + return _common.deserialize(raw_response, self._response_deserializer) async def read(self) -> ResponseType: @@ -414,6 +432,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) response_message = await self._read() + if response_message is None: # If the read operation failed, Core should explain why. await self._raise_rpc_error_if_not_ok() diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 849b5f471c8..bbca4dd3996 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -33,6 +33,8 @@ _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' _RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 _UNREACHABLE_TARGET = '0.1:1111' +_INFINITE_INTERVAL_US = 2**31-1 + class TestUnaryUnaryCall(AioTestBase): @@ -143,6 +145,29 @@ class TestUnaryUnaryCall(AioTestBase): self.assertEqual(await call.details(), 'Locally cancelled by application!') + async def test_cancel_unary_unary_in_task(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + coro_started = asyncio.Event() + call = stub.EmptyCall(messages_pb2.SimpleRequest()) + + async def another_coro(): + coro_started.set() + await call + + task = self.loop.create_task(another_coro()) + await coro_started.wait() + + self.assertFalse(task.done()) + task.cancel() + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + with self.assertRaises(grpc.RpcError) as exception_context: + await task + self.assertEqual(grpc.StatusCode.CANCELLED, + exception_context.exception.code()) + class TestUnaryStreamCall(AioTestBase): @@ -328,6 +353,73 @@ class TestUnaryStreamCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.OK) + async def test_cancel_unary_stream_in_task_using_read(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + coro_started = asyncio.Event() + + # Configs the server method to block forever + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_INFINITE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + + async def another_coro(): + coro_started.set() + await call.read() + + task = self.loop.create_task(another_coro()) + await coro_started.wait() + + self.assertFalse(task.done()) + task.cancel() + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + with self.assertRaises(grpc.RpcError) as exception_context: + await task + self.assertEqual(grpc.StatusCode.CANCELLED, + exception_context.exception.code()) + + async def test_cancel_unary_stream_in_task_using_async_for(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + coro_started = asyncio.Event() + + # Configs the server method to block forever + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_INFINITE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + + async def another_coro(): + coro_started.set() + async for _ in call: + pass + + task = self.loop.create_task(another_coro()) + await coro_started.wait() + + self.assertFalse(task.done()) + task.cancel() + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + with self.assertRaises(grpc.RpcError) as exception_context: + await task + self.assertEqual(grpc.StatusCode.CANCELLED, + exception_context.exception.code()) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) From 413d29218e06c1407fe6226cc3c41f50cd386ce7 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 18 Dec 2019 13:57:09 -0800 Subject: [PATCH 06/18] Make YAPF happy --- src/python/grpcio/grpc/experimental/aio/_call.py | 8 +++----- src/python/grpcio_tests/tests_aio/unit/call_test.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index c8fb5d18568..96415fa9521 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -359,11 +359,9 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): serialized_request = _common.serialize(self._request, self._request_serializer) try: - await self._cython_call.unary_stream( - serialized_request, - self._set_initial_metadata, - self._set_status - ) + await self._cython_call.unary_stream(serialized_request, + self._set_initial_metadata, + self._set_status) except asyncio.CancelledError: if self._code != grpc.StatusCode.CANCELLED: self.cancel() diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index bbca4dd3996..c0a7fa17017 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -33,7 +33,7 @@ _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' _RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 _UNREACHABLE_TARGET = '0.1:1111' -_INFINITE_INTERVAL_US = 2**31-1 +_INFINITE_INTERVAL_US = 2**31 - 1 class TestUnaryUnaryCall(AioTestBase): From 6f0ffef2e94f24475b6197ad920ab287ce8050e8 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 18 Dec 2019 16:15:37 -0800 Subject: [PATCH 07/18] Resolve a TODO and handle one more cancellation corner case --- .../grpcio/grpc/experimental/aio/_call.py | 21 ++++++++++--------- .../grpcio_tests/tests_aio/unit/call_test.py | 4 ---- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 96415fa9521..7acb5494b4b 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -308,16 +308,17 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): _LOCAL_CANCELLATION_DETAILS, None, None)) def __await__(self) -> ResponseType: - """Wait till the ongoing RPC request finishes. - - Returns: - Response of the RPC call. - - Raises: - RpcError: Indicating that the RPC terminated with non-OK status. - asyncio.CancelledError: Indicating that the RPC was canceled. - """ - response = yield from self._call + """Wait till the ongoing RPC request finishes.""" + try: + response = yield from self._call + except asyncio.CancelledError: + # Even if we converted all other CancelledError, there is still + # this corner case. If the application cancels immediately after + # the Call object is created, we will observe this + # `CancelledError`. + if not self.cancelled(): + self.cancel() + raise _create_rpc_error(_EMPTY_METADATA, self._status.result()) return response diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index c0a7fa17017..cecce1c79d5 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -121,10 +121,6 @@ class TestUnaryUnaryCall(AioTestBase): self.assertFalse(call.cancelled()) - # TODO(https://github.com/grpc/grpc/issues/20869) remove sleep. - # Force the loop to execute the RPC task. - await asyncio.sleep(0) - self.assertTrue(call.cancel()) self.assertFalse(call.cancel()) From a3d7733dd02054654c53e93e93781c96f8f371e1 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 19 Dec 2019 13:58:03 -0800 Subject: [PATCH 08/18] Passing cancel signal to Core for Unary Call as well --- .../grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi | 11 ----------- src/python/grpcio/grpc/experimental/aio/_call.py | 1 + 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index a726084ec4d..c10d79cb7d3 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -77,11 +77,6 @@ cdef class _AioCall: """Destroys the corresponding Core object for this RPC.""" grpc_call_unref(self._grpc_call_wrapper.call) - @property - def locally_cancelled(self): - """Grant Python layer access of the cancelled flag.""" - return self._is_locally_cancelled - def cancel(self, AioRpcStatus status): """Cancels the RPC in Core with given RPC status. @@ -112,12 +107,6 @@ cdef class _AioCall: # By implementation, grpc_call_cancel always return OK error = grpc_call_cancel(self._grpc_call_wrapper.call, NULL) assert error == GRPC_CALL_OK - status = AioRpcStatus( - StatusCode.cancelled, - _UNKNOWN_CANCELLATION_DETAILS, - None, - None, - ) async def unary_unary(self, bytes request, diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 7acb5494b4b..a8969b06edb 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -297,6 +297,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): """Forwards the application cancellation reasoning.""" if not self._status.done(): self._set_status(status) + self._cython_call.cancel(status) self._call.cancel() return True else: From 4e3d980f7038f229a09121cf885e9fcf39221d1d Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 2 Jan 2020 10:23:06 -0800 Subject: [PATCH 09/18] Convert local cancellation exception into CancelledError --- .../grpcio/grpc/experimental/aio/_call.py | 24 +++++++++----- .../grpcio_tests/tests_aio/unit/call_test.py | 33 ++++--------------- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index a8969b06edb..1557b678d66 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -150,12 +150,14 @@ class Call(_base_call.Call): _code: grpc.StatusCode _status: Awaitable[cygrpc.AioRpcStatus] _initial_metadata: Awaitable[MetadataType] + _locally_cancelled: bool def __init__(self) -> None: self._loop = asyncio.get_event_loop() self._code = None self._status = self._loop.create_future() self._initial_metadata = self._loop.create_future() + self._locally_cancelled = False def cancel(self) -> bool: """Placeholder cancellation method. @@ -204,6 +206,10 @@ class Call(_base_call.Call): cancellation (by application) and Core receiving status from peer. We make no promise here which one will win. """ + # In case of local cancellation, flip the flag. + if status.details() is _LOCAL_CANCELLATION_DETAILS: + self._locally_cancelled = True + # In case of the RPC finished without receiving metadata. if not self._initial_metadata.done(): self._initial_metadata.set_result(_EMPTY_METADATA) @@ -212,7 +218,9 @@ class Call(_base_call.Call): self._status.set_result(status) self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()] - async def _raise_rpc_error_if_not_ok(self) -> None: + async def _raise_if_not_ok(self) -> None: + if self._locally_cancelled: + raise asyncio.CancelledError() await self._status if self._code != grpc.StatusCode.OK: raise _create_rpc_error(await self.initial_metadata(), @@ -287,8 +295,8 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): if self._code != grpc.StatusCode.CANCELLED: self.cancel() - # Raises RpcError here if RPC failed or cancelled - await self._raise_rpc_error_if_not_ok() + # Raises here if RPC failed or cancelled + await self._raise_if_not_ok() return _common.deserialize(serialized_response, self._response_deserializer) @@ -319,7 +327,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): # `CancelledError`. if not self.cancelled(): self.cancel() - raise _create_rpc_error(_EMPTY_METADATA, self._status.result()) + raise return response @@ -367,7 +375,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): except asyncio.CancelledError: if self._code != grpc.StatusCode.CANCELLED: self.cancel() - await self._raise_rpc_error_if_not_ok() + raise async def _fetch_stream_responses(self) -> ResponseType: await self._send_unary_request_task @@ -418,7 +426,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): except asyncio.CancelledError: if self._code != grpc.StatusCode.CANCELLED: self.cancel() - await self._raise_rpc_error_if_not_ok() + raise if raw_response is None: return None @@ -428,14 +436,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): async def read(self) -> ResponseType: if self._status.done(): - await self._raise_rpc_error_if_not_ok() + await self._raise_if_not_ok() raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) response_message = await self._read() if response_message is None: # If the read operation failed, Core should explain why. - await self._raise_rpc_error_if_not_ok() + await self._raise_if_not_ok() # If no exception raised, there is something wrong internally. assert False, 'Read operation failed with StatusCode.OK' else: diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index cecce1c79d5..bdf2fbfda6b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -124,18 +124,10 @@ class TestUnaryUnaryCall(AioTestBase): self.assertTrue(call.cancel()) self.assertFalse(call.cancel()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await call # The info in the RpcError should match the info in Call object. - rpc_error = exception_context.exception - self.assertEqual(rpc_error.code(), await call.code()) - self.assertEqual(rpc_error.details(), await call.details()) - self.assertEqual(rpc_error.trailing_metadata(), await - call.trailing_metadata()) - self.assertEqual(rpc_error.debug_error_string(), await - call.debug_error_string()) - self.assertTrue(call.cancelled()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.details(), @@ -159,10 +151,8 @@ class TestUnaryUnaryCall(AioTestBase): self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await task - self.assertEqual(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) class TestUnaryStreamCall(AioTestBase): @@ -201,7 +191,7 @@ class TestUnaryStreamCall(AioTestBase): call.details()) self.assertFalse(call.cancel()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await call.read() self.assertTrue(call.cancelled()) @@ -232,7 +222,7 @@ class TestUnaryStreamCall(AioTestBase): self.assertFalse(call.cancel()) self.assertFalse(call.cancel()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await call.read() async def test_early_cancel_unary_stream(self): @@ -256,16 +246,11 @@ class TestUnaryStreamCall(AioTestBase): self.assertTrue(call.cancel()) self.assertFalse(call.cancel()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await call.read() self.assertTrue(call.cancelled()) - self.assertEqual(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) - self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, - exception_context.exception.details()) - self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details()) @@ -377,10 +362,8 @@ class TestUnaryStreamCall(AioTestBase): self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await task - self.assertEqual(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) async def test_cancel_unary_stream_in_task_using_async_for(self): async with aio.insecure_channel(self._server_target) as channel: @@ -411,10 +394,8 @@ class TestUnaryStreamCall(AioTestBase): self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await task - self.assertEqual(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) if __name__ == '__main__': From 9a3ddd8d76170ed9a0c040e3b0d343cd2e698a40 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Fri, 3 Jan 2020 10:28:40 -0800 Subject: [PATCH 10/18] Correct comment wording --- src/python/grpcio/grpc/experimental/aio/_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 1557b678d66..36f07ce4d92 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -321,7 +321,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): try: response = yield from self._call except asyncio.CancelledError: - # Even if we converted all other CancelledError, there is still + # Even if we caught all other CancelledError, there is still # this corner case. If the application cancels immediately after # the Call object is created, we will observe this # `CancelledError`. From 324d2e64beb31ced7682e6ddab85e7b1f740c52b Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Fri, 3 Jan 2020 12:54:21 -0800 Subject: [PATCH 11/18] Replace or with || --- src/core/lib/gpr/time_precise.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/lib/gpr/time_precise.cc b/src/core/lib/gpr/time_precise.cc index 3223a84c7ad..e40228d3d0d 100644 --- a/src/core/lib/gpr/time_precise.cc +++ b/src/core/lib/gpr/time_precise.cc @@ -31,7 +31,7 @@ #include "src/core/lib/gpr/time_precise.h" -#if GPR_CYCLE_COUNTER_RDTSC_32 or GPR_CYCLE_COUNTER_RDTSC_64 +#if GPR_CYCLE_COUNTER_RDTSC_32 || GPR_CYCLE_COUNTER_RDTSC_64 #if GPR_LINUX static bool read_freq_from_kernel(double* freq) { // Google production kernel export the frequency for us in kHz. From 2f0362ee3a370530db0c1611e4fe743e7d8c3c54 Mon Sep 17 00:00:00 2001 From: Vijay Pai Date: Sun, 5 Jan 2020 16:27:30 -0800 Subject: [PATCH 12/18] Remove unused (and defective) constructor --- include/grpcpp/impl/codegen/callback_common.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/grpcpp/impl/codegen/callback_common.h b/include/grpcpp/impl/codegen/callback_common.h index 6b4cbdec03f..aa3bd26e1c5 100644 --- a/include/grpcpp/impl/codegen/callback_common.h +++ b/include/grpcpp/impl/codegen/callback_common.h @@ -150,11 +150,6 @@ class CallbackWithSuccessTag CallbackWithSuccessTag() : call_(nullptr) {} - CallbackWithSuccessTag(grpc_call* call, std::function f, - CompletionQueueTag* ops, bool can_inline) { - Set(call, f, ops, can_inline); - } - CallbackWithSuccessTag(const CallbackWithSuccessTag&) = delete; CallbackWithSuccessTag& operator=(const CallbackWithSuccessTag&) = delete; From 5c4e28583068dcceeb56917e4d86a2cf43c13b03 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Mon, 6 Jan 2020 11:54:57 -0800 Subject: [PATCH 13/18] Use "raise_for_status" --- src/python/grpcio/grpc/experimental/aio/_call.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 36f07ce4d92..f04c6cdd761 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -218,7 +218,7 @@ class Call(_base_call.Call): self._status.set_result(status) self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()] - async def _raise_if_not_ok(self) -> None: + async def _raise_for_status(self) -> None: if self._locally_cancelled: raise asyncio.CancelledError() await self._status @@ -296,7 +296,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): self.cancel() # Raises here if RPC failed or cancelled - await self._raise_if_not_ok() + await self._raise_for_status() return _common.deserialize(serialized_response, self._response_deserializer) @@ -436,14 +436,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): async def read(self) -> ResponseType: if self._status.done(): - await self._raise_if_not_ok() + await self._raise_for_status() raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) response_message = await self._read() if response_message is None: # If the read operation failed, Core should explain why. - await self._raise_if_not_ok() + await self._raise_for_status() # If no exception raised, there is something wrong internally. assert False, 'Read operation failed with StatusCode.OK' else: From ad83e0b77af206adbecf6a6b7100520306aabef1 Mon Sep 17 00:00:00 2001 From: Richard Belleville Date: Mon, 6 Jan 2020 12:50:24 -0800 Subject: [PATCH 14/18] Clarify the set_trailing_metadata docstring --- src/python/grpcio/grpc/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 93a10644bf5..3688ac82600 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -1162,7 +1162,13 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): @abc.abstractmethod def set_trailing_metadata(self, trailing_metadata): - """Sends the trailing metadata for the RPC. + """Sets the trailing metadata for the RPC. + + Sets the trailing metadata to be sent upon completion of the RPC. + + If this method is invoked multiple times throughout the lifetime of an + RPC, the value supplied in the final invocation will be the value sent + over the wire. This method need not be called by implementations if they have no metadata to add to what the gRPC runtime will transmit. From 4bb124f54f62436c38dac1f217e4d7f26a3e4a76 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Mon, 6 Jan 2020 14:24:40 -0800 Subject: [PATCH 15/18] Make yapf_code capable of making in-place changes --- tools/distrib/yapf_code.sh | 5 ++++- tools/run_tests/sanity/sanity_tests.yaml | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/distrib/yapf_code.sh b/tools/distrib/yapf_code.sh index c377e3d7f98..4c9639ba8dc 100755 --- a/tools/distrib/yapf_code.sh +++ b/tools/distrib/yapf_code.sh @@ -15,6 +15,9 @@ set -ex +ACTION=${1:---in-place} +[[ $ACTION == '--in-place' ]] || [[ $ACTION == '--diff' ]] + # change to root directory cd "$(dirname "${0}")/../.." @@ -33,4 +36,4 @@ PYTHON=${VIRTUALENV}/bin/python "$PYTHON" -m pip install --upgrade futures "$PYTHON" -m pip install yapf==0.28.0 -$PYTHON -m yapf --diff --recursive --style=setup.cfg "${DIRS[@]}" +$PYTHON -m yapf $ACTION --recursive --style=setup.cfg "${DIRS[@]}" diff --git a/tools/run_tests/sanity/sanity_tests.yaml b/tools/run_tests/sanity/sanity_tests.yaml index e5579470074..59a9c240dc5 100644 --- a/tools/run_tests/sanity/sanity_tests.yaml +++ b/tools/run_tests/sanity/sanity_tests.yaml @@ -25,7 +25,7 @@ - script: tools/distrib/clang_tidy_code.sh - script: tools/distrib/pylint_code.sh - script: tools/distrib/python/check_grpcio_tools.py -- script: tools/distrib/yapf_code.sh +- script: tools/distrib/yapf_code.sh --diff cpu_cost: 1000 - script: tools/distrib/check_protobuf_pod_version.sh - script: tools/distrib/check_shadow_boringssl_symbol_list.sh From f50c5a025cb98a313360ebc3ad842970ca8eed85 Mon Sep 17 00:00:00 2001 From: yang-g Date: Mon, 6 Jan 2020 16:33:08 -0800 Subject: [PATCH 16/18] Revert "Merge pull request #21494 from grpc/revert-21478-max_message_size" This reverts commit 2e4ebd7478c58d119210b5b68e929e7098282f9c, reversing changes made to 1bd6fcc3388ad35047d7c0b28eb2bff862f276b3. --- include/grpcpp/server_impl.h | 7 ++----- src/cpp/server/server_builder.cc | 29 ++++++++++++++--------------- src/cpp/server/server_cc.cc | 12 ++++++++---- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/include/grpcpp/server_impl.h b/include/grpcpp/server_impl.h index 5cc7f595d05..9506c419018 100644 --- a/include/grpcpp/server_impl.h +++ b/include/grpcpp/server_impl.h @@ -163,9 +163,6 @@ class Server : public grpc::ServerInterface, private grpc::GrpcLibraryCodegen { /// /// Server constructors. To be used by \a ServerBuilder only. /// - /// \param max_message_size Maximum message length that the channel can - /// receive. - /// /// \param args The channel args /// /// \param sync_server_cqs The completion queues to use if the server is a @@ -182,7 +179,7 @@ class Server : public grpc::ServerInterface, private grpc::GrpcLibraryCodegen { /// /// \param sync_cq_timeout_msec The timeout to use when calling AsyncNext() on /// server completion queues passed via sync_server_cqs param. - Server(int max_message_size, ChannelArguments* args, + Server(ChannelArguments* args, std::shared_ptr>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, @@ -306,7 +303,7 @@ class Server : public grpc::ServerInterface, private grpc::GrpcLibraryCodegen { std::unique_ptr> interceptor_creators_; - const int max_receive_message_size_; + int max_receive_message_size_; /// The following completion queues are ONLY used in case of Sync API /// i.e. if the server has any services with sync methods. The server uses diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index c058a75dc03..8acfe536270 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -218,20 +218,9 @@ ServerBuilder& ServerBuilder::AddListeningPort( std::unique_ptr ServerBuilder::BuildAndStart() { grpc::ChannelArguments args; - for (const auto& option : options_) { - option->UpdateArguments(&args); - option->UpdatePlugins(&plugins_); - } - - for (const auto& plugin : plugins_) { - plugin->UpdateServerBuilder(this); - plugin->UpdateChannelArguments(&args); - } - if (max_receive_message_size_ >= -1) { args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, max_receive_message_size_); } - // The default message size is -1 (max), so no need to explicitly set it for // -1. if (max_send_message_size_ >= 0) { @@ -254,6 +243,16 @@ std::unique_ptr ServerBuilder::BuildAndStart() { grpc_resource_quota_arg_vtable()); } + for (const auto& option : options_) { + option->UpdateArguments(&args); + option->UpdatePlugins(&plugins_); + } + + for (const auto& plugin : plugins_) { + plugin->UpdateServerBuilder(this); + plugin->UpdateChannelArguments(&args); + } + // == Determine if the server has any syncrhonous methods == bool has_sync_methods = false; for (const auto& value : services_) { @@ -332,10 +331,10 @@ std::unique_ptr ServerBuilder::BuildAndStart() { } std::unique_ptr server(new grpc::Server( - max_receive_message_size_, &args, sync_server_cqs, - sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, - sync_server_settings_.cq_timeout_msec, std::move(acceptors_), - resource_quota_, std::move(interceptor_creators_))); + &args, sync_server_cqs, sync_server_settings_.min_pollers, + sync_server_settings_.max_pollers, sync_server_settings_.cq_timeout_msec, + std::move(acceptors_), resource_quota_, + std::move(interceptor_creators_))); grpc_impl::ServerInitializer* initializer = server->initializer(); diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index ef4245b0e57..56f189cedaa 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -15,6 +15,7 @@ * */ +#include #include #include @@ -23,6 +24,7 @@ #include #include +#include #include #include #include @@ -964,7 +966,7 @@ class Server::SyncRequestThreadManager : public grpc::ThreadManager { static grpc::internal::GrpcLibraryInitializer g_gli_initializer; Server::Server( - int max_receive_message_size, grpc::ChannelArguments* args, + grpc::ChannelArguments* args, std::shared_ptr>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, @@ -976,7 +978,7 @@ Server::Server( interceptor_creators) : acceptors_(std::move(acceptors)), interceptor_creators_(std::move(interceptor_creators)), - max_receive_message_size_(max_receive_message_size), + max_receive_message_size_(-1), sync_server_cqs_(std::move(sync_server_cqs)), started_(false), shutdown_(false), @@ -1026,10 +1028,12 @@ Server::Server( static_cast( channel_args.args[i].value.pointer.p)); } - break; + } + if (0 == + strcmp(channel_args.args[i].key, GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH)) { + max_receive_message_size_ = channel_args.args[i].value.integer; } } - server_ = grpc_server_create(&channel_args, nullptr); } From 138ce26bb54999ddf88281d9514509b4a4b2411b Mon Sep 17 00:00:00 2001 From: yang-g Date: Mon, 6 Jan 2020 16:40:08 -0800 Subject: [PATCH 17/18] Fixes --- src/cpp/server/server_builder.cc | 21 ++++++++++++++++----- src/cpp/server/server_cc.cc | 3 +-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index 8acfe536270..71f17da0a4b 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -26,6 +26,7 @@ #include +#include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/gpr/useful.h" #include "src/cpp/server/external_connection_acceptor_impl.h" @@ -218,7 +219,22 @@ ServerBuilder& ServerBuilder::AddListeningPort( std::unique_ptr ServerBuilder::BuildAndStart() { grpc::ChannelArguments args; + + for (const auto& option : options_) { + option->UpdateArguments(&args); + option->UpdatePlugins(&plugins_); + } if (max_receive_message_size_ >= -1) { + grpc_channel_args c_args = args.c_channel_args(); + const grpc_arg* arg = + grpc_channel_args_find(&c_args, GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH); + // Some option has set max_receive_message_length and it is also set + // directly on the ServerBuilder. + if (arg != nullptr) { + gpr_log( + GPR_ERROR, + "gRPC ServerBuilder receives multiple max_receive_message_length"); + } args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, max_receive_message_size_); } // The default message size is -1 (max), so no need to explicitly set it for @@ -243,11 +259,6 @@ std::unique_ptr ServerBuilder::BuildAndStart() { grpc_resource_quota_arg_vtable()); } - for (const auto& option : options_) { - option->UpdateArguments(&args); - option->UpdatePlugins(&plugins_); - } - for (const auto& plugin : plugins_) { plugin->UpdateServerBuilder(this); plugin->UpdateChannelArguments(&args); diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 56f189cedaa..5367fb25ebb 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -15,7 +15,6 @@ * */ -#include #include #include @@ -978,7 +977,7 @@ Server::Server( interceptor_creators) : acceptors_(std::move(acceptors)), interceptor_creators_(std::move(interceptor_creators)), - max_receive_message_size_(-1), + max_receive_message_size_(INT_MIN), sync_server_cqs_(std::move(sync_server_cqs)), started_(false), shutdown_(false), From 48f026d90ece794eb718d7749e0b54b83ef76feb Mon Sep 17 00:00:00 2001 From: Arjun Roy Date: Wed, 30 Oct 2019 17:40:53 -0700 Subject: [PATCH 18/18] gRPC TCP Transmit-side Zerocopy. Implements TCP Tx-side zerocopy. Must be explicitly enabled to use. For large RPCs (>= 16KiB) it reduces the amount of CPU time spent since it avoids a userspace to kernel data buffer copy. However, there is a tradeoff - the application must process a callback on the socket error queue placed by the kernel, informing the application that the data buffer can be freed since the kernel is done. The cost of processing the error queue means that we do not have an advantage for small RPCs. --- include/grpc/impl/codegen/grpc_types.h | 14 + .../lib/iomgr/socket_utils_common_posix.cc | 14 + src/core/lib/iomgr/socket_utils_posix.h | 12 + src/core/lib/iomgr/tcp_posix.cc | 663 ++++++++++++++++-- .../iomgr/tcp_server_utils_posix_common.cc | 8 + 5 files changed, 655 insertions(+), 56 deletions(-) diff --git a/include/grpc/impl/codegen/grpc_types.h b/include/grpc/impl/codegen/grpc_types.h index 836441f8948..89fd15faf1a 100644 --- a/include/grpc/impl/codegen/grpc_types.h +++ b/include/grpc/impl/codegen/grpc_types.h @@ -323,6 +323,20 @@ typedef struct { "grpc.experimental.tcp_min_read_chunk_size" #define GRPC_ARG_TCP_MAX_READ_CHUNK_SIZE \ "grpc.experimental.tcp_max_read_chunk_size" +/* TCP TX Zerocopy enable state: zero is disabled, non-zero is enabled. By + default, it is disabled. */ +#define GRPC_ARG_TCP_TX_ZEROCOPY_ENABLED \ + "grpc.experimental.tcp_tx_zerocopy_enabled" +/* TCP TX Zerocopy send threshold: only zerocopy if >= this many bytes sent. By + default, this is set to 16KB. */ +#define GRPC_ARG_TCP_TX_ZEROCOPY_SEND_BYTES_THRESHOLD \ + "grpc.experimental.tcp_tx_zerocopy_send_bytes_threshold" +/* TCP TX Zerocopy max simultaneous sends: limit for maximum number of pending + calls to tcp_write() using zerocopy. A tcp_write() is considered pending + until the kernel performs the zerocopy-done callback for all sendmsg() calls + issued by the tcp_write(). By default, this is set to 4. */ +#define GRPC_ARG_TCP_TX_ZEROCOPY_MAX_SIMULT_SENDS \ + "grpc.experimental.tcp_tx_zerocopy_max_simultaneous_sends" /* Timeout in milliseconds to use for calls to the grpclb load balancer. If 0 or unset, the balancer calls will have no deadline. */ #define GRPC_ARG_GRPCLB_CALL_TIMEOUT_MS "grpc.grpclb_call_timeout_ms" diff --git a/src/core/lib/iomgr/socket_utils_common_posix.cc b/src/core/lib/iomgr/socket_utils_common_posix.cc index f46cbd51c88..3974ae7dec2 100644 --- a/src/core/lib/iomgr/socket_utils_common_posix.cc +++ b/src/core/lib/iomgr/socket_utils_common_posix.cc @@ -50,6 +50,20 @@ #include "src/core/lib/iomgr/sockaddr.h" #include "src/core/lib/iomgr/sockaddr_utils.h" +/* set a socket to use zerocopy */ +grpc_error* grpc_set_socket_zerocopy(int fd) { +#ifdef GRPC_LINUX_ERRQUEUE + const int enable = 1; + auto err = setsockopt(fd, SOL_SOCKET, SO_ZEROCOPY, &enable, sizeof(enable)); + if (err != 0) { + return GRPC_OS_ERROR(errno, "setsockopt(SO_ZEROCOPY)"); + } + return GRPC_ERROR_NONE; +#else + return GRPC_OS_ERROR(ENOSYS, "setsockopt(SO_ZEROCOPY)"); +#endif +} + /* set a socket to non blocking mode */ grpc_error* grpc_set_socket_nonblocking(int fd, int non_blocking) { int oldflags = fcntl(fd, F_GETFL, 0); diff --git a/src/core/lib/iomgr/socket_utils_posix.h b/src/core/lib/iomgr/socket_utils_posix.h index a708a7a0ed5..734d340a953 100644 --- a/src/core/lib/iomgr/socket_utils_posix.h +++ b/src/core/lib/iomgr/socket_utils_posix.h @@ -31,10 +31,22 @@ #include "src/core/lib/iomgr/socket_factory_posix.h" #include "src/core/lib/iomgr/socket_mutator.h" +#ifdef GRPC_LINUX_ERRQUEUE +#ifndef SO_ZEROCOPY +#define SO_ZEROCOPY 60 +#endif +#ifndef SO_EE_ORIGIN_ZEROCOPY +#define SO_EE_ORIGIN_ZEROCOPY 5 +#endif +#endif /* ifdef GRPC_LINUX_ERRQUEUE */ + /* a wrapper for accept or accept4 */ int grpc_accept4(int sockfd, grpc_resolved_address* resolved_addr, int nonblock, int cloexec); +/* set a socket to use zerocopy */ +grpc_error* grpc_set_socket_zerocopy(int fd); + /* set a socket to non blocking mode */ grpc_error* grpc_set_socket_nonblocking(int fd, int non_blocking); diff --git a/src/core/lib/iomgr/tcp_posix.cc b/src/core/lib/iomgr/tcp_posix.cc index 668a0c805e8..c96031183b3 100644 --- a/src/core/lib/iomgr/tcp_posix.cc +++ b/src/core/lib/iomgr/tcp_posix.cc @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -49,9 +50,11 @@ #include "src/core/lib/debug/trace.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/sync.h" #include "src/core/lib/iomgr/buffer_list.h" #include "src/core/lib/iomgr/ev_posix.h" #include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" #include "src/core/lib/profiling/timers.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/slice/slice_string_helpers.h" @@ -71,6 +74,15 @@ #define SENDMSG_FLAGS 0 #endif +// TCP zero copy sendmsg flag. +// NB: We define this here as a fallback in case we're using an older set of +// library headers that has not defined MSG_ZEROCOPY. Since this constant is +// part of the kernel, we are guaranteed it will never change/disagree so +// defining it here is safe. +#ifndef MSG_ZEROCOPY +#define MSG_ZEROCOPY 0x4000000 +#endif + #ifdef GRPC_MSG_IOVLEN_TYPE typedef GRPC_MSG_IOVLEN_TYPE msg_iovlen_type; #else @@ -79,6 +91,264 @@ typedef size_t msg_iovlen_type; extern grpc_core::TraceFlag grpc_tcp_trace; +namespace grpc_core { + +class TcpZerocopySendRecord { + public: + TcpZerocopySendRecord() { grpc_slice_buffer_init(&buf_); } + + ~TcpZerocopySendRecord() { + AssertEmpty(); + grpc_slice_buffer_destroy_internal(&buf_); + } + + // Given the slices that we wish to send, and the current offset into the + // slice buffer (indicating which have already been sent), populate an iovec + // array that will be used for a zerocopy enabled sendmsg(). + msg_iovlen_type PopulateIovs(size_t* unwind_slice_idx, + size_t* unwind_byte_idx, size_t* sending_length, + iovec* iov); + + // A sendmsg() may not be able to send the bytes that we requested at this + // time, returning EAGAIN (possibly due to backpressure). In this case, + // unwind the offset into the slice buffer so we retry sending these bytes. + void UnwindIfThrottled(size_t unwind_slice_idx, size_t unwind_byte_idx) { + out_offset_.byte_idx = unwind_byte_idx; + out_offset_.slice_idx = unwind_slice_idx; + } + + // Update the offset into the slice buffer based on how much we wanted to sent + // vs. what sendmsg() actually sent (which may be lower, possibly due to + // backpressure). + void UpdateOffsetForBytesSent(size_t sending_length, size_t actually_sent); + + // Indicates whether all underlying data has been sent or not. + bool AllSlicesSent() { return out_offset_.slice_idx == buf_.count; } + + // Reset this structure for a new tcp_write() with zerocopy. + void PrepareForSends(grpc_slice_buffer* slices_to_send) { + AssertEmpty(); + out_offset_.slice_idx = 0; + out_offset_.byte_idx = 0; + grpc_slice_buffer_swap(slices_to_send, &buf_); + Ref(); + } + + // References: 1 reference per sendmsg(), and 1 for the tcp_write(). + void Ref() { ref_.FetchAdd(1, MemoryOrder::RELAXED); } + + // Unref: called when we get an error queue notification for a sendmsg(), if a + // sendmsg() failed or when tcp_write() is done. + bool Unref() { + const intptr_t prior = ref_.FetchSub(1, MemoryOrder::ACQ_REL); + GPR_DEBUG_ASSERT(prior > 0); + if (prior == 1) { + AllSendsComplete(); + return true; + } + return false; + } + + private: + struct OutgoingOffset { + size_t slice_idx = 0; + size_t byte_idx = 0; + }; + + void AssertEmpty() { + GPR_DEBUG_ASSERT(buf_.count == 0); + GPR_DEBUG_ASSERT(buf_.length == 0); + GPR_DEBUG_ASSERT(ref_.Load(MemoryOrder::RELAXED) == 0); + } + + // When all sendmsg() calls associated with this tcp_write() have been + // completed (ie. we have received the notifications for each sequence number + // for each sendmsg()) and all reference counts have been dropped, drop our + // reference to the underlying data since we no longer need it. + void AllSendsComplete() { + GPR_DEBUG_ASSERT(ref_.Load(MemoryOrder::RELAXED) == 0); + grpc_slice_buffer_reset_and_unref_internal(&buf_); + } + + grpc_slice_buffer buf_; + Atomic ref_; + OutgoingOffset out_offset_; +}; + +class TcpZerocopySendCtx { + public: + static constexpr int kDefaultMaxSends = 4; + static constexpr size_t kDefaultSendBytesThreshold = 16 * 1024; // 16KB + + TcpZerocopySendCtx(int max_sends = kDefaultMaxSends, + size_t send_bytes_threshold = kDefaultSendBytesThreshold) + : max_sends_(max_sends), + free_send_records_size_(max_sends), + threshold_bytes_(send_bytes_threshold) { + send_records_ = static_cast( + gpr_malloc(max_sends * sizeof(*send_records_))); + free_send_records_ = static_cast( + gpr_malloc(max_sends * sizeof(*free_send_records_))); + if (send_records_ == nullptr || free_send_records_ == nullptr) { + gpr_free(send_records_); + gpr_free(free_send_records_); + gpr_log(GPR_INFO, "Disabling TCP TX zerocopy due to memory pressure.\n"); + memory_limited_ = true; + } else { + for (int idx = 0; idx < max_sends_; ++idx) { + new (send_records_ + idx) TcpZerocopySendRecord(); + free_send_records_[idx] = send_records_ + idx; + } + } + } + + ~TcpZerocopySendCtx() { + if (send_records_ != nullptr) { + for (int idx = 0; idx < max_sends_; ++idx) { + send_records_[idx].~TcpZerocopySendRecord(); + } + } + gpr_free(send_records_); + gpr_free(free_send_records_); + } + + // True if we were unable to allocate the various bookkeeping structures at + // transport initialization time. If memory limited, we do not zerocopy. + bool memory_limited() const { return memory_limited_; } + + // TCP send zerocopy maintains an implicit sequence number for every + // successful sendmsg() with zerocopy enabled; the kernel later gives us an + // error queue notification with this sequence number indicating that the + // underlying data buffers that we sent can now be released. Once that + // notification is received, we can release the buffers associated with this + // zerocopy send record. Here, we associate the sequence number with the data + // buffers that were sent with the corresponding call to sendmsg(). + void NoteSend(TcpZerocopySendRecord* record) { + record->Ref(); + AssociateSeqWithSendRecord(last_send_, record); + ++last_send_; + } + + // If sendmsg() actually failed, though, we need to revert the sequence number + // that we speculatively bumped before calling sendmsg(). Note that we bump + // this sequence number and perform relevant bookkeeping (see: NoteSend()) + // *before* calling sendmsg() since, if we called it *after* sendmsg(), then + // there is a possible race with the release notification which could occur on + // another thread before we do the necessary bookkeeping. Hence, calling + // NoteSend() *before* sendmsg() and implementing an undo function is needed. + void UndoSend() { + --last_send_; + if (ReleaseSendRecord(last_send_)->Unref()) { + // We should still be holding the ref taken by tcp_write(). + GPR_DEBUG_ASSERT(0); + } + } + + // Simply associate this send record (and the underlying sent data buffers) + // with the implicit sequence number for this zerocopy sendmsg(). + void AssociateSeqWithSendRecord(uint32_t seq, TcpZerocopySendRecord* record) { + MutexLock guard(&lock_); + ctx_lookup_.emplace(seq, record); + } + + // Get a send record for a send that we wish to do with zerocopy. + TcpZerocopySendRecord* GetSendRecord() { + MutexLock guard(&lock_); + return TryGetSendRecordLocked(); + } + + // A given send record corresponds to a single tcp_write() with zerocopy + // enabled. This can result in several sendmsg() calls to flush all of the + // data to wire. Each sendmsg() takes a reference on the + // TcpZerocopySendRecord, and corresponds to a single sequence number. + // ReleaseSendRecord releases a reference on TcpZerocopySendRecord for a + // single sequence number. This is called either when we receive the relevant + // error queue notification (saying that we can discard the underlying + // buffers for this sendmsg()) is received from the kernel - or, in case + // sendmsg() was unsuccessful to begin with. + TcpZerocopySendRecord* ReleaseSendRecord(uint32_t seq) { + MutexLock guard(&lock_); + return ReleaseSendRecordLocked(seq); + } + + // After all the references to a TcpZerocopySendRecord are released, we can + // add it back to the pool (of size max_sends_). Note that we can only have + // max_sends_ tcp_write() instances with zerocopy enabled in flight at the + // same time. + void PutSendRecord(TcpZerocopySendRecord* record) { + GPR_DEBUG_ASSERT(record >= send_records_ && + record < send_records_ + max_sends_); + MutexLock guard(&lock_); + PutSendRecordLocked(record); + } + + // Indicate that we are disposing of this zerocopy context. This indicator + // will prevent new zerocopy writes from being issued. + void Shutdown() { shutdown_.Store(true, MemoryOrder::RELEASE); } + + // Indicates that there are no inflight tcp_write() instances with zerocopy + // enabled. + bool AllSendRecordsEmpty() { + MutexLock guard(&lock_); + return free_send_records_size_ == max_sends_; + } + + bool enabled() const { return enabled_; } + + void set_enabled(bool enabled) { + GPR_DEBUG_ASSERT(!enabled || !memory_limited()); + enabled_ = enabled; + } + + // Only use zerocopy if we are sending at least this many bytes. The + // additional overhead of reading the error queue for notifications means that + // zerocopy is not useful for small transfers. + size_t threshold_bytes() const { return threshold_bytes_; } + + private: + TcpZerocopySendRecord* ReleaseSendRecordLocked(uint32_t seq) { + auto iter = ctx_lookup_.find(seq); + GPR_DEBUG_ASSERT(iter != ctx_lookup_.end()); + TcpZerocopySendRecord* record = iter->second; + ctx_lookup_.erase(iter); + return record; + } + + TcpZerocopySendRecord* TryGetSendRecordLocked() { + if (shutdown_.Load(MemoryOrder::ACQUIRE)) { + return nullptr; + } + if (free_send_records_size_ == 0) { + return nullptr; + } + free_send_records_size_--; + return free_send_records_[free_send_records_size_]; + } + + void PutSendRecordLocked(TcpZerocopySendRecord* record) { + GPR_DEBUG_ASSERT(free_send_records_size_ < max_sends_); + free_send_records_[free_send_records_size_] = record; + free_send_records_size_++; + } + + TcpZerocopySendRecord* send_records_; + TcpZerocopySendRecord** free_send_records_; + int max_sends_; + int free_send_records_size_; + Mutex lock_; + uint32_t last_send_ = 0; + Atomic shutdown_; + bool enabled_ = false; + size_t threshold_bytes_ = kDefaultSendBytesThreshold; + std::unordered_map ctx_lookup_; + bool memory_limited_ = false; +}; + +} // namespace grpc_core + +using grpc_core::TcpZerocopySendCtx; +using grpc_core::TcpZerocopySendRecord; + namespace { struct grpc_tcp { grpc_endpoint base; @@ -142,6 +412,8 @@ struct grpc_tcp { bool ts_capable; /* Cache whether we can set timestamping options */ gpr_atm stop_error_notification; /* Set to 1 if we do not want to be notified on errors anymore */ + TcpZerocopySendCtx tcp_zerocopy_send_ctx; + TcpZerocopySendRecord* current_zerocopy_send = nullptr; }; struct backup_poller { @@ -151,6 +423,8 @@ struct backup_poller { } // namespace +static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* tcp); + #define BACKUP_POLLER_POLLSET(b) ((grpc_pollset*)((b) + 1)) static gpr_atm g_uncovered_notifications_pending; @@ -339,6 +613,7 @@ static void tcp_handle_write(void* arg /* grpc_tcp */, grpc_error* error); static void tcp_shutdown(grpc_endpoint* ep, grpc_error* why) { grpc_tcp* tcp = reinterpret_cast(ep); + ZerocopyDisableAndWaitForRemaining(tcp); grpc_fd_shutdown(tcp->em_fd, why); grpc_resource_user_shutdown(tcp->resource_user); } @@ -357,6 +632,7 @@ static void tcp_free(grpc_tcp* tcp) { gpr_mu_unlock(&tcp->tb_mu); tcp->outgoing_buffer_arg = nullptr; gpr_mu_destroy(&tcp->tb_mu); + tcp->tcp_zerocopy_send_ctx.~TcpZerocopySendCtx(); gpr_free(tcp); } @@ -390,6 +666,7 @@ static void tcp_destroy(grpc_endpoint* ep) { grpc_tcp* tcp = reinterpret_cast(ep); grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); if (grpc_event_engine_can_track_errors()) { + ZerocopyDisableAndWaitForRemaining(tcp); gpr_atm_no_barrier_store(&tcp->stop_error_notification, true); grpc_fd_set_error(tcp->em_fd); } @@ -652,13 +929,13 @@ static void tcp_read(grpc_endpoint* ep, grpc_slice_buffer* incoming_buffer, /* A wrapper around sendmsg. It sends \a msg over \a fd and returns the number * of bytes sent. */ -ssize_t tcp_send(int fd, const struct msghdr* msg) { +ssize_t tcp_send(int fd, const struct msghdr* msg, int additional_flags = 0) { GPR_TIMER_SCOPE("sendmsg", 1); ssize_t sent_length; do { /* TODO(klempner): Cork if this is a partial write */ GRPC_STATS_INC_SYSCALL_WRITE(); - sent_length = sendmsg(fd, msg, SENDMSG_FLAGS); + sent_length = sendmsg(fd, msg, SENDMSG_FLAGS | additional_flags); } while (sent_length < 0 && errno == EINTR); return sent_length; } @@ -671,16 +948,52 @@ ssize_t tcp_send(int fd, const struct msghdr* msg) { */ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length); + ssize_t* sent_length, + int additional_flags = 0); /** The callback function to be invoked when we get an error on the socket. */ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error); +static TcpZerocopySendRecord* tcp_get_send_zerocopy_record( + grpc_tcp* tcp, grpc_slice_buffer* buf); + #ifdef GRPC_LINUX_ERRQUEUE +static bool process_errors(grpc_tcp* tcp); + +static TcpZerocopySendRecord* tcp_get_send_zerocopy_record( + grpc_tcp* tcp, grpc_slice_buffer* buf) { + TcpZerocopySendRecord* zerocopy_send_record = nullptr; + const bool use_zerocopy = + tcp->tcp_zerocopy_send_ctx.enabled() && + tcp->tcp_zerocopy_send_ctx.threshold_bytes() < buf->length; + if (use_zerocopy) { + zerocopy_send_record = tcp->tcp_zerocopy_send_ctx.GetSendRecord(); + if (zerocopy_send_record == nullptr) { + process_errors(tcp); + zerocopy_send_record = tcp->tcp_zerocopy_send_ctx.GetSendRecord(); + } + if (zerocopy_send_record != nullptr) { + zerocopy_send_record->PrepareForSends(buf); + GPR_DEBUG_ASSERT(buf->count == 0); + GPR_DEBUG_ASSERT(buf->length == 0); + tcp->outgoing_byte_idx = 0; + tcp->outgoing_buffer = nullptr; + } + } + return zerocopy_send_record; +} + +static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* tcp) { + tcp->tcp_zerocopy_send_ctx.Shutdown(); + while (!tcp->tcp_zerocopy_send_ctx.AllSendRecordsEmpty()) { + process_errors(tcp); + } +} static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length) { + ssize_t* sent_length, + int additional_flags) { if (!tcp->socket_ts_enabled) { uint32_t opt = grpc_core::kTimestampingSocketOptions; if (setsockopt(tcp->fd, SOL_SOCKET, SO_TIMESTAMPING, @@ -708,7 +1021,7 @@ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, msg->msg_controllen = CMSG_SPACE(sizeof(uint32_t)); /* If there was an error on sendmsg the logic in tcp_flush will handle it. */ - ssize_t length = tcp_send(tcp->fd, msg); + ssize_t length = tcp_send(tcp->fd, msg, additional_flags); *sent_length = length; /* Only save timestamps if all the bytes were taken by sendmsg. */ if (sending_length == static_cast(length)) { @@ -722,6 +1035,43 @@ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, return true; } +static void UnrefMaybePutZerocopySendRecord(grpc_tcp* tcp, + TcpZerocopySendRecord* record, + uint32_t seq, const char* tag); +// Reads \a cmsg to process zerocopy control messages. +static void process_zerocopy(grpc_tcp* tcp, struct cmsghdr* cmsg) { + GPR_DEBUG_ASSERT(cmsg); + auto serr = reinterpret_cast(CMSG_DATA(cmsg)); + GPR_DEBUG_ASSERT(serr->ee_errno == 0); + GPR_DEBUG_ASSERT(serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY); + const uint32_t lo = serr->ee_info; + const uint32_t hi = serr->ee_data; + for (uint32_t seq = lo; seq <= hi; ++seq) { + // TODO(arjunroy): It's likely that lo and hi refer to zerocopy sequence + // numbers that are generated by a single call to grpc_endpoint_write; ie. + // we can batch the unref operation. So, check if record is the same for + // both; if so, batch the unref/put. + TcpZerocopySendRecord* record = + tcp->tcp_zerocopy_send_ctx.ReleaseSendRecord(seq); + GPR_DEBUG_ASSERT(record); + UnrefMaybePutZerocopySendRecord(tcp, record, seq, "CALLBACK RCVD"); + } +} + +// Whether the cmsg received from error queue is of the IPv4 or IPv6 levels. +static bool CmsgIsIpLevel(const cmsghdr& cmsg) { + return (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR) || + (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR); +} + +static bool CmsgIsZeroCopy(const cmsghdr& cmsg) { + if (!CmsgIsIpLevel(cmsg)) { + return false; + } + auto serr = reinterpret_cast CMSG_DATA(&cmsg); + return serr->ee_errno == 0 && serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY; +} + /** Reads \a cmsg to derive timestamps from the control messages. If a valid * timestamp is found, the traced buffer list is updated with this timestamp. * The caller of this function should be looping on the control messages found @@ -783,73 +1133,76 @@ struct cmsghdr* process_timestamp(grpc_tcp* tcp, msghdr* msg, /** For linux platforms, reads the socket's error queue and processes error * messages from the queue. */ -static void process_errors(grpc_tcp* tcp) { +static bool process_errors(grpc_tcp* tcp) { + bool processed_err = false; + struct iovec iov; + iov.iov_base = nullptr; + iov.iov_len = 0; + struct msghdr msg; + msg.msg_name = nullptr; + msg.msg_namelen = 0; + msg.msg_iov = &iov; + msg.msg_iovlen = 0; + msg.msg_flags = 0; + /* Allocate enough space so we don't need to keep increasing this as size + * of OPT_STATS increase */ + constexpr size_t cmsg_alloc_space = + CMSG_SPACE(sizeof(grpc_core::scm_timestamping)) + + CMSG_SPACE(sizeof(sock_extended_err) + sizeof(sockaddr_in)) + + CMSG_SPACE(32 * NLA_ALIGN(NLA_HDRLEN + sizeof(uint64_t))); + /* Allocate aligned space for cmsgs received along with timestamps */ + union { + char rbuf[cmsg_alloc_space]; + struct cmsghdr align; + } aligned_buf; + msg.msg_control = aligned_buf.rbuf; + msg.msg_controllen = sizeof(aligned_buf.rbuf); + int r, saved_errno; while (true) { - struct iovec iov; - iov.iov_base = nullptr; - iov.iov_len = 0; - struct msghdr msg; - msg.msg_name = nullptr; - msg.msg_namelen = 0; - msg.msg_iov = &iov; - msg.msg_iovlen = 0; - msg.msg_flags = 0; - - /* Allocate enough space so we don't need to keep increasing this as size - * of OPT_STATS increase */ - constexpr size_t cmsg_alloc_space = - CMSG_SPACE(sizeof(grpc_core::scm_timestamping)) + - CMSG_SPACE(sizeof(sock_extended_err) + sizeof(sockaddr_in)) + - CMSG_SPACE(32 * NLA_ALIGN(NLA_HDRLEN + sizeof(uint64_t))); - /* Allocate aligned space for cmsgs received along with timestamps */ - union { - char rbuf[cmsg_alloc_space]; - struct cmsghdr align; - } aligned_buf; - memset(&aligned_buf, 0, sizeof(aligned_buf)); - - msg.msg_control = aligned_buf.rbuf; - msg.msg_controllen = sizeof(aligned_buf.rbuf); - - int r, saved_errno; do { r = recvmsg(tcp->fd, &msg, MSG_ERRQUEUE); saved_errno = errno; } while (r < 0 && saved_errno == EINTR); if (r == -1 && saved_errno == EAGAIN) { - return; /* No more errors to process */ + return processed_err; /* No more errors to process */ } if (r == -1) { - return; + return processed_err; } - if ((msg.msg_flags & MSG_CTRUNC) != 0) { + if (GPR_UNLIKELY((msg.msg_flags & MSG_CTRUNC) != 0)) { gpr_log(GPR_ERROR, "Error message was truncated."); } if (msg.msg_controllen == 0) { /* There was no control message found. It was probably spurious. */ - return; + return processed_err; } bool seen = false; for (auto cmsg = CMSG_FIRSTHDR(&msg); cmsg && cmsg->cmsg_len; cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if (cmsg->cmsg_level != SOL_SOCKET || - cmsg->cmsg_type != SCM_TIMESTAMPING) { - /* Got a control message that is not a timestamp. Don't know how to - * handle this. */ + if (CmsgIsZeroCopy(*cmsg)) { + process_zerocopy(tcp, cmsg); + seen = true; + processed_err = true; + } else if (cmsg->cmsg_level == SOL_SOCKET && + cmsg->cmsg_type == SCM_TIMESTAMPING) { + cmsg = process_timestamp(tcp, &msg, cmsg); + seen = true; + processed_err = true; + } else { + /* Got a control message that is not a timestamp or zerocopy. Don't know + * how to handle this. */ if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { gpr_log(GPR_INFO, "unknown control message cmsg_level:%d cmsg_type:%d", cmsg->cmsg_level, cmsg->cmsg_type); } - return; + return processed_err; } - cmsg = process_timestamp(tcp, &msg, cmsg); - seen = true; } if (!seen) { - return; + return processed_err; } } } @@ -870,18 +1223,28 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error) { /* We are still interested in collecting timestamps, so let's try reading * them. */ - process_errors(tcp); + bool processed = process_errors(tcp); /* This might not a timestamps error. Set the read and write closures to be * ready. */ - grpc_fd_set_readable(tcp->em_fd); - grpc_fd_set_writable(tcp->em_fd); + if (!processed) { + grpc_fd_set_readable(tcp->em_fd); + grpc_fd_set_writable(tcp->em_fd); + } grpc_fd_notify_on_error(tcp->em_fd, &tcp->error_closure); } #else /* GRPC_LINUX_ERRQUEUE */ +static TcpZerocopySendRecord* tcp_get_send_zerocopy_record( + grpc_tcp* tcp, grpc_slice_buffer* buf) { + return nullptr; +} + +static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* tcp) {} + static bool tcp_write_with_timestamps(grpc_tcp* /*tcp*/, struct msghdr* /*msg*/, size_t /*sending_length*/, - ssize_t* /*sent_length*/) { + ssize_t* /*sent_length*/, + int /*additional_flags*/) { gpr_log(GPR_ERROR, "Write with timestamps not supported for this platform"); GPR_ASSERT(0); return false; @@ -907,12 +1270,138 @@ void tcp_shutdown_buffer_list(grpc_tcp* tcp) { } } -/* returns true if done, false if pending; if returning true, *error is set */ #if defined(IOV_MAX) && IOV_MAX < 1000 #define MAX_WRITE_IOVEC IOV_MAX #else #define MAX_WRITE_IOVEC 1000 #endif +msg_iovlen_type TcpZerocopySendRecord::PopulateIovs(size_t* unwind_slice_idx, + size_t* unwind_byte_idx, + size_t* sending_length, + iovec* iov) { + msg_iovlen_type iov_size; + *unwind_slice_idx = out_offset_.slice_idx; + *unwind_byte_idx = out_offset_.byte_idx; + for (iov_size = 0; + out_offset_.slice_idx != buf_.count && iov_size != MAX_WRITE_IOVEC; + iov_size++) { + iov[iov_size].iov_base = + GRPC_SLICE_START_PTR(buf_.slices[out_offset_.slice_idx]) + + out_offset_.byte_idx; + iov[iov_size].iov_len = + GRPC_SLICE_LENGTH(buf_.slices[out_offset_.slice_idx]) - + out_offset_.byte_idx; + *sending_length += iov[iov_size].iov_len; + ++(out_offset_.slice_idx); + out_offset_.byte_idx = 0; + } + GPR_DEBUG_ASSERT(iov_size > 0); + return iov_size; +} + +void TcpZerocopySendRecord::UpdateOffsetForBytesSent(size_t sending_length, + size_t actually_sent) { + size_t trailing = sending_length - actually_sent; + while (trailing > 0) { + size_t slice_length; + out_offset_.slice_idx--; + slice_length = GRPC_SLICE_LENGTH(buf_.slices[out_offset_.slice_idx]); + if (slice_length > trailing) { + out_offset_.byte_idx = slice_length - trailing; + break; + } else { + trailing -= slice_length; + } + } +} + +// returns true if done, false if pending; if returning true, *error is set +static bool do_tcp_flush_zerocopy(grpc_tcp* tcp, TcpZerocopySendRecord* record, + grpc_error** error) { + struct msghdr msg; + struct iovec iov[MAX_WRITE_IOVEC]; + msg_iovlen_type iov_size; + ssize_t sent_length = 0; + size_t sending_length; + size_t unwind_slice_idx; + size_t unwind_byte_idx; + while (true) { + sending_length = 0; + iov_size = record->PopulateIovs(&unwind_slice_idx, &unwind_byte_idx, + &sending_length, iov); + msg.msg_name = nullptr; + msg.msg_namelen = 0; + msg.msg_iov = iov; + msg.msg_iovlen = iov_size; + msg.msg_flags = 0; + bool tried_sending_message = false; + // Before calling sendmsg (with or without timestamps): we + // take a single ref on the zerocopy send record. + tcp->tcp_zerocopy_send_ctx.NoteSend(record); + if (tcp->outgoing_buffer_arg != nullptr) { + if (!tcp->ts_capable || + !tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length, + MSG_ZEROCOPY)) { + /* We could not set socket options to collect Fathom timestamps. + * Fallback on writing without timestamps. */ + tcp->ts_capable = false; + tcp_shutdown_buffer_list(tcp); + } else { + tried_sending_message = true; + } + } + if (!tried_sending_message) { + msg.msg_control = nullptr; + msg.msg_controllen = 0; + GRPC_STATS_INC_TCP_WRITE_SIZE(sending_length); + GRPC_STATS_INC_TCP_WRITE_IOV_SIZE(iov_size); + sent_length = tcp_send(tcp->fd, &msg, MSG_ZEROCOPY); + } + if (sent_length < 0) { + // If this particular send failed, drop ref taken earlier in this method. + tcp->tcp_zerocopy_send_ctx.UndoSend(); + if (errno == EAGAIN) { + record->UnwindIfThrottled(unwind_slice_idx, unwind_byte_idx); + return false; + } else if (errno == EPIPE) { + *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp); + tcp_shutdown_buffer_list(tcp); + return true; + } else { + *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp); + tcp_shutdown_buffer_list(tcp); + return true; + } + } + tcp->bytes_counter += sent_length; + record->UpdateOffsetForBytesSent(sending_length, + static_cast(sent_length)); + if (record->AllSlicesSent()) { + *error = GRPC_ERROR_NONE; + return true; + } + } +} + +static void UnrefMaybePutZerocopySendRecord(grpc_tcp* tcp, + TcpZerocopySendRecord* record, + uint32_t seq, const char* tag) { + if (record->Unref()) { + tcp->tcp_zerocopy_send_ctx.PutSendRecord(record); + } +} + +static bool tcp_flush_zerocopy(grpc_tcp* tcp, TcpZerocopySendRecord* record, + grpc_error** error) { + bool done = do_tcp_flush_zerocopy(tcp, record, error); + if (done) { + // Either we encountered an error, or we successfully sent all the bytes. + // In either case, we're done with this record. + UnrefMaybePutZerocopySendRecord(tcp, record, 0, "flush_done"); + } + return done; +} + static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { struct msghdr msg; struct iovec iov[MAX_WRITE_IOVEC]; @@ -927,7 +1416,7 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { // buffer as we write size_t outgoing_slice_idx = 0; - for (;;) { + while (true) { sending_length = 0; unwind_slice_idx = outgoing_slice_idx; unwind_byte_idx = tcp->outgoing_byte_idx; @@ -1027,12 +1516,21 @@ static void tcp_handle_write(void* arg /* grpc_tcp */, grpc_error* error) { if (error != GRPC_ERROR_NONE) { cb = tcp->write_cb; tcp->write_cb = nullptr; + if (tcp->current_zerocopy_send != nullptr) { + UnrefMaybePutZerocopySendRecord(tcp, tcp->current_zerocopy_send, 0, + "handle_write_err"); + tcp->current_zerocopy_send = nullptr; + } grpc_core::Closure::Run(DEBUG_LOCATION, cb, GRPC_ERROR_REF(error)); TCP_UNREF(tcp, "write"); return; } - if (!tcp_flush(tcp, &error)) { + bool flush_result = + tcp->current_zerocopy_send != nullptr + ? tcp_flush_zerocopy(tcp, tcp->current_zerocopy_send, &error) + : tcp_flush(tcp, &error); + if (!flush_result) { if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { gpr_log(GPR_INFO, "write: delayed"); } @@ -1042,6 +1540,7 @@ static void tcp_handle_write(void* arg /* grpc_tcp */, grpc_error* error) { } else { cb = tcp->write_cb; tcp->write_cb = nullptr; + tcp->current_zerocopy_send = nullptr; if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { const char* str = grpc_error_string(error); gpr_log(GPR_INFO, "write: %s", str); @@ -1057,6 +1556,7 @@ static void tcp_write(grpc_endpoint* ep, grpc_slice_buffer* buf, GPR_TIMER_SCOPE("tcp_write", 0); grpc_tcp* tcp = reinterpret_cast(ep); grpc_error* error = GRPC_ERROR_NONE; + TcpZerocopySendRecord* zerocopy_send_record = nullptr; if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { size_t i; @@ -1073,8 +1573,8 @@ static void tcp_write(grpc_endpoint* ep, grpc_slice_buffer* buf, } GPR_ASSERT(tcp->write_cb == nullptr); + GPR_DEBUG_ASSERT(tcp->current_zerocopy_send == nullptr); - tcp->outgoing_buffer_arg = arg; if (buf->length == 0) { grpc_core::Closure::Run( DEBUG_LOCATION, cb, @@ -1085,15 +1585,26 @@ static void tcp_write(grpc_endpoint* ep, grpc_slice_buffer* buf, tcp_shutdown_buffer_list(tcp); return; } - tcp->outgoing_buffer = buf; - tcp->outgoing_byte_idx = 0; + + zerocopy_send_record = tcp_get_send_zerocopy_record(tcp, buf); + if (zerocopy_send_record == nullptr) { + // Either not enough bytes, or couldn't allocate a zerocopy context. + tcp->outgoing_buffer = buf; + tcp->outgoing_byte_idx = 0; + } + tcp->outgoing_buffer_arg = arg; if (arg) { GPR_ASSERT(grpc_event_engine_can_track_errors()); } - if (!tcp_flush(tcp, &error)) { + bool flush_result = + zerocopy_send_record != nullptr + ? tcp_flush_zerocopy(tcp, zerocopy_send_record, &error) + : tcp_flush(tcp, &error); + if (!flush_result) { TCP_REF(tcp, "write"); tcp->write_cb = cb; + tcp->current_zerocopy_send = zerocopy_send_record; if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { gpr_log(GPR_INFO, "write: delayed"); } @@ -1121,6 +1632,7 @@ static void tcp_add_to_pollset_set(grpc_endpoint* ep, static void tcp_delete_from_pollset_set(grpc_endpoint* ep, grpc_pollset_set* pollset_set) { grpc_tcp* tcp = reinterpret_cast(ep); + ZerocopyDisableAndWaitForRemaining(tcp); grpc_pollset_set_del_fd(pollset_set, tcp->em_fd); } @@ -1172,9 +1684,15 @@ static const grpc_endpoint_vtable vtable = {tcp_read, grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, const grpc_channel_args* channel_args, const char* peer_string) { + static constexpr bool kZerocpTxEnabledDefault = false; int tcp_read_chunk_size = GRPC_TCP_DEFAULT_READ_SLICE_SIZE; int tcp_max_read_chunk_size = 4 * 1024 * 1024; int tcp_min_read_chunk_size = 256; + bool tcp_tx_zerocopy_enabled = kZerocpTxEnabledDefault; + int tcp_tx_zerocopy_send_bytes_thresh = + grpc_core::TcpZerocopySendCtx::kDefaultSendBytesThreshold; + int tcp_tx_zerocopy_max_simult_sends = + grpc_core::TcpZerocopySendCtx::kDefaultMaxSends; grpc_resource_quota* resource_quota = grpc_resource_quota_create(nullptr); if (channel_args != nullptr) { for (size_t i = 0; i < channel_args->num_args; i++) { @@ -1199,6 +1717,23 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, resource_quota = grpc_resource_quota_ref_internal(static_cast( channel_args->args[i].value.pointer.p)); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_TCP_TX_ZEROCOPY_ENABLED)) { + tcp_tx_zerocopy_enabled = grpc_channel_arg_get_bool( + &channel_args->args[i], kZerocpTxEnabledDefault); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_TCP_TX_ZEROCOPY_SEND_BYTES_THRESHOLD)) { + grpc_integer_options options = { + grpc_core::TcpZerocopySendCtx::kDefaultSendBytesThreshold, 0, + INT_MAX}; + tcp_tx_zerocopy_send_bytes_thresh = + grpc_channel_arg_get_integer(&channel_args->args[i], options); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_TCP_TX_ZEROCOPY_MAX_SIMULT_SENDS)) { + grpc_integer_options options = { + grpc_core::TcpZerocopySendCtx::kDefaultMaxSends, 0, INT_MAX}; + tcp_tx_zerocopy_max_simult_sends = + grpc_channel_arg_get_integer(&channel_args->args[i], options); } } } @@ -1215,6 +1750,7 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, tcp->fd = grpc_fd_wrapped_fd(em_fd); tcp->read_cb = nullptr; tcp->write_cb = nullptr; + tcp->current_zerocopy_send = nullptr; tcp->release_fd_cb = nullptr; tcp->release_fd = nullptr; tcp->incoming_buffer = nullptr; @@ -1228,6 +1764,20 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, tcp->socket_ts_enabled = false; tcp->ts_capable = true; tcp->outgoing_buffer_arg = nullptr; + new (&tcp->tcp_zerocopy_send_ctx) TcpZerocopySendCtx( + tcp_tx_zerocopy_max_simult_sends, tcp_tx_zerocopy_send_bytes_thresh); + if (tcp_tx_zerocopy_enabled && !tcp->tcp_zerocopy_send_ctx.memory_limited()) { +#ifdef GRPC_LINUX_ERRQUEUE + const int enable = 1; + auto err = + setsockopt(tcp->fd, SOL_SOCKET, SO_ZEROCOPY, &enable, sizeof(enable)); + if (err == 0) { + tcp->tcp_zerocopy_send_ctx.set_enabled(true); + } else { + gpr_log(GPR_ERROR, "Failed to set zerocopy options on the socket."); + } +#endif + } /* paired with unref in grpc_tcp_destroy */ new (&tcp->refcount) grpc_core::RefCount(1, &grpc_tcp_trace); gpr_atm_no_barrier_store(&tcp->shutdown_count, 0); @@ -1294,6 +1844,7 @@ void grpc_tcp_destroy_and_release_fd(grpc_endpoint* ep, int* fd, grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); if (grpc_event_engine_can_track_errors()) { /* Stop errors notification. */ + ZerocopyDisableAndWaitForRemaining(tcp); gpr_atm_no_barrier_store(&tcp->stop_error_notification, true); grpc_fd_set_error(tcp->em_fd); } diff --git a/src/core/lib/iomgr/tcp_server_utils_posix_common.cc b/src/core/lib/iomgr/tcp_server_utils_posix_common.cc index ee1cd5c1027..da18cc39c51 100644 --- a/src/core/lib/iomgr/tcp_server_utils_posix_common.cc +++ b/src/core/lib/iomgr/tcp_server_utils_posix_common.cc @@ -157,6 +157,14 @@ grpc_error* grpc_tcp_server_prepare_socket(grpc_tcp_server* s, int fd, if (err != GRPC_ERROR_NONE) goto error; } +#ifdef GRPC_LINUX_ERRQUEUE + err = grpc_set_socket_zerocopy(fd); + if (err != GRPC_ERROR_NONE) { + /* it's not fatal, so just log it. */ + gpr_log(GPR_DEBUG, "Node does not support SO_ZEROCOPY, continuing."); + GRPC_ERROR_UNREF(err); + } +#endif err = grpc_set_socket_nonblocking(fd, 1); if (err != GRPC_ERROR_NONE) goto error; err = grpc_set_socket_cloexec(fd, 1);