Merge remote-tracking branch 'upstream/master' into grpc_security_level_negotiation

pull/21215/head
Yihua Zhang 5 years ago
commit 62005e4290
  1. 14
      include/grpc/impl/codegen/grpc_types.h
  2. 5
      include/grpcpp/impl/codegen/callback_common.h
  3. 7
      include/grpcpp/server_impl.h
  4. 2
      src/core/lib/gpr/time_precise.cc
  5. 2
      src/core/lib/iomgr/executor.cc
  6. 14
      src/core/lib/iomgr/socket_utils_common_posix.cc
  7. 12
      src/core/lib/iomgr/socket_utils_posix.h
  8. 663
      src/core/lib/iomgr/tcp_posix.cc
  9. 8
      src/core/lib/iomgr/tcp_server_utils_posix_common.cc
  10. 32
      src/cpp/server/server_builder.cc
  11. 11
      src/cpp/server/server_cc.cc
  12. 8
      src/python/grpcio/grpc/__init__.py
  13. 13
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  14. 114
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  15. 37
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  16. 188
      src/python/grpcio/grpc/experimental/aio/_call.py
  17. 106
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  18. 5
      tools/distrib/yapf_code.sh
  19. 2
      tools/run_tests/sanity/sanity_tests.yaml

@ -323,6 +323,20 @@ typedef struct {
"grpc.experimental.tcp_min_read_chunk_size" "grpc.experimental.tcp_min_read_chunk_size"
#define GRPC_ARG_TCP_MAX_READ_CHUNK_SIZE \ #define GRPC_ARG_TCP_MAX_READ_CHUNK_SIZE \
"grpc.experimental.tcp_max_read_chunk_size" "grpc.experimental.tcp_max_read_chunk_size"
/* TCP TX Zerocopy enable state: zero is disabled, non-zero is enabled. By
default, it is disabled. */
#define GRPC_ARG_TCP_TX_ZEROCOPY_ENABLED \
"grpc.experimental.tcp_tx_zerocopy_enabled"
/* TCP TX Zerocopy send threshold: only zerocopy if >= this many bytes sent. By
default, this is set to 16KB. */
#define GRPC_ARG_TCP_TX_ZEROCOPY_SEND_BYTES_THRESHOLD \
"grpc.experimental.tcp_tx_zerocopy_send_bytes_threshold"
/* TCP TX Zerocopy max simultaneous sends: limit for maximum number of pending
calls to tcp_write() using zerocopy. A tcp_write() is considered pending
until the kernel performs the zerocopy-done callback for all sendmsg() calls
issued by the tcp_write(). By default, this is set to 4. */
#define GRPC_ARG_TCP_TX_ZEROCOPY_MAX_SIMULT_SENDS \
"grpc.experimental.tcp_tx_zerocopy_max_simultaneous_sends"
/* Timeout in milliseconds to use for calls to the grpclb load balancer. /* Timeout in milliseconds to use for calls to the grpclb load balancer.
If 0 or unset, the balancer calls will have no deadline. */ If 0 or unset, the balancer calls will have no deadline. */
#define GRPC_ARG_GRPCLB_CALL_TIMEOUT_MS "grpc.grpclb_call_timeout_ms" #define GRPC_ARG_GRPCLB_CALL_TIMEOUT_MS "grpc.grpclb_call_timeout_ms"

@ -150,11 +150,6 @@ class CallbackWithSuccessTag
CallbackWithSuccessTag() : call_(nullptr) {} CallbackWithSuccessTag() : call_(nullptr) {}
CallbackWithSuccessTag(grpc_call* call, std::function<void(bool)> f,
CompletionQueueTag* ops, bool can_inline) {
Set(call, f, ops, can_inline);
}
CallbackWithSuccessTag(const CallbackWithSuccessTag&) = delete; CallbackWithSuccessTag(const CallbackWithSuccessTag&) = delete;
CallbackWithSuccessTag& operator=(const CallbackWithSuccessTag&) = delete; CallbackWithSuccessTag& operator=(const CallbackWithSuccessTag&) = delete;

@ -163,9 +163,6 @@ class Server : public grpc::ServerInterface, private grpc::GrpcLibraryCodegen {
/// ///
/// Server constructors. To be used by \a ServerBuilder only. /// Server constructors. To be used by \a ServerBuilder only.
/// ///
/// \param max_message_size Maximum message length that the channel can
/// receive.
///
/// \param args The channel args /// \param args The channel args
/// ///
/// \param sync_server_cqs The completion queues to use if the server is a /// \param sync_server_cqs The completion queues to use if the server is a
@ -182,7 +179,7 @@ class Server : public grpc::ServerInterface, private grpc::GrpcLibraryCodegen {
/// ///
/// \param sync_cq_timeout_msec The timeout to use when calling AsyncNext() on /// \param sync_cq_timeout_msec The timeout to use when calling AsyncNext() on
/// server completion queues passed via sync_server_cqs param. /// server completion queues passed via sync_server_cqs param.
Server(int max_message_size, ChannelArguments* args, Server(ChannelArguments* args,
std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>> std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>>
sync_server_cqs, sync_server_cqs,
int min_pollers, int max_pollers, int sync_cq_timeout_msec, int min_pollers, int max_pollers, int sync_cq_timeout_msec,
@ -306,7 +303,7 @@ class Server : public grpc::ServerInterface, private grpc::GrpcLibraryCodegen {
std::unique_ptr<grpc::experimental::ServerInterceptorFactoryInterface>> std::unique_ptr<grpc::experimental::ServerInterceptorFactoryInterface>>
interceptor_creators_; interceptor_creators_;
const int max_receive_message_size_; int max_receive_message_size_;
/// The following completion queues are ONLY used in case of Sync API /// The following completion queues are ONLY used in case of Sync API
/// i.e. if the server has any services with sync methods. The server uses /// i.e. if the server has any services with sync methods. The server uses

@ -31,7 +31,7 @@
#include "src/core/lib/gpr/time_precise.h" #include "src/core/lib/gpr/time_precise.h"
#if GPR_CYCLE_COUNTER_RDTSC_32 or GPR_CYCLE_COUNTER_RDTSC_64 #if GPR_CYCLE_COUNTER_RDTSC_32 || GPR_CYCLE_COUNTER_RDTSC_64
#if GPR_LINUX #if GPR_LINUX
static bool read_freq_from_kernel(double* freq) { static bool read_freq_from_kernel(double* freq) {
// Google production kernel export the frequency for us in kHz. // Google production kernel export the frequency for us in kHz.

@ -143,7 +143,7 @@ void Executor::SetThreading(bool threading) {
if (threading) { if (threading) {
if (curr_num_threads > 0) { if (curr_num_threads > 0) {
EXECUTOR_TRACE("(%s) SetThreading(true). curr_num_threads == 0", name_); EXECUTOR_TRACE("(%s) SetThreading(true). curr_num_threads > 0", name_);
return; return;
} }

@ -50,6 +50,20 @@
#include "src/core/lib/iomgr/sockaddr.h" #include "src/core/lib/iomgr/sockaddr.h"
#include "src/core/lib/iomgr/sockaddr_utils.h" #include "src/core/lib/iomgr/sockaddr_utils.h"
/* set a socket to use zerocopy */
grpc_error* grpc_set_socket_zerocopy(int fd) {
#ifdef GRPC_LINUX_ERRQUEUE
const int enable = 1;
auto err = setsockopt(fd, SOL_SOCKET, SO_ZEROCOPY, &enable, sizeof(enable));
if (err != 0) {
return GRPC_OS_ERROR(errno, "setsockopt(SO_ZEROCOPY)");
}
return GRPC_ERROR_NONE;
#else
return GRPC_OS_ERROR(ENOSYS, "setsockopt(SO_ZEROCOPY)");
#endif
}
/* set a socket to non blocking mode */ /* set a socket to non blocking mode */
grpc_error* grpc_set_socket_nonblocking(int fd, int non_blocking) { grpc_error* grpc_set_socket_nonblocking(int fd, int non_blocking) {
int oldflags = fcntl(fd, F_GETFL, 0); int oldflags = fcntl(fd, F_GETFL, 0);

@ -31,10 +31,22 @@
#include "src/core/lib/iomgr/socket_factory_posix.h" #include "src/core/lib/iomgr/socket_factory_posix.h"
#include "src/core/lib/iomgr/socket_mutator.h" #include "src/core/lib/iomgr/socket_mutator.h"
#ifdef GRPC_LINUX_ERRQUEUE
#ifndef SO_ZEROCOPY
#define SO_ZEROCOPY 60
#endif
#ifndef SO_EE_ORIGIN_ZEROCOPY
#define SO_EE_ORIGIN_ZEROCOPY 5
#endif
#endif /* ifdef GRPC_LINUX_ERRQUEUE */
/* a wrapper for accept or accept4 */ /* a wrapper for accept or accept4 */
int grpc_accept4(int sockfd, grpc_resolved_address* resolved_addr, int nonblock, int grpc_accept4(int sockfd, grpc_resolved_address* resolved_addr, int nonblock,
int cloexec); int cloexec);
/* set a socket to use zerocopy */
grpc_error* grpc_set_socket_zerocopy(int fd);
/* set a socket to non blocking mode */ /* set a socket to non blocking mode */
grpc_error* grpc_set_socket_nonblocking(int fd, int non_blocking); grpc_error* grpc_set_socket_nonblocking(int fd, int non_blocking);

@ -36,6 +36,7 @@
#include <sys/types.h> #include <sys/types.h>
#include <unistd.h> #include <unistd.h>
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <grpc/slice.h> #include <grpc/slice.h>
#include <grpc/support/alloc.h> #include <grpc/support/alloc.h>
@ -49,9 +50,11 @@
#include "src/core/lib/debug/trace.h" #include "src/core/lib/debug/trace.h"
#include "src/core/lib/gpr/string.h" #include "src/core/lib/gpr/string.h"
#include "src/core/lib/gpr/useful.h" #include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/iomgr/buffer_list.h" #include "src/core/lib/iomgr/buffer_list.h"
#include "src/core/lib/iomgr/ev_posix.h" #include "src/core/lib/iomgr/ev_posix.h"
#include "src/core/lib/iomgr/executor.h" #include "src/core/lib/iomgr/executor.h"
#include "src/core/lib/iomgr/socket_utils_posix.h"
#include "src/core/lib/profiling/timers.h" #include "src/core/lib/profiling/timers.h"
#include "src/core/lib/slice/slice_internal.h" #include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/slice/slice_string_helpers.h" #include "src/core/lib/slice/slice_string_helpers.h"
@ -71,6 +74,15 @@
#define SENDMSG_FLAGS 0 #define SENDMSG_FLAGS 0
#endif #endif
// TCP zero copy sendmsg flag.
// NB: We define this here as a fallback in case we're using an older set of
// library headers that has not defined MSG_ZEROCOPY. Since this constant is
// part of the kernel, we are guaranteed it will never change/disagree so
// defining it here is safe.
#ifndef MSG_ZEROCOPY
#define MSG_ZEROCOPY 0x4000000
#endif
#ifdef GRPC_MSG_IOVLEN_TYPE #ifdef GRPC_MSG_IOVLEN_TYPE
typedef GRPC_MSG_IOVLEN_TYPE msg_iovlen_type; typedef GRPC_MSG_IOVLEN_TYPE msg_iovlen_type;
#else #else
@ -79,6 +91,264 @@ typedef size_t msg_iovlen_type;
extern grpc_core::TraceFlag grpc_tcp_trace; extern grpc_core::TraceFlag grpc_tcp_trace;
namespace grpc_core {
class TcpZerocopySendRecord {
public:
TcpZerocopySendRecord() { grpc_slice_buffer_init(&buf_); }
~TcpZerocopySendRecord() {
AssertEmpty();
grpc_slice_buffer_destroy_internal(&buf_);
}
// Given the slices that we wish to send, and the current offset into the
// slice buffer (indicating which have already been sent), populate an iovec
// array that will be used for a zerocopy enabled sendmsg().
msg_iovlen_type PopulateIovs(size_t* unwind_slice_idx,
size_t* unwind_byte_idx, size_t* sending_length,
iovec* iov);
// A sendmsg() may not be able to send the bytes that we requested at this
// time, returning EAGAIN (possibly due to backpressure). In this case,
// unwind the offset into the slice buffer so we retry sending these bytes.
void UnwindIfThrottled(size_t unwind_slice_idx, size_t unwind_byte_idx) {
out_offset_.byte_idx = unwind_byte_idx;
out_offset_.slice_idx = unwind_slice_idx;
}
// Update the offset into the slice buffer based on how much we wanted to sent
// vs. what sendmsg() actually sent (which may be lower, possibly due to
// backpressure).
void UpdateOffsetForBytesSent(size_t sending_length, size_t actually_sent);
// Indicates whether all underlying data has been sent or not.
bool AllSlicesSent() { return out_offset_.slice_idx == buf_.count; }
// Reset this structure for a new tcp_write() with zerocopy.
void PrepareForSends(grpc_slice_buffer* slices_to_send) {
AssertEmpty();
out_offset_.slice_idx = 0;
out_offset_.byte_idx = 0;
grpc_slice_buffer_swap(slices_to_send, &buf_);
Ref();
}
// References: 1 reference per sendmsg(), and 1 for the tcp_write().
void Ref() { ref_.FetchAdd(1, MemoryOrder::RELAXED); }
// Unref: called when we get an error queue notification for a sendmsg(), if a
// sendmsg() failed or when tcp_write() is done.
bool Unref() {
const intptr_t prior = ref_.FetchSub(1, MemoryOrder::ACQ_REL);
GPR_DEBUG_ASSERT(prior > 0);
if (prior == 1) {
AllSendsComplete();
return true;
}
return false;
}
private:
struct OutgoingOffset {
size_t slice_idx = 0;
size_t byte_idx = 0;
};
void AssertEmpty() {
GPR_DEBUG_ASSERT(buf_.count == 0);
GPR_DEBUG_ASSERT(buf_.length == 0);
GPR_DEBUG_ASSERT(ref_.Load(MemoryOrder::RELAXED) == 0);
}
// When all sendmsg() calls associated with this tcp_write() have been
// completed (ie. we have received the notifications for each sequence number
// for each sendmsg()) and all reference counts have been dropped, drop our
// reference to the underlying data since we no longer need it.
void AllSendsComplete() {
GPR_DEBUG_ASSERT(ref_.Load(MemoryOrder::RELAXED) == 0);
grpc_slice_buffer_reset_and_unref_internal(&buf_);
}
grpc_slice_buffer buf_;
Atomic<intptr_t> ref_;
OutgoingOffset out_offset_;
};
class TcpZerocopySendCtx {
public:
static constexpr int kDefaultMaxSends = 4;
static constexpr size_t kDefaultSendBytesThreshold = 16 * 1024; // 16KB
TcpZerocopySendCtx(int max_sends = kDefaultMaxSends,
size_t send_bytes_threshold = kDefaultSendBytesThreshold)
: max_sends_(max_sends),
free_send_records_size_(max_sends),
threshold_bytes_(send_bytes_threshold) {
send_records_ = static_cast<TcpZerocopySendRecord*>(
gpr_malloc(max_sends * sizeof(*send_records_)));
free_send_records_ = static_cast<TcpZerocopySendRecord**>(
gpr_malloc(max_sends * sizeof(*free_send_records_)));
if (send_records_ == nullptr || free_send_records_ == nullptr) {
gpr_free(send_records_);
gpr_free(free_send_records_);
gpr_log(GPR_INFO, "Disabling TCP TX zerocopy due to memory pressure.\n");
memory_limited_ = true;
} else {
for (int idx = 0; idx < max_sends_; ++idx) {
new (send_records_ + idx) TcpZerocopySendRecord();
free_send_records_[idx] = send_records_ + idx;
}
}
}
~TcpZerocopySendCtx() {
if (send_records_ != nullptr) {
for (int idx = 0; idx < max_sends_; ++idx) {
send_records_[idx].~TcpZerocopySendRecord();
}
}
gpr_free(send_records_);
gpr_free(free_send_records_);
}
// True if we were unable to allocate the various bookkeeping structures at
// transport initialization time. If memory limited, we do not zerocopy.
bool memory_limited() const { return memory_limited_; }
// TCP send zerocopy maintains an implicit sequence number for every
// successful sendmsg() with zerocopy enabled; the kernel later gives us an
// error queue notification with this sequence number indicating that the
// underlying data buffers that we sent can now be released. Once that
// notification is received, we can release the buffers associated with this
// zerocopy send record. Here, we associate the sequence number with the data
// buffers that were sent with the corresponding call to sendmsg().
void NoteSend(TcpZerocopySendRecord* record) {
record->Ref();
AssociateSeqWithSendRecord(last_send_, record);
++last_send_;
}
// If sendmsg() actually failed, though, we need to revert the sequence number
// that we speculatively bumped before calling sendmsg(). Note that we bump
// this sequence number and perform relevant bookkeeping (see: NoteSend())
// *before* calling sendmsg() since, if we called it *after* sendmsg(), then
// there is a possible race with the release notification which could occur on
// another thread before we do the necessary bookkeeping. Hence, calling
// NoteSend() *before* sendmsg() and implementing an undo function is needed.
void UndoSend() {
--last_send_;
if (ReleaseSendRecord(last_send_)->Unref()) {
// We should still be holding the ref taken by tcp_write().
GPR_DEBUG_ASSERT(0);
}
}
// Simply associate this send record (and the underlying sent data buffers)
// with the implicit sequence number for this zerocopy sendmsg().
void AssociateSeqWithSendRecord(uint32_t seq, TcpZerocopySendRecord* record) {
MutexLock guard(&lock_);
ctx_lookup_.emplace(seq, record);
}
// Get a send record for a send that we wish to do with zerocopy.
TcpZerocopySendRecord* GetSendRecord() {
MutexLock guard(&lock_);
return TryGetSendRecordLocked();
}
// A given send record corresponds to a single tcp_write() with zerocopy
// enabled. This can result in several sendmsg() calls to flush all of the
// data to wire. Each sendmsg() takes a reference on the
// TcpZerocopySendRecord, and corresponds to a single sequence number.
// ReleaseSendRecord releases a reference on TcpZerocopySendRecord for a
// single sequence number. This is called either when we receive the relevant
// error queue notification (saying that we can discard the underlying
// buffers for this sendmsg()) is received from the kernel - or, in case
// sendmsg() was unsuccessful to begin with.
TcpZerocopySendRecord* ReleaseSendRecord(uint32_t seq) {
MutexLock guard(&lock_);
return ReleaseSendRecordLocked(seq);
}
// After all the references to a TcpZerocopySendRecord are released, we can
// add it back to the pool (of size max_sends_). Note that we can only have
// max_sends_ tcp_write() instances with zerocopy enabled in flight at the
// same time.
void PutSendRecord(TcpZerocopySendRecord* record) {
GPR_DEBUG_ASSERT(record >= send_records_ &&
record < send_records_ + max_sends_);
MutexLock guard(&lock_);
PutSendRecordLocked(record);
}
// Indicate that we are disposing of this zerocopy context. This indicator
// will prevent new zerocopy writes from being issued.
void Shutdown() { shutdown_.Store(true, MemoryOrder::RELEASE); }
// Indicates that there are no inflight tcp_write() instances with zerocopy
// enabled.
bool AllSendRecordsEmpty() {
MutexLock guard(&lock_);
return free_send_records_size_ == max_sends_;
}
bool enabled() const { return enabled_; }
void set_enabled(bool enabled) {
GPR_DEBUG_ASSERT(!enabled || !memory_limited());
enabled_ = enabled;
}
// Only use zerocopy if we are sending at least this many bytes. The
// additional overhead of reading the error queue for notifications means that
// zerocopy is not useful for small transfers.
size_t threshold_bytes() const { return threshold_bytes_; }
private:
TcpZerocopySendRecord* ReleaseSendRecordLocked(uint32_t seq) {
auto iter = ctx_lookup_.find(seq);
GPR_DEBUG_ASSERT(iter != ctx_lookup_.end());
TcpZerocopySendRecord* record = iter->second;
ctx_lookup_.erase(iter);
return record;
}
TcpZerocopySendRecord* TryGetSendRecordLocked() {
if (shutdown_.Load(MemoryOrder::ACQUIRE)) {
return nullptr;
}
if (free_send_records_size_ == 0) {
return nullptr;
}
free_send_records_size_--;
return free_send_records_[free_send_records_size_];
}
void PutSendRecordLocked(TcpZerocopySendRecord* record) {
GPR_DEBUG_ASSERT(free_send_records_size_ < max_sends_);
free_send_records_[free_send_records_size_] = record;
free_send_records_size_++;
}
TcpZerocopySendRecord* send_records_;
TcpZerocopySendRecord** free_send_records_;
int max_sends_;
int free_send_records_size_;
Mutex lock_;
uint32_t last_send_ = 0;
Atomic<bool> shutdown_;
bool enabled_ = false;
size_t threshold_bytes_ = kDefaultSendBytesThreshold;
std::unordered_map<uint32_t, TcpZerocopySendRecord*> ctx_lookup_;
bool memory_limited_ = false;
};
} // namespace grpc_core
using grpc_core::TcpZerocopySendCtx;
using grpc_core::TcpZerocopySendRecord;
namespace { namespace {
struct grpc_tcp { struct grpc_tcp {
grpc_endpoint base; grpc_endpoint base;
@ -142,6 +412,8 @@ struct grpc_tcp {
bool ts_capable; /* Cache whether we can set timestamping options */ bool ts_capable; /* Cache whether we can set timestamping options */
gpr_atm stop_error_notification; /* Set to 1 if we do not want to be notified gpr_atm stop_error_notification; /* Set to 1 if we do not want to be notified
on errors anymore */ on errors anymore */
TcpZerocopySendCtx tcp_zerocopy_send_ctx;
TcpZerocopySendRecord* current_zerocopy_send = nullptr;
}; };
struct backup_poller { struct backup_poller {
@ -151,6 +423,8 @@ struct backup_poller {
} // namespace } // namespace
static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* tcp);
#define BACKUP_POLLER_POLLSET(b) ((grpc_pollset*)((b) + 1)) #define BACKUP_POLLER_POLLSET(b) ((grpc_pollset*)((b) + 1))
static gpr_atm g_uncovered_notifications_pending; static gpr_atm g_uncovered_notifications_pending;
@ -339,6 +613,7 @@ static void tcp_handle_write(void* arg /* grpc_tcp */, grpc_error* error);
static void tcp_shutdown(grpc_endpoint* ep, grpc_error* why) { static void tcp_shutdown(grpc_endpoint* ep, grpc_error* why) {
grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep); grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep);
ZerocopyDisableAndWaitForRemaining(tcp);
grpc_fd_shutdown(tcp->em_fd, why); grpc_fd_shutdown(tcp->em_fd, why);
grpc_resource_user_shutdown(tcp->resource_user); grpc_resource_user_shutdown(tcp->resource_user);
} }
@ -357,6 +632,7 @@ static void tcp_free(grpc_tcp* tcp) {
gpr_mu_unlock(&tcp->tb_mu); gpr_mu_unlock(&tcp->tb_mu);
tcp->outgoing_buffer_arg = nullptr; tcp->outgoing_buffer_arg = nullptr;
gpr_mu_destroy(&tcp->tb_mu); gpr_mu_destroy(&tcp->tb_mu);
tcp->tcp_zerocopy_send_ctx.~TcpZerocopySendCtx();
gpr_free(tcp); gpr_free(tcp);
} }
@ -390,6 +666,7 @@ static void tcp_destroy(grpc_endpoint* ep) {
grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep); grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep);
grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer);
if (grpc_event_engine_can_track_errors()) { if (grpc_event_engine_can_track_errors()) {
ZerocopyDisableAndWaitForRemaining(tcp);
gpr_atm_no_barrier_store(&tcp->stop_error_notification, true); gpr_atm_no_barrier_store(&tcp->stop_error_notification, true);
grpc_fd_set_error(tcp->em_fd); grpc_fd_set_error(tcp->em_fd);
} }
@ -652,13 +929,13 @@ static void tcp_read(grpc_endpoint* ep, grpc_slice_buffer* incoming_buffer,
/* A wrapper around sendmsg. It sends \a msg over \a fd and returns the number /* A wrapper around sendmsg. It sends \a msg over \a fd and returns the number
* of bytes sent. */ * of bytes sent. */
ssize_t tcp_send(int fd, const struct msghdr* msg) { ssize_t tcp_send(int fd, const struct msghdr* msg, int additional_flags = 0) {
GPR_TIMER_SCOPE("sendmsg", 1); GPR_TIMER_SCOPE("sendmsg", 1);
ssize_t sent_length; ssize_t sent_length;
do { do {
/* TODO(klempner): Cork if this is a partial write */ /* TODO(klempner): Cork if this is a partial write */
GRPC_STATS_INC_SYSCALL_WRITE(); GRPC_STATS_INC_SYSCALL_WRITE();
sent_length = sendmsg(fd, msg, SENDMSG_FLAGS); sent_length = sendmsg(fd, msg, SENDMSG_FLAGS | additional_flags);
} while (sent_length < 0 && errno == EINTR); } while (sent_length < 0 && errno == EINTR);
return sent_length; return sent_length;
} }
@ -671,16 +948,52 @@ ssize_t tcp_send(int fd, const struct msghdr* msg) {
*/ */
static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg,
size_t sending_length, size_t sending_length,
ssize_t* sent_length); ssize_t* sent_length,
int additional_flags = 0);
/** The callback function to be invoked when we get an error on the socket. */ /** The callback function to be invoked when we get an error on the socket. */
static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error); static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error);
static TcpZerocopySendRecord* tcp_get_send_zerocopy_record(
grpc_tcp* tcp, grpc_slice_buffer* buf);
#ifdef GRPC_LINUX_ERRQUEUE #ifdef GRPC_LINUX_ERRQUEUE
static bool process_errors(grpc_tcp* tcp);
static TcpZerocopySendRecord* tcp_get_send_zerocopy_record(
grpc_tcp* tcp, grpc_slice_buffer* buf) {
TcpZerocopySendRecord* zerocopy_send_record = nullptr;
const bool use_zerocopy =
tcp->tcp_zerocopy_send_ctx.enabled() &&
tcp->tcp_zerocopy_send_ctx.threshold_bytes() < buf->length;
if (use_zerocopy) {
zerocopy_send_record = tcp->tcp_zerocopy_send_ctx.GetSendRecord();
if (zerocopy_send_record == nullptr) {
process_errors(tcp);
zerocopy_send_record = tcp->tcp_zerocopy_send_ctx.GetSendRecord();
}
if (zerocopy_send_record != nullptr) {
zerocopy_send_record->PrepareForSends(buf);
GPR_DEBUG_ASSERT(buf->count == 0);
GPR_DEBUG_ASSERT(buf->length == 0);
tcp->outgoing_byte_idx = 0;
tcp->outgoing_buffer = nullptr;
}
}
return zerocopy_send_record;
}
static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* tcp) {
tcp->tcp_zerocopy_send_ctx.Shutdown();
while (!tcp->tcp_zerocopy_send_ctx.AllSendRecordsEmpty()) {
process_errors(tcp);
}
}
static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg,
size_t sending_length, size_t sending_length,
ssize_t* sent_length) { ssize_t* sent_length,
int additional_flags) {
if (!tcp->socket_ts_enabled) { if (!tcp->socket_ts_enabled) {
uint32_t opt = grpc_core::kTimestampingSocketOptions; uint32_t opt = grpc_core::kTimestampingSocketOptions;
if (setsockopt(tcp->fd, SOL_SOCKET, SO_TIMESTAMPING, if (setsockopt(tcp->fd, SOL_SOCKET, SO_TIMESTAMPING,
@ -708,7 +1021,7 @@ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg,
msg->msg_controllen = CMSG_SPACE(sizeof(uint32_t)); msg->msg_controllen = CMSG_SPACE(sizeof(uint32_t));
/* If there was an error on sendmsg the logic in tcp_flush will handle it. */ /* If there was an error on sendmsg the logic in tcp_flush will handle it. */
ssize_t length = tcp_send(tcp->fd, msg); ssize_t length = tcp_send(tcp->fd, msg, additional_flags);
*sent_length = length; *sent_length = length;
/* Only save timestamps if all the bytes were taken by sendmsg. */ /* Only save timestamps if all the bytes were taken by sendmsg. */
if (sending_length == static_cast<size_t>(length)) { if (sending_length == static_cast<size_t>(length)) {
@ -722,6 +1035,43 @@ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg,
return true; return true;
} }
static void UnrefMaybePutZerocopySendRecord(grpc_tcp* tcp,
TcpZerocopySendRecord* record,
uint32_t seq, const char* tag);
// Reads \a cmsg to process zerocopy control messages.
static void process_zerocopy(grpc_tcp* tcp, struct cmsghdr* cmsg) {
GPR_DEBUG_ASSERT(cmsg);
auto serr = reinterpret_cast<struct sock_extended_err*>(CMSG_DATA(cmsg));
GPR_DEBUG_ASSERT(serr->ee_errno == 0);
GPR_DEBUG_ASSERT(serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY);
const uint32_t lo = serr->ee_info;
const uint32_t hi = serr->ee_data;
for (uint32_t seq = lo; seq <= hi; ++seq) {
// TODO(arjunroy): It's likely that lo and hi refer to zerocopy sequence
// numbers that are generated by a single call to grpc_endpoint_write; ie.
// we can batch the unref operation. So, check if record is the same for
// both; if so, batch the unref/put.
TcpZerocopySendRecord* record =
tcp->tcp_zerocopy_send_ctx.ReleaseSendRecord(seq);
GPR_DEBUG_ASSERT(record);
UnrefMaybePutZerocopySendRecord(tcp, record, seq, "CALLBACK RCVD");
}
}
// Whether the cmsg received from error queue is of the IPv4 or IPv6 levels.
static bool CmsgIsIpLevel(const cmsghdr& cmsg) {
return (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR) ||
(cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR);
}
static bool CmsgIsZeroCopy(const cmsghdr& cmsg) {
if (!CmsgIsIpLevel(cmsg)) {
return false;
}
auto serr = reinterpret_cast<const sock_extended_err*> CMSG_DATA(&cmsg);
return serr->ee_errno == 0 && serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY;
}
/** Reads \a cmsg to derive timestamps from the control messages. If a valid /** Reads \a cmsg to derive timestamps from the control messages. If a valid
* timestamp is found, the traced buffer list is updated with this timestamp. * timestamp is found, the traced buffer list is updated with this timestamp.
* The caller of this function should be looping on the control messages found * The caller of this function should be looping on the control messages found
@ -783,73 +1133,76 @@ struct cmsghdr* process_timestamp(grpc_tcp* tcp, msghdr* msg,
/** For linux platforms, reads the socket's error queue and processes error /** For linux platforms, reads the socket's error queue and processes error
* messages from the queue. * messages from the queue.
*/ */
static void process_errors(grpc_tcp* tcp) { static bool process_errors(grpc_tcp* tcp) {
bool processed_err = false;
struct iovec iov;
iov.iov_base = nullptr;
iov.iov_len = 0;
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 0;
msg.msg_flags = 0;
/* Allocate enough space so we don't need to keep increasing this as size
* of OPT_STATS increase */
constexpr size_t cmsg_alloc_space =
CMSG_SPACE(sizeof(grpc_core::scm_timestamping)) +
CMSG_SPACE(sizeof(sock_extended_err) + sizeof(sockaddr_in)) +
CMSG_SPACE(32 * NLA_ALIGN(NLA_HDRLEN + sizeof(uint64_t)));
/* Allocate aligned space for cmsgs received along with timestamps */
union {
char rbuf[cmsg_alloc_space];
struct cmsghdr align;
} aligned_buf;
msg.msg_control = aligned_buf.rbuf;
msg.msg_controllen = sizeof(aligned_buf.rbuf);
int r, saved_errno;
while (true) { while (true) {
struct iovec iov;
iov.iov_base = nullptr;
iov.iov_len = 0;
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 0;
msg.msg_flags = 0;
/* Allocate enough space so we don't need to keep increasing this as size
* of OPT_STATS increase */
constexpr size_t cmsg_alloc_space =
CMSG_SPACE(sizeof(grpc_core::scm_timestamping)) +
CMSG_SPACE(sizeof(sock_extended_err) + sizeof(sockaddr_in)) +
CMSG_SPACE(32 * NLA_ALIGN(NLA_HDRLEN + sizeof(uint64_t)));
/* Allocate aligned space for cmsgs received along with timestamps */
union {
char rbuf[cmsg_alloc_space];
struct cmsghdr align;
} aligned_buf;
memset(&aligned_buf, 0, sizeof(aligned_buf));
msg.msg_control = aligned_buf.rbuf;
msg.msg_controllen = sizeof(aligned_buf.rbuf);
int r, saved_errno;
do { do {
r = recvmsg(tcp->fd, &msg, MSG_ERRQUEUE); r = recvmsg(tcp->fd, &msg, MSG_ERRQUEUE);
saved_errno = errno; saved_errno = errno;
} while (r < 0 && saved_errno == EINTR); } while (r < 0 && saved_errno == EINTR);
if (r == -1 && saved_errno == EAGAIN) { if (r == -1 && saved_errno == EAGAIN) {
return; /* No more errors to process */ return processed_err; /* No more errors to process */
} }
if (r == -1) { if (r == -1) {
return; return processed_err;
} }
if ((msg.msg_flags & MSG_CTRUNC) != 0) { if (GPR_UNLIKELY((msg.msg_flags & MSG_CTRUNC) != 0)) {
gpr_log(GPR_ERROR, "Error message was truncated."); gpr_log(GPR_ERROR, "Error message was truncated.");
} }
if (msg.msg_controllen == 0) { if (msg.msg_controllen == 0) {
/* There was no control message found. It was probably spurious. */ /* There was no control message found. It was probably spurious. */
return; return processed_err;
} }
bool seen = false; bool seen = false;
for (auto cmsg = CMSG_FIRSTHDR(&msg); cmsg && cmsg->cmsg_len; for (auto cmsg = CMSG_FIRSTHDR(&msg); cmsg && cmsg->cmsg_len;
cmsg = CMSG_NXTHDR(&msg, cmsg)) { cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level != SOL_SOCKET || if (CmsgIsZeroCopy(*cmsg)) {
cmsg->cmsg_type != SCM_TIMESTAMPING) { process_zerocopy(tcp, cmsg);
/* Got a control message that is not a timestamp. Don't know how to seen = true;
* handle this. */ processed_err = true;
} else if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SCM_TIMESTAMPING) {
cmsg = process_timestamp(tcp, &msg, cmsg);
seen = true;
processed_err = true;
} else {
/* Got a control message that is not a timestamp or zerocopy. Don't know
* how to handle this. */
if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) {
gpr_log(GPR_INFO, gpr_log(GPR_INFO,
"unknown control message cmsg_level:%d cmsg_type:%d", "unknown control message cmsg_level:%d cmsg_type:%d",
cmsg->cmsg_level, cmsg->cmsg_type); cmsg->cmsg_level, cmsg->cmsg_type);
} }
return; return processed_err;
} }
cmsg = process_timestamp(tcp, &msg, cmsg);
seen = true;
} }
if (!seen) { if (!seen) {
return; return processed_err;
} }
} }
} }
@ -870,18 +1223,28 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error) {
/* We are still interested in collecting timestamps, so let's try reading /* We are still interested in collecting timestamps, so let's try reading
* them. */ * them. */
process_errors(tcp); bool processed = process_errors(tcp);
/* This might not a timestamps error. Set the read and write closures to be /* This might not a timestamps error. Set the read and write closures to be
* ready. */ * ready. */
grpc_fd_set_readable(tcp->em_fd); if (!processed) {
grpc_fd_set_writable(tcp->em_fd); grpc_fd_set_readable(tcp->em_fd);
grpc_fd_set_writable(tcp->em_fd);
}
grpc_fd_notify_on_error(tcp->em_fd, &tcp->error_closure); grpc_fd_notify_on_error(tcp->em_fd, &tcp->error_closure);
} }
#else /* GRPC_LINUX_ERRQUEUE */ #else /* GRPC_LINUX_ERRQUEUE */
static TcpZerocopySendRecord* tcp_get_send_zerocopy_record(
grpc_tcp* tcp, grpc_slice_buffer* buf) {
return nullptr;
}
static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* tcp) {}
static bool tcp_write_with_timestamps(grpc_tcp* /*tcp*/, struct msghdr* /*msg*/, static bool tcp_write_with_timestamps(grpc_tcp* /*tcp*/, struct msghdr* /*msg*/,
size_t /*sending_length*/, size_t /*sending_length*/,
ssize_t* /*sent_length*/) { ssize_t* /*sent_length*/,
int /*additional_flags*/) {
gpr_log(GPR_ERROR, "Write with timestamps not supported for this platform"); gpr_log(GPR_ERROR, "Write with timestamps not supported for this platform");
GPR_ASSERT(0); GPR_ASSERT(0);
return false; return false;
@ -907,12 +1270,138 @@ void tcp_shutdown_buffer_list(grpc_tcp* tcp) {
} }
} }
/* returns true if done, false if pending; if returning true, *error is set */
#if defined(IOV_MAX) && IOV_MAX < 1000 #if defined(IOV_MAX) && IOV_MAX < 1000
#define MAX_WRITE_IOVEC IOV_MAX #define MAX_WRITE_IOVEC IOV_MAX
#else #else
#define MAX_WRITE_IOVEC 1000 #define MAX_WRITE_IOVEC 1000
#endif #endif
msg_iovlen_type TcpZerocopySendRecord::PopulateIovs(size_t* unwind_slice_idx,
size_t* unwind_byte_idx,
size_t* sending_length,
iovec* iov) {
msg_iovlen_type iov_size;
*unwind_slice_idx = out_offset_.slice_idx;
*unwind_byte_idx = out_offset_.byte_idx;
for (iov_size = 0;
out_offset_.slice_idx != buf_.count && iov_size != MAX_WRITE_IOVEC;
iov_size++) {
iov[iov_size].iov_base =
GRPC_SLICE_START_PTR(buf_.slices[out_offset_.slice_idx]) +
out_offset_.byte_idx;
iov[iov_size].iov_len =
GRPC_SLICE_LENGTH(buf_.slices[out_offset_.slice_idx]) -
out_offset_.byte_idx;
*sending_length += iov[iov_size].iov_len;
++(out_offset_.slice_idx);
out_offset_.byte_idx = 0;
}
GPR_DEBUG_ASSERT(iov_size > 0);
return iov_size;
}
void TcpZerocopySendRecord::UpdateOffsetForBytesSent(size_t sending_length,
size_t actually_sent) {
size_t trailing = sending_length - actually_sent;
while (trailing > 0) {
size_t slice_length;
out_offset_.slice_idx--;
slice_length = GRPC_SLICE_LENGTH(buf_.slices[out_offset_.slice_idx]);
if (slice_length > trailing) {
out_offset_.byte_idx = slice_length - trailing;
break;
} else {
trailing -= slice_length;
}
}
}
// returns true if done, false if pending; if returning true, *error is set
static bool do_tcp_flush_zerocopy(grpc_tcp* tcp, TcpZerocopySendRecord* record,
grpc_error** error) {
struct msghdr msg;
struct iovec iov[MAX_WRITE_IOVEC];
msg_iovlen_type iov_size;
ssize_t sent_length = 0;
size_t sending_length;
size_t unwind_slice_idx;
size_t unwind_byte_idx;
while (true) {
sending_length = 0;
iov_size = record->PopulateIovs(&unwind_slice_idx, &unwind_byte_idx,
&sending_length, iov);
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = iov;
msg.msg_iovlen = iov_size;
msg.msg_flags = 0;
bool tried_sending_message = false;
// Before calling sendmsg (with or without timestamps): we
// take a single ref on the zerocopy send record.
tcp->tcp_zerocopy_send_ctx.NoteSend(record);
if (tcp->outgoing_buffer_arg != nullptr) {
if (!tcp->ts_capable ||
!tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length,
MSG_ZEROCOPY)) {
/* We could not set socket options to collect Fathom timestamps.
* Fallback on writing without timestamps. */
tcp->ts_capable = false;
tcp_shutdown_buffer_list(tcp);
} else {
tried_sending_message = true;
}
}
if (!tried_sending_message) {
msg.msg_control = nullptr;
msg.msg_controllen = 0;
GRPC_STATS_INC_TCP_WRITE_SIZE(sending_length);
GRPC_STATS_INC_TCP_WRITE_IOV_SIZE(iov_size);
sent_length = tcp_send(tcp->fd, &msg, MSG_ZEROCOPY);
}
if (sent_length < 0) {
// If this particular send failed, drop ref taken earlier in this method.
tcp->tcp_zerocopy_send_ctx.UndoSend();
if (errno == EAGAIN) {
record->UnwindIfThrottled(unwind_slice_idx, unwind_byte_idx);
return false;
} else if (errno == EPIPE) {
*error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp);
tcp_shutdown_buffer_list(tcp);
return true;
} else {
*error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp);
tcp_shutdown_buffer_list(tcp);
return true;
}
}
tcp->bytes_counter += sent_length;
record->UpdateOffsetForBytesSent(sending_length,
static_cast<size_t>(sent_length));
if (record->AllSlicesSent()) {
*error = GRPC_ERROR_NONE;
return true;
}
}
}
static void UnrefMaybePutZerocopySendRecord(grpc_tcp* tcp,
TcpZerocopySendRecord* record,
uint32_t seq, const char* tag) {
if (record->Unref()) {
tcp->tcp_zerocopy_send_ctx.PutSendRecord(record);
}
}
static bool tcp_flush_zerocopy(grpc_tcp* tcp, TcpZerocopySendRecord* record,
grpc_error** error) {
bool done = do_tcp_flush_zerocopy(tcp, record, error);
if (done) {
// Either we encountered an error, or we successfully sent all the bytes.
// In either case, we're done with this record.
UnrefMaybePutZerocopySendRecord(tcp, record, 0, "flush_done");
}
return done;
}
static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) {
struct msghdr msg; struct msghdr msg;
struct iovec iov[MAX_WRITE_IOVEC]; struct iovec iov[MAX_WRITE_IOVEC];
@ -927,7 +1416,7 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) {
// buffer as we write // buffer as we write
size_t outgoing_slice_idx = 0; size_t outgoing_slice_idx = 0;
for (;;) { while (true) {
sending_length = 0; sending_length = 0;
unwind_slice_idx = outgoing_slice_idx; unwind_slice_idx = outgoing_slice_idx;
unwind_byte_idx = tcp->outgoing_byte_idx; unwind_byte_idx = tcp->outgoing_byte_idx;
@ -1027,12 +1516,21 @@ static void tcp_handle_write(void* arg /* grpc_tcp */, grpc_error* error) {
if (error != GRPC_ERROR_NONE) { if (error != GRPC_ERROR_NONE) {
cb = tcp->write_cb; cb = tcp->write_cb;
tcp->write_cb = nullptr; tcp->write_cb = nullptr;
if (tcp->current_zerocopy_send != nullptr) {
UnrefMaybePutZerocopySendRecord(tcp, tcp->current_zerocopy_send, 0,
"handle_write_err");
tcp->current_zerocopy_send = nullptr;
}
grpc_core::Closure::Run(DEBUG_LOCATION, cb, GRPC_ERROR_REF(error)); grpc_core::Closure::Run(DEBUG_LOCATION, cb, GRPC_ERROR_REF(error));
TCP_UNREF(tcp, "write"); TCP_UNREF(tcp, "write");
return; return;
} }
if (!tcp_flush(tcp, &error)) { bool flush_result =
tcp->current_zerocopy_send != nullptr
? tcp_flush_zerocopy(tcp, tcp->current_zerocopy_send, &error)
: tcp_flush(tcp, &error);
if (!flush_result) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) {
gpr_log(GPR_INFO, "write: delayed"); gpr_log(GPR_INFO, "write: delayed");
} }
@ -1042,6 +1540,7 @@ static void tcp_handle_write(void* arg /* grpc_tcp */, grpc_error* error) {
} else { } else {
cb = tcp->write_cb; cb = tcp->write_cb;
tcp->write_cb = nullptr; tcp->write_cb = nullptr;
tcp->current_zerocopy_send = nullptr;
if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) {
const char* str = grpc_error_string(error); const char* str = grpc_error_string(error);
gpr_log(GPR_INFO, "write: %s", str); gpr_log(GPR_INFO, "write: %s", str);
@ -1057,6 +1556,7 @@ static void tcp_write(grpc_endpoint* ep, grpc_slice_buffer* buf,
GPR_TIMER_SCOPE("tcp_write", 0); GPR_TIMER_SCOPE("tcp_write", 0);
grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep); grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep);
grpc_error* error = GRPC_ERROR_NONE; grpc_error* error = GRPC_ERROR_NONE;
TcpZerocopySendRecord* zerocopy_send_record = nullptr;
if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) {
size_t i; size_t i;
@ -1073,8 +1573,8 @@ static void tcp_write(grpc_endpoint* ep, grpc_slice_buffer* buf,
} }
GPR_ASSERT(tcp->write_cb == nullptr); GPR_ASSERT(tcp->write_cb == nullptr);
GPR_DEBUG_ASSERT(tcp->current_zerocopy_send == nullptr);
tcp->outgoing_buffer_arg = arg;
if (buf->length == 0) { if (buf->length == 0) {
grpc_core::Closure::Run( grpc_core::Closure::Run(
DEBUG_LOCATION, cb, DEBUG_LOCATION, cb,
@ -1085,15 +1585,26 @@ static void tcp_write(grpc_endpoint* ep, grpc_slice_buffer* buf,
tcp_shutdown_buffer_list(tcp); tcp_shutdown_buffer_list(tcp);
return; return;
} }
tcp->outgoing_buffer = buf;
tcp->outgoing_byte_idx = 0; zerocopy_send_record = tcp_get_send_zerocopy_record(tcp, buf);
if (zerocopy_send_record == nullptr) {
// Either not enough bytes, or couldn't allocate a zerocopy context.
tcp->outgoing_buffer = buf;
tcp->outgoing_byte_idx = 0;
}
tcp->outgoing_buffer_arg = arg;
if (arg) { if (arg) {
GPR_ASSERT(grpc_event_engine_can_track_errors()); GPR_ASSERT(grpc_event_engine_can_track_errors());
} }
if (!tcp_flush(tcp, &error)) { bool flush_result =
zerocopy_send_record != nullptr
? tcp_flush_zerocopy(tcp, zerocopy_send_record, &error)
: tcp_flush(tcp, &error);
if (!flush_result) {
TCP_REF(tcp, "write"); TCP_REF(tcp, "write");
tcp->write_cb = cb; tcp->write_cb = cb;
tcp->current_zerocopy_send = zerocopy_send_record;
if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) {
gpr_log(GPR_INFO, "write: delayed"); gpr_log(GPR_INFO, "write: delayed");
} }
@ -1121,6 +1632,7 @@ static void tcp_add_to_pollset_set(grpc_endpoint* ep,
static void tcp_delete_from_pollset_set(grpc_endpoint* ep, static void tcp_delete_from_pollset_set(grpc_endpoint* ep,
grpc_pollset_set* pollset_set) { grpc_pollset_set* pollset_set) {
grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep); grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep);
ZerocopyDisableAndWaitForRemaining(tcp);
grpc_pollset_set_del_fd(pollset_set, tcp->em_fd); grpc_pollset_set_del_fd(pollset_set, tcp->em_fd);
} }
@ -1172,9 +1684,15 @@ static const grpc_endpoint_vtable vtable = {tcp_read,
grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd,
const grpc_channel_args* channel_args, const grpc_channel_args* channel_args,
const char* peer_string) { const char* peer_string) {
static constexpr bool kZerocpTxEnabledDefault = false;
int tcp_read_chunk_size = GRPC_TCP_DEFAULT_READ_SLICE_SIZE; int tcp_read_chunk_size = GRPC_TCP_DEFAULT_READ_SLICE_SIZE;
int tcp_max_read_chunk_size = 4 * 1024 * 1024; int tcp_max_read_chunk_size = 4 * 1024 * 1024;
int tcp_min_read_chunk_size = 256; int tcp_min_read_chunk_size = 256;
bool tcp_tx_zerocopy_enabled = kZerocpTxEnabledDefault;
int tcp_tx_zerocopy_send_bytes_thresh =
grpc_core::TcpZerocopySendCtx::kDefaultSendBytesThreshold;
int tcp_tx_zerocopy_max_simult_sends =
grpc_core::TcpZerocopySendCtx::kDefaultMaxSends;
grpc_resource_quota* resource_quota = grpc_resource_quota_create(nullptr); grpc_resource_quota* resource_quota = grpc_resource_quota_create(nullptr);
if (channel_args != nullptr) { if (channel_args != nullptr) {
for (size_t i = 0; i < channel_args->num_args; i++) { for (size_t i = 0; i < channel_args->num_args; i++) {
@ -1199,6 +1717,23 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd,
resource_quota = resource_quota =
grpc_resource_quota_ref_internal(static_cast<grpc_resource_quota*>( grpc_resource_quota_ref_internal(static_cast<grpc_resource_quota*>(
channel_args->args[i].value.pointer.p)); channel_args->args[i].value.pointer.p));
} else if (0 == strcmp(channel_args->args[i].key,
GRPC_ARG_TCP_TX_ZEROCOPY_ENABLED)) {
tcp_tx_zerocopy_enabled = grpc_channel_arg_get_bool(
&channel_args->args[i], kZerocpTxEnabledDefault);
} else if (0 == strcmp(channel_args->args[i].key,
GRPC_ARG_TCP_TX_ZEROCOPY_SEND_BYTES_THRESHOLD)) {
grpc_integer_options options = {
grpc_core::TcpZerocopySendCtx::kDefaultSendBytesThreshold, 0,
INT_MAX};
tcp_tx_zerocopy_send_bytes_thresh =
grpc_channel_arg_get_integer(&channel_args->args[i], options);
} else if (0 == strcmp(channel_args->args[i].key,
GRPC_ARG_TCP_TX_ZEROCOPY_MAX_SIMULT_SENDS)) {
grpc_integer_options options = {
grpc_core::TcpZerocopySendCtx::kDefaultMaxSends, 0, INT_MAX};
tcp_tx_zerocopy_max_simult_sends =
grpc_channel_arg_get_integer(&channel_args->args[i], options);
} }
} }
} }
@ -1215,6 +1750,7 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd,
tcp->fd = grpc_fd_wrapped_fd(em_fd); tcp->fd = grpc_fd_wrapped_fd(em_fd);
tcp->read_cb = nullptr; tcp->read_cb = nullptr;
tcp->write_cb = nullptr; tcp->write_cb = nullptr;
tcp->current_zerocopy_send = nullptr;
tcp->release_fd_cb = nullptr; tcp->release_fd_cb = nullptr;
tcp->release_fd = nullptr; tcp->release_fd = nullptr;
tcp->incoming_buffer = nullptr; tcp->incoming_buffer = nullptr;
@ -1228,6 +1764,20 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd,
tcp->socket_ts_enabled = false; tcp->socket_ts_enabled = false;
tcp->ts_capable = true; tcp->ts_capable = true;
tcp->outgoing_buffer_arg = nullptr; tcp->outgoing_buffer_arg = nullptr;
new (&tcp->tcp_zerocopy_send_ctx) TcpZerocopySendCtx(
tcp_tx_zerocopy_max_simult_sends, tcp_tx_zerocopy_send_bytes_thresh);
if (tcp_tx_zerocopy_enabled && !tcp->tcp_zerocopy_send_ctx.memory_limited()) {
#ifdef GRPC_LINUX_ERRQUEUE
const int enable = 1;
auto err =
setsockopt(tcp->fd, SOL_SOCKET, SO_ZEROCOPY, &enable, sizeof(enable));
if (err == 0) {
tcp->tcp_zerocopy_send_ctx.set_enabled(true);
} else {
gpr_log(GPR_ERROR, "Failed to set zerocopy options on the socket.");
}
#endif
}
/* paired with unref in grpc_tcp_destroy */ /* paired with unref in grpc_tcp_destroy */
new (&tcp->refcount) grpc_core::RefCount(1, &grpc_tcp_trace); new (&tcp->refcount) grpc_core::RefCount(1, &grpc_tcp_trace);
gpr_atm_no_barrier_store(&tcp->shutdown_count, 0); gpr_atm_no_barrier_store(&tcp->shutdown_count, 0);
@ -1294,6 +1844,7 @@ void grpc_tcp_destroy_and_release_fd(grpc_endpoint* ep, int* fd,
grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer);
if (grpc_event_engine_can_track_errors()) { if (grpc_event_engine_can_track_errors()) {
/* Stop errors notification. */ /* Stop errors notification. */
ZerocopyDisableAndWaitForRemaining(tcp);
gpr_atm_no_barrier_store(&tcp->stop_error_notification, true); gpr_atm_no_barrier_store(&tcp->stop_error_notification, true);
grpc_fd_set_error(tcp->em_fd); grpc_fd_set_error(tcp->em_fd);
} }

@ -157,6 +157,14 @@ grpc_error* grpc_tcp_server_prepare_socket(grpc_tcp_server* s, int fd,
if (err != GRPC_ERROR_NONE) goto error; if (err != GRPC_ERROR_NONE) goto error;
} }
#ifdef GRPC_LINUX_ERRQUEUE
err = grpc_set_socket_zerocopy(fd);
if (err != GRPC_ERROR_NONE) {
/* it's not fatal, so just log it. */
gpr_log(GPR_DEBUG, "Node does not support SO_ZEROCOPY, continuing.");
GRPC_ERROR_UNREF(err);
}
#endif
err = grpc_set_socket_nonblocking(fd, 1); err = grpc_set_socket_nonblocking(fd, 1);
if (err != GRPC_ERROR_NONE) goto error; if (err != GRPC_ERROR_NONE) goto error;
err = grpc_set_socket_cloexec(fd, 1); err = grpc_set_socket_cloexec(fd, 1);

@ -26,6 +26,7 @@
#include <utility> #include <utility>
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gpr/string.h" #include "src/core/lib/gpr/string.h"
#include "src/core/lib/gpr/useful.h" #include "src/core/lib/gpr/useful.h"
#include "src/cpp/server/external_connection_acceptor_impl.h" #include "src/cpp/server/external_connection_acceptor_impl.h"
@ -218,20 +219,24 @@ ServerBuilder& ServerBuilder::AddListeningPort(
std::unique_ptr<grpc::Server> ServerBuilder::BuildAndStart() { std::unique_ptr<grpc::Server> ServerBuilder::BuildAndStart() {
grpc::ChannelArguments args; grpc::ChannelArguments args;
for (const auto& option : options_) { for (const auto& option : options_) {
option->UpdateArguments(&args); option->UpdateArguments(&args);
option->UpdatePlugins(&plugins_); option->UpdatePlugins(&plugins_);
} }
for (const auto& plugin : plugins_) {
plugin->UpdateServerBuilder(this);
plugin->UpdateChannelArguments(&args);
}
if (max_receive_message_size_ >= -1) { if (max_receive_message_size_ >= -1) {
grpc_channel_args c_args = args.c_channel_args();
const grpc_arg* arg =
grpc_channel_args_find(&c_args, GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH);
// Some option has set max_receive_message_length and it is also set
// directly on the ServerBuilder.
if (arg != nullptr) {
gpr_log(
GPR_ERROR,
"gRPC ServerBuilder receives multiple max_receive_message_length");
}
args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, max_receive_message_size_); args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, max_receive_message_size_);
} }
// The default message size is -1 (max), so no need to explicitly set it for // The default message size is -1 (max), so no need to explicitly set it for
// -1. // -1.
if (max_send_message_size_ >= 0) { if (max_send_message_size_ >= 0) {
@ -254,6 +259,11 @@ std::unique_ptr<grpc::Server> ServerBuilder::BuildAndStart() {
grpc_resource_quota_arg_vtable()); grpc_resource_quota_arg_vtable());
} }
for (const auto& plugin : plugins_) {
plugin->UpdateServerBuilder(this);
plugin->UpdateChannelArguments(&args);
}
// == Determine if the server has any syncrhonous methods == // == Determine if the server has any syncrhonous methods ==
bool has_sync_methods = false; bool has_sync_methods = false;
for (const auto& value : services_) { for (const auto& value : services_) {
@ -332,10 +342,10 @@ std::unique_ptr<grpc::Server> ServerBuilder::BuildAndStart() {
} }
std::unique_ptr<grpc::Server> server(new grpc::Server( std::unique_ptr<grpc::Server> server(new grpc::Server(
max_receive_message_size_, &args, sync_server_cqs, &args, sync_server_cqs, sync_server_settings_.min_pollers,
sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, sync_server_settings_.max_pollers, sync_server_settings_.cq_timeout_msec,
sync_server_settings_.cq_timeout_msec, std::move(acceptors_), std::move(acceptors_), resource_quota_,
resource_quota_, std::move(interceptor_creators_))); std::move(interceptor_creators_)));
grpc_impl::ServerInitializer* initializer = server->initializer(); grpc_impl::ServerInitializer* initializer = server->initializer();

@ -23,6 +23,7 @@
#include <utility> #include <utility>
#include <grpc/grpc.h> #include <grpc/grpc.h>
#include <grpc/impl/codegen/grpc_types.h>
#include <grpc/support/alloc.h> #include <grpc/support/alloc.h>
#include <grpc/support/log.h> #include <grpc/support/log.h>
#include <grpcpp/completion_queue.h> #include <grpcpp/completion_queue.h>
@ -964,7 +965,7 @@ class Server::SyncRequestThreadManager : public grpc::ThreadManager {
static grpc::internal::GrpcLibraryInitializer g_gli_initializer; static grpc::internal::GrpcLibraryInitializer g_gli_initializer;
Server::Server( Server::Server(
int max_receive_message_size, grpc::ChannelArguments* args, grpc::ChannelArguments* args,
std::shared_ptr<std::vector<std::unique_ptr<grpc::ServerCompletionQueue>>> std::shared_ptr<std::vector<std::unique_ptr<grpc::ServerCompletionQueue>>>
sync_server_cqs, sync_server_cqs,
int min_pollers, int max_pollers, int sync_cq_timeout_msec, int min_pollers, int max_pollers, int sync_cq_timeout_msec,
@ -976,7 +977,7 @@ Server::Server(
interceptor_creators) interceptor_creators)
: acceptors_(std::move(acceptors)), : acceptors_(std::move(acceptors)),
interceptor_creators_(std::move(interceptor_creators)), interceptor_creators_(std::move(interceptor_creators)),
max_receive_message_size_(max_receive_message_size), max_receive_message_size_(INT_MIN),
sync_server_cqs_(std::move(sync_server_cqs)), sync_server_cqs_(std::move(sync_server_cqs)),
started_(false), started_(false),
shutdown_(false), shutdown_(false),
@ -1026,10 +1027,12 @@ Server::Server(
static_cast<grpc::HealthCheckServiceInterface*>( static_cast<grpc::HealthCheckServiceInterface*>(
channel_args.args[i].value.pointer.p)); channel_args.args[i].value.pointer.p));
} }
break; }
if (0 ==
strcmp(channel_args.args[i].key, GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH)) {
max_receive_message_size_ = channel_args.args[i].value.integer;
} }
} }
server_ = grpc_server_create(&channel_args, nullptr); server_ = grpc_server_create(&channel_args, nullptr);
} }

@ -1162,7 +1162,13 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
@abc.abstractmethod @abc.abstractmethod
def set_trailing_metadata(self, trailing_metadata): def set_trailing_metadata(self, trailing_metadata):
"""Sends the trailing metadata for the RPC. """Sets the trailing metadata for the RPC.
Sets the trailing metadata to be sent upon completion of the RPC.
If this method is invoked multiple times throughout the lifetime of an
RPC, the value supplied in the final invocation will be the value sent
over the wire.
This method need not be called by implementations if they have no This method need not be called by implementations if they have no
metadata to add to what the gRPC runtime will transmit. metadata to add to what the gRPC runtime will transmit.

@ -22,13 +22,12 @@ cdef class _AioCall:
# time we need access to the event loop. # time we need access to the event loop.
object _loop object _loop
# Streaming call only attributes: # Flag indicates whether cancel being called or not. Cancellation from
# # Core or peer works perfectly fine with normal procedure. However, we
# A asyncio.Event that indicates if the status is received on the client side. # need this flag to clean up resources for cancellation from the
object _status_received # application layer. Directly cancelling tasks might cause segfault
# A tuple of key value pairs representing the initial metadata sent by peer. # because Core is holding a pointer for the callback handler.
tuple _initial_metadata bint _is_locally_cancelled
cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except * cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except *
cdef void _destroy_grpc_call(self) cdef void _destroy_grpc_call(self)
cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future)

@ -33,8 +33,7 @@ cdef class _AioCall:
self._grpc_call_wrapper = GrpcCallWrapper() self._grpc_call_wrapper = GrpcCallWrapper()
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method) self._create_grpc_call(deadline, method)
self._is_locally_cancelled = False
self._status_received = asyncio.Event(loop=self._loop)
def __dealloc__(self): def __dealloc__(self):
self._destroy_grpc_call() self._destroy_grpc_call()
@ -78,17 +77,21 @@ cdef class _AioCall:
"""Destroys the corresponding Core object for this RPC.""" """Destroys the corresponding Core object for this RPC."""
grpc_call_unref(self._grpc_call_wrapper.call) grpc_call_unref(self._grpc_call_wrapper.call)
cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future): def cancel(self, AioRpcStatus status):
"""Cancels the RPC in Core, and return the final RPC status.""" """Cancels the RPC in Core with given RPC status.
cdef AioRpcStatus status
Above abstractions must invoke this method to set Core objects into
proper state.
"""
self._is_locally_cancelled = True
cdef object details cdef object details
cdef char *c_details cdef char *c_details
cdef grpc_call_error error cdef grpc_call_error error
# Try to fetch application layer cancellation details in the future. # Try to fetch application layer cancellation details in the future.
# * If cancellation details present, cancel with status; # * If cancellation details present, cancel with status;
# * If details not present, cancel with unknown reason. # * If details not present, cancel with unknown reason.
if cancellation_future.done(): if status is not None:
status = cancellation_future.result()
details = str_to_bytes(status.details()) details = str_to_bytes(status.details())
self._references.append(details) self._references.append(details)
c_details = <char *>details c_details = <char *>details
@ -100,23 +103,13 @@ cdef class _AioCall:
NULL, NULL,
) )
assert error == GRPC_CALL_OK assert error == GRPC_CALL_OK
return status
else: else:
# By implementation, grpc_call_cancel always return OK # By implementation, grpc_call_cancel always return OK
error = grpc_call_cancel(self._grpc_call_wrapper.call, NULL) error = grpc_call_cancel(self._grpc_call_wrapper.call, NULL)
assert error == GRPC_CALL_OK assert error == GRPC_CALL_OK
status = AioRpcStatus(
StatusCode.cancelled,
_UNKNOWN_CANCELLATION_DETAILS,
None,
None,
)
cancellation_future.set_result(status)
return status
async def unary_unary(self, async def unary_unary(self,
bytes request, bytes request,
object cancellation_future,
object initial_metadata_observer, object initial_metadata_observer,
object status_observer): object status_observer):
"""Performs a unary unary RPC. """Performs a unary unary RPC.
@ -145,19 +138,11 @@ cdef class _AioCall:
receive_initial_metadata_op, receive_message_op, receive_initial_metadata_op, receive_message_op,
receive_status_on_client_op) receive_status_on_client_op)
try: # Executes all operations in one batch.
await execute_batch(self._grpc_call_wrapper, # Might raise CancelledError, handling it in Python UnaryUnaryCall.
ops, await execute_batch(self._grpc_call_wrapper,
self._loop) ops,
except asyncio.CancelledError: self._loop)
status = self._cancel_and_create_status(cancellation_future)
initial_metadata_observer(None)
status_observer(status)
raise
else:
initial_metadata_observer(
receive_initial_metadata_op.initial_metadata()
)
status = AioRpcStatus( status = AioRpcStatus(
receive_status_on_client_op.code(), receive_status_on_client_op.code(),
@ -179,6 +164,11 @@ cdef class _AioCall:
cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,) cdef tuple ops = (op,)
await execute_batch(self._grpc_call_wrapper, ops, self._loop) await execute_batch(self._grpc_call_wrapper, ops, self._loop)
# Halts if the RPC is locally cancelled
if self._is_locally_cancelled:
return
cdef AioRpcStatus status = AioRpcStatus( cdef AioRpcStatus status = AioRpcStatus(
op.code(), op.code(),
op.details(), op.details(),
@ -186,52 +176,30 @@ cdef class _AioCall:
op.error_string(), op.error_string(),
) )
status_observer(status) status_observer(status)
self._status_received.set()
def _handle_cancellation_from_application(self,
object cancellation_future,
object status_observer):
def _cancellation_action(finished_future):
if not self._status_received.set():
status = self._cancel_and_create_status(finished_future)
status_observer(status)
self._status_received.set()
cancellation_future.add_done_callback(_cancellation_action) async def receive_serialized_message(self):
"""Receives one single raw message in bytes."""
async def _message_async_generator(self):
cdef bytes received_message cdef bytes received_message
# Infinitely receiving messages, until: # Receives a message. Returns None when failed:
# * EOF, no more messages to read; # * EOF, no more messages to read;
# * The client application cancells; # * The client application cancels;
# * The server sends final status. # * The server sends final status.
while True: received_message = await _receive_message(
if self._status_received.is_set(): self._grpc_call_wrapper,
return self._loop
)
received_message = await _receive_message( return received_message
self._grpc_call_wrapper,
self._loop
)
if received_message is None:
# The read operation failed, Core should explain why it fails
await self._status_received.wait()
return
else:
yield received_message
async def unary_stream(self, async def unary_stream(self,
bytes request, bytes request,
object cancellation_future,
object initial_metadata_observer, object initial_metadata_observer,
object status_observer): object status_observer):
"""Actual implementation of the complete unary-stream call. """Implementation of the start of a unary-stream call."""
# Peer may prematurely end this RPC at any point. We need a corutine
Needs to pay extra attention to the raise mechanism. If we want to # that watches if the server sends the final status.
propagate the final status exception, then we have to raise it. self._loop.create_task(self._handle_status_once_received(status_observer))
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
cdef tuple outbound_ops cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation( cdef Operation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA, _EMPTY_METADATA,
@ -248,21 +216,13 @@ cdef class _AioCall:
send_close_op, send_close_op,
) )
# Actually sends out the request message. # Sends out the request message.
await execute_batch(self._grpc_call_wrapper, await execute_batch(self._grpc_call_wrapper,
outbound_ops, outbound_ops,
self._loop) self._loop)
# Peer may prematurely end this RPC at any point. We need a mechanism
# that handles both the normal case and the error case.
self._loop.create_task(self._handle_status_once_received(status_observer))
self._handle_cancellation_from_application(cancellation_future,
status_observer)
# Receives initial metadata. # Receives initial metadata.
initial_metadata_observer( initial_metadata_observer(
await _receive_initial_metadata(self._grpc_call_wrapper, await _receive_initial_metadata(self._grpc_call_wrapper,
self._loop), self._loop),
) )
return self._message_async_generator()

@ -26,38 +26,13 @@ cdef class AioChannel:
def close(self): def close(self):
grpc_channel_destroy(self.channel) grpc_channel_destroy(self.channel)
async def unary_unary(self, def call(self,
bytes method, bytes method,
bytes request, object deadline):
object deadline, """Assembles a Cython Call object.
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Assembles a unary-unary RPC.
Returns: Returns:
The response message in bytes. The _AioCall object.
""" """
cdef _AioCall call = _AioCall(self, deadline, method) cdef _AioCall call = _AioCall(self, deadline, method)
return await call.unary_unary(request, return call
cancellation_future,
initial_metadata_observer,
status_observer)
def unary_stream(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Assembles a unary-stream RPC.
Returns:
An async generator that yields raw responses.
"""
cdef _AioCall call = _AioCall(self, deadline, method)
return call.unary_stream(request,
cancellation_future,
initial_metadata_observer,
status_observer)

@ -41,6 +41,8 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tdebug_error_string = "{}"\n' '\tdebug_error_string = "{}"\n'
'>') '>')
_EMPTY_METADATA = tuple()
class AioRpcError(grpc.RpcError): class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API. """An implementation of RpcError to be used by the asynchronous API.
@ -148,14 +150,14 @@ class Call(_base_call.Call):
_code: grpc.StatusCode _code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus] _status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType] _initial_metadata: Awaitable[MetadataType]
_cancellation: asyncio.Future _locally_cancelled: bool
def __init__(self) -> None: def __init__(self) -> None:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._code = None self._code = None
self._status = self._loop.create_future() self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future() self._initial_metadata = self._loop.create_future()
self._cancellation = self._loop.create_future() self._locally_cancelled = False
def cancel(self) -> bool: def cancel(self) -> bool:
"""Placeholder cancellation method. """Placeholder cancellation method.
@ -167,8 +169,7 @@ class Call(_base_call.Call):
raise NotImplementedError() raise NotImplementedError()
def cancelled(self) -> bool: def cancelled(self) -> bool:
return self._cancellation.done( return self._code == grpc.StatusCode.CANCELLED
) or self._code == grpc.StatusCode.CANCELLED
def done(self) -> bool: def done(self) -> bool:
return self._status.done() return self._status.done()
@ -205,14 +206,22 @@ class Call(_base_call.Call):
cancellation (by application) and Core receiving status from peer. We cancellation (by application) and Core receiving status from peer. We
make no promise here which one will win. make no promise here which one will win.
""" """
if self._status.done(): # In case of local cancellation, flip the flag.
return if status.details() is _LOCAL_CANCELLATION_DETAILS:
else: self._locally_cancelled = True
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[
status.code()]
async def _raise_rpc_error_if_not_ok(self) -> None: # In case of the RPC finished without receiving metadata.
if not self._initial_metadata.done():
self._initial_metadata.set_result(_EMPTY_METADATA)
# Sets final status
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()]
async def _raise_for_status(self) -> None:
if self._locally_cancelled:
raise asyncio.CancelledError()
await self._status
if self._code != grpc.StatusCode.OK: if self._code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(), raise _create_rpc_error(await self.initial_metadata(),
self._status.result()) self._status.result())
@ -245,12 +254,11 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
Returned when an instance of `UnaryUnaryMultiCallable` object is called. Returned when an instance of `UnaryUnaryMultiCallable` object is called.
""" """
_request: RequestType _request: RequestType
_deadline: Optional[float]
_channel: cygrpc.AioChannel _channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction _request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction _response_deserializer: DeserializingFunction
_call: asyncio.Task _call: asyncio.Task
_cython_call: cygrpc._AioCall
def __init__(self, request: RequestType, deadline: Optional[float], def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
@ -258,11 +266,10 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction) -> None:
super().__init__() super().__init__()
self._request = request self._request = request
self._deadline = deadline
self._channel = channel self._channel = channel
self._method = method
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._cython_call = self._channel.call(method, deadline)
self._call = self._loop.create_task(self._invoke()) self._call = self._loop.create_task(self._invoke())
def __del__(self) -> None: def __del__(self) -> None:
@ -275,28 +282,30 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
serialized_request = _common.serialize(self._request, serialized_request = _common.serialize(self._request,
self._request_serializer) self._request_serializer)
# NOTE(lidiz) asyncio.CancelledError is not a good transport for # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
# status, since the Task class do not cache the exact # because the asyncio.Task class do not cache the exception object.
# asyncio.CancelledError object. So, the solution is catching the error # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
# in Cython layer, then cancel the RPC and update the status, finally try:
# re-raise the CancelledError. serialized_response = await self._cython_call.unary_unary(
serialized_response = await self._channel.unary_unary( serialized_request,
self._method, self._set_initial_metadata,
serialized_request, self._set_status,
self._deadline, )
self._cancellation, except asyncio.CancelledError:
self._set_initial_metadata, if self._code != grpc.StatusCode.CANCELLED:
self._set_status, self.cancel()
)
await self._raise_rpc_error_if_not_ok() # Raises here if RPC failed or cancelled
await self._raise_for_status()
return _common.deserialize(serialized_response, return _common.deserialize(serialized_response,
self._response_deserializer) self._response_deserializer)
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning.""" """Forwards the application cancellation reasoning."""
if not self._status.done() and not self._cancellation.done(): if not self._status.done():
self._cancellation.set_result(status) self._set_status(status)
self._cython_call.cancel(status)
self._call.cancel() self._call.cancel()
return True return True
else: else:
@ -308,16 +317,17 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
_LOCAL_CANCELLATION_DETAILS, None, None)) _LOCAL_CANCELLATION_DETAILS, None, None))
def __await__(self) -> ResponseType: def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes. """Wait till the ongoing RPC request finishes."""
try:
Returns: response = yield from self._call
Response of the RPC call. except asyncio.CancelledError:
# Even if we caught all other CancelledError, there is still
Raises: # this corner case. If the application cancels immediately after
RpcError: Indicating that the RPC terminated with non-OK status. # the Call object is created, we will observe this
asyncio.CancelledError: Indicating that the RPC was canceled. # `CancelledError`.
""" if not self.cancelled():
response = yield from self._call self.cancel()
raise
return response return response
@ -328,13 +338,11 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
Returned when an instance of `UnaryStreamMultiCallable` object is called. Returned when an instance of `UnaryStreamMultiCallable` object is called.
""" """
_request: RequestType _request: RequestType
_deadline: Optional[float]
_channel: cygrpc.AioChannel _channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction _request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction _response_deserializer: DeserializingFunction
_call: asyncio.Task _cython_call: cygrpc._AioCall
_bytes_aiter: AsyncIterable[bytes] _send_unary_request_task: asyncio.Task
_message_aiter: AsyncIterable[ResponseType] _message_aiter: AsyncIterable[ResponseType]
def __init__(self, request: RequestType, deadline: Optional[float], def __init__(self, request: RequestType, deadline: Optional[float],
@ -343,13 +351,13 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction) -> None:
super().__init__() super().__init__()
self._request = request self._request = request
self._deadline = deadline
self._channel = channel self._channel = channel
self._method = method
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke()) self._send_unary_request_task = self._loop.create_task(
self._message_aiter = self._process() self._send_unary_request())
self._message_aiter = self._fetch_stream_responses()
self._cython_call = self._channel.call(method, deadline)
def __del__(self) -> None: def __del__(self) -> None:
if not self._status.done(): if not self._status.done():
@ -357,32 +365,24 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None)) _GC_CANCELLATION_DETAILS, None, None))
async def _invoke(self) -> ResponseType: async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request, serialized_request = _common.serialize(self._request,
self._request_serializer) self._request_serializer)
try:
self._bytes_aiter = await self._channel.unary_stream( await self._cython_call.unary_stream(serialized_request,
self._method, self._set_initial_metadata,
serialized_request, self._set_status)
self._deadline, except asyncio.CancelledError:
self._cancellation, if self._code != grpc.StatusCode.CANCELLED:
self._set_initial_metadata, self.cancel()
self._set_status, raise
)
async def _fetch_stream_responses(self) -> ResponseType:
async def _process(self) -> ResponseType: await self._send_unary_request_task
await self._call message = await self._read()
async for serialized_response in self._bytes_aiter: while message:
if self._cancellation.done(): yield message
await self._status message = await self._read()
if self._status.done():
# Raises pre-maturely if final status received here. Generates
# more helpful stack trace for end users.
await self._raise_rpc_error_if_not_ok()
yield _common.deserialize(serialized_response,
self._response_deserializer)
await self._raise_rpc_error_if_not_ok()
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning. """Forwards the application cancellation reasoning.
@ -395,8 +395,15 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
and the client calling "cancel" at the same time, this method respects and the client calling "cancel" at the same time, this method respects
the winner in Core. the winner in Core.
""" """
if not self._status.done() and not self._cancellation.done(): if not self._status.done():
self._cancellation.set_result(status) self._set_status(status)
self._cython_call.cancel(status)
if not self._send_unary_request_task.done():
# Injects CancelledError to the Task. The exception will
# propagate to _fetch_stream_responses as well, if the sending
# is not done.
self._send_unary_request_task.cancel()
return True return True
else: else:
return False return False
@ -409,8 +416,35 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
def __aiter__(self) -> AsyncIterable[ResponseType]: def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._message_aiter return self._message_aiter
async def _read(self) -> ResponseType:
# Wait for the request being sent
await self._send_unary_request_task
# Reads response message from Core
try:
raw_response = await self._cython_call.receive_serialized_message()
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
raise
if raw_response is None:
return None
else:
return _common.deserialize(raw_response,
self._response_deserializer)
async def read(self) -> ResponseType: async def read(self) -> ResponseType:
if self._status.done(): if self._status.done():
await self._raise_rpc_error_if_not_ok() await self._raise_for_status()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
return await self._message_aiter.__anext__()
response_message = await self._read()
if response_message is None:
# If the read operation failed, Core should explain why.
await self._raise_for_status()
# If no exception raised, there is something wrong internally.
assert False, 'Read operation failed with StatusCode.OK'
else:
return response_message

@ -33,6 +33,8 @@ _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 _RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_UNREACHABLE_TARGET = '0.1:1111' _UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31 - 1
class TestUnaryUnaryCall(AioTestBase): class TestUnaryUnaryCall(AioTestBase):
@ -119,24 +121,38 @@ class TestUnaryUnaryCall(AioTestBase):
self.assertFalse(call.cancelled()) self.assertFalse(call.cancelled())
# TODO(https://github.com/grpc/grpc/issues/20869) remove sleep.
# Force the loop to execute the RPC task.
await asyncio.sleep(0)
self.assertTrue(call.cancel()) self.assertTrue(call.cancel())
self.assertFalse(call.cancel()) self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError) as exception_context: with self.assertRaises(asyncio.CancelledError):
await call await call
# The info in the RpcError should match the info in Call object.
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(), self.assertEqual(await call.details(),
'Locally cancelled by application!') 'Locally cancelled by application!')
# NOTE(lidiz) The CancelledError is almost always re-created, async def test_cancel_unary_unary_in_task(self):
# so we might not want to use it to transmit data. async with aio.insecure_channel(self._server_target) as channel:
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
call = stub.EmptyCall(messages_pb2.SimpleRequest())
async def another_coro():
coro_started.set()
await call
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
class TestUnaryStreamCall(AioTestBase): class TestUnaryStreamCall(AioTestBase):
@ -175,7 +191,7 @@ class TestUnaryStreamCall(AioTestBase):
call.details()) call.details())
self.assertFalse(call.cancel()) self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(asyncio.CancelledError):
await call.read() await call.read()
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
@ -206,7 +222,7 @@ class TestUnaryStreamCall(AioTestBase):
self.assertFalse(call.cancel()) self.assertFalse(call.cancel())
self.assertFalse(call.cancel()) self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(asyncio.CancelledError):
await call.read() await call.read()
async def test_early_cancel_unary_stream(self): async def test_early_cancel_unary_stream(self):
@ -230,16 +246,11 @@ class TestUnaryStreamCall(AioTestBase):
self.assertTrue(call.cancel()) self.assertTrue(call.cancel())
self.assertFalse(call.cancel()) self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(asyncio.CancelledError):
await call.read() await call.read()
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION,
exception_context.exception.details())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details()) call.details())
@ -323,6 +334,69 @@ class TestUnaryStreamCall(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_cancel_unary_stream_in_task_using_read(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
await call.read()
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
async def test_cancel_unary_stream_in_task_using_async_for(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
async for _ in call:
pass
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)

@ -15,6 +15,9 @@
set -ex set -ex
ACTION=${1:---in-place}
[[ $ACTION == '--in-place' ]] || [[ $ACTION == '--diff' ]]
# change to root directory # change to root directory
cd "$(dirname "${0}")/../.." cd "$(dirname "${0}")/../.."
@ -33,4 +36,4 @@ PYTHON=${VIRTUALENV}/bin/python
"$PYTHON" -m pip install --upgrade futures "$PYTHON" -m pip install --upgrade futures
"$PYTHON" -m pip install yapf==0.28.0 "$PYTHON" -m pip install yapf==0.28.0
$PYTHON -m yapf --diff --recursive --style=setup.cfg "${DIRS[@]}" $PYTHON -m yapf $ACTION --recursive --style=setup.cfg "${DIRS[@]}"

@ -25,7 +25,7 @@
- script: tools/distrib/clang_tidy_code.sh - script: tools/distrib/clang_tidy_code.sh
- script: tools/distrib/pylint_code.sh - script: tools/distrib/pylint_code.sh
- script: tools/distrib/python/check_grpcio_tools.py - script: tools/distrib/python/check_grpcio_tools.py
- script: tools/distrib/yapf_code.sh - script: tools/distrib/yapf_code.sh --diff
cpu_cost: 1000 cpu_cost: 1000
- script: tools/distrib/check_protobuf_pod_version.sh - script: tools/distrib/check_protobuf_pod_version.sh
- script: tools/distrib/check_shadow_boringssl_symbol_list.sh - script: tools/distrib/check_shadow_boringssl_symbol_list.sh

Loading…
Cancel
Save