[binder] Fix server-side recv_trailing_metadata (#27184)

According to the [transport explainer](https://grpc.github.io/grpc/core/md_doc_core_transport_explainer.html), the server-side `recv_trailing_metadata` should not be completed before sending trailing metadata to the client.
pull/27300/head
Ta-Wei Tu 4 years ago committed by GitHub
parent 72171a3326
commit fa2d21716b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      src/core/ext/transport/binder/transport/binder_stream.h
  2. 78
      src/core/ext/transport/binder/transport/binder_transport.cc
  3. 9
      src/core/ext/transport/binder/utils/transport_stream_receiver.h
  4. 84
      src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc
  5. 16
      src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h
  6. 2
      test/core/transport/binder/mock_objects.h

@ -108,6 +108,9 @@ struct grpc_binder_stream {
bool* call_failed_before_recv_message = nullptr;
grpc_metadata_batch* recv_trailing_metadata;
grpc_closure* recv_trailing_metadata_finished = nullptr;
bool trailing_metadata_sent = false;
bool need_to_call_trailing_metadata_callback = false;
};
#endif // GRPC_CORE_EXT_TRANSPORT_BINDER_TRANSPORT_BINDER_STREAM_H

@ -157,6 +157,7 @@ static void cancel_stream_locked(grpc_binder_transport* gbt,
grpc_core::ExecCtx::Run(DEBUG_LOCATION, gbs->recv_message_ready,
GRPC_ERROR_REF(error));
gbs->recv_message_ready = nullptr;
gbs->recv_message->reset();
gbs->recv_message = nullptr;
gbs->call_failed_before_recv_message = nullptr;
}
@ -173,11 +174,13 @@ static void cancel_stream_locked(grpc_binder_transport* gbt,
static void recv_initial_metadata_locked(void* arg,
grpc_error_handle /*error*/) {
gpr_log(GPR_INFO, "recv_initial_metadata_locked");
RecvInitialMetadataArgs* args = static_cast<RecvInitialMetadataArgs*>(arg);
grpc_binder_stream* gbs = args->gbs;
gpr_log(GPR_INFO,
"recv_initial_metadata_locked is_client = %d is_closed = %d",
gbs->is_client, gbs->is_closed);
if (!gbs->is_closed) {
grpc_error_handle error = [&] {
GPR_ASSERT(gbs->recv_initial_metadata);
@ -200,11 +203,12 @@ static void recv_initial_metadata_locked(void* arg,
}
static void recv_message_locked(void* arg, grpc_error_handle /*error*/) {
gpr_log(GPR_INFO, "recv_message_locked");
RecvMessageArgs* args = static_cast<RecvMessageArgs*>(arg);
grpc_binder_stream* gbs = args->gbs;
gpr_log(GPR_INFO, "recv_message_locked is_client = %d is_closed = %d",
gbs->is_client, gbs->is_closed);
if (!gbs->is_closed) {
grpc_error_handle error = [&] {
GPR_ASSERT(gbs->recv_message);
@ -246,11 +250,13 @@ static void recv_message_locked(void* arg, grpc_error_handle /*error*/) {
static void recv_trailing_metadata_locked(void* arg,
grpc_error_handle /*error*/) {
gpr_log(GPR_INFO, "recv_trailing_metadata_locked");
RecvTrailingMetadataArgs* args = static_cast<RecvTrailingMetadataArgs*>(arg);
grpc_binder_stream* gbs = args->gbs;
gpr_log(GPR_INFO,
"recv_trailing_metadata_locked is_client = %d is_closed = %d",
gbs->is_client, gbs->is_closed);
if (!gbs->is_closed) {
grpc_error_handle error = [&] {
GPR_ASSERT(gbs->recv_trailing_metadata);
@ -284,10 +290,20 @@ static void recv_trailing_metadata_locked(void* arg,
return GRPC_ERROR_NONE;
}();
grpc_closure* cb = gbs->recv_trailing_metadata_finished;
gbs->recv_trailing_metadata_finished = nullptr;
gbs->recv_trailing_metadata = nullptr;
grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error);
if (gbs->is_client || gbs->trailing_metadata_sent) {
grpc_closure* cb = gbs->recv_trailing_metadata_finished;
gbs->recv_trailing_metadata_finished = nullptr;
gbs->recv_trailing_metadata = nullptr;
grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error);
} else {
// According to transport explaineer - "Server extra: This op shouldn't
// actually be considered complete until the server has also sent trailing
// metadata to provide the other side with final status"
//
// We haven't sent trailing metadata yet, so we have to delay completing
// the recv_trailing_metadata callback.
gbs->need_to_call_trailing_metadata_callback = true;
}
}
GRPC_BINDER_STREAM_UNREF(gbs, "recv_trailing_metadata");
}
@ -304,23 +320,29 @@ static void perform_stream_op_locked(void* stream_op,
GPR_ASSERT(!op->send_initial_metadata && !op->send_message &&
!op->send_trailing_metadata && !op->recv_initial_metadata &&
!op->recv_message && !op->recv_trailing_metadata);
gpr_log(GPR_INFO, "cancel_stream");
// Send trailing metadata to inform the other end about the cancellation,
// regardless if we'd already done that or not.
grpc_binder::Transaction cancel_tx(gbs->GetTxCode(), gbs->GetThenIncSeq(),
gbt->is_client);
cancel_tx.SetSuffix(grpc_binder::Metadata{});
absl::Status status = gbt->wire_writer->RpcCall(cancel_tx);
gpr_log(GPR_INFO, "cancel_stream is_client = %d", gbs->is_client);
if (!gbs->is_client) {
// Send trailing metadata to inform the other end about the cancellation,
// regardless if we'd already done that or not.
grpc_binder::Transaction cancel_tx(gbs->GetTxCode(), gbs->GetThenIncSeq(),
gbt->is_client);
cancel_tx.SetSuffix(grpc_binder::Metadata{});
cancel_tx.SetStatus(1);
absl::Status status = gbt->wire_writer->RpcCall(cancel_tx);
}
cancel_stream_locked(gbt, gbs, op->payload->cancel_stream.cancel_error);
if (op->on_complete != nullptr) {
grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete,
absl_status_to_grpc_error(status));
grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, GRPC_ERROR_NONE);
}
GRPC_BINDER_STREAM_UNREF(gbs, "perform_stream_op");
return;
}
if (gbs->is_closed) {
if (op->send_message) {
// Reset the send_message payload to prevent memory leaks.
op->payload->send_message.send_message.reset();
}
if (op->recv_initial_metadata) {
grpc_core::ExecCtx::Run(
DEBUG_LOCATION,
@ -520,6 +542,21 @@ static void perform_stream_op_locked(void* stream_op,
absl::Status status = absl::OkStatus();
if (tx) {
status = gbt->wire_writer->RpcCall(*tx);
if (!gbs->is_client && op->send_trailing_metadata) {
gbs->trailing_metadata_sent = true;
// According to transport explaineer - "Server extra: This op shouldn't
// actually be considered complete until the server has also sent trailing
// metadata to provide the other side with final status"
//
// Because we've done sending trailing metadata here, we can safely
// complete the recv_trailing_metadata callback here.
if (gbs->need_to_call_trailing_metadata_callback) {
grpc_closure* cb = gbs->recv_trailing_metadata_finished;
gbs->recv_trailing_metadata_finished = nullptr;
grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE);
gbs->need_to_call_trailing_metadata_callback = false;
}
}
}
// Note that this should only be scheduled when all non-recv ops are
// completed
@ -534,9 +571,10 @@ static void perform_stream_op_locked(void* stream_op,
static void perform_stream_op(grpc_transport* gt, grpc_stream* gs,
grpc_transport_stream_op_batch* op) {
GPR_TIMER_SCOPE("perform_stream_op", 0);
gpr_log(GPR_INFO, "%s = %p %p %p", __func__, gt, gs, op);
grpc_binder_transport* gbt = reinterpret_cast<grpc_binder_transport*>(gt);
grpc_binder_stream* gbs = reinterpret_cast<grpc_binder_stream*>(gs);
gpr_log(GPR_INFO, "%s = %p %p %p is_client = %d", __func__, gt, gs, op,
gbs->is_client);
GRPC_BINDER_STREAM_REF(gbs, "perform_stream_op");
op->handler_private.extra_arg = gbs;
gbt->combiner->Run(GRPC_CLOSURE_INIT(&op->handler_private.closure,

@ -60,15 +60,6 @@ class TransportStreamReceiver {
virtual void NotifyRecvTrailingMetadata(
StreamIdentifier id, absl::StatusOr<Metadata> trailing_metadata,
int status) = 0;
// Trailing metadata marks the end of one-side of the stream. Thus, after
// receiving trailing metadata from the other-end, we know that there will
// never be in-coming message data anymore, and all recv_message callbacks
// registered will never be satisfied. This function cancels all such
// callbacks gracefully (with GRPC_ERROR_NONE) to avoid being blocked waiting
// for them.
virtual void CancelRecvMessageCallbacksDueToTrailingMetadata(
StreamIdentifier id) = 0;
// Remove all entries associated with stream number `id`.
virtual void CancelStream(StreamIdentifier id) = 0;

@ -37,7 +37,11 @@ void TransportStreamReceiverImpl::RegisterRecvInitialMetadata(
grpc_core::MutexLock l(&m_);
auto iter = pending_initial_metadata_.find(id);
if (iter == pending_initial_metadata_.end()) {
initial_metadata_cbs_[id] = std::move(cb);
if (trailing_metadata_recvd_.count(id)) {
cb(absl::CancelledError(""));
} else {
initial_metadata_cbs_[id] = std::move(cb);
}
cb = nullptr;
} else {
initial_metadata = std::move(iter->second.front());
@ -63,7 +67,7 @@ void TransportStreamReceiverImpl::RegisterRecvMessage(
if (iter == pending_message_.end()) {
// If we'd already received trailing-metadata and there's no pending
// messages, cancel the callback.
if (recv_message_cancelled_.count(id)) {
if (trailing_metadata_recvd_.count(id)) {
cb(absl::CancelledError(
TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully));
} else {
@ -157,7 +161,7 @@ void TransportStreamReceiverImpl::NotifyRecvTrailingMetadata(
// parsed after message data, we can safely cancel all upcoming callbacks of
// recv_message.
gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
CancelRecvMessageCallbacksDueToTrailingMetadata(id);
OnRecvTrailingMetadata(id);
TrailingMetadataCallbackType cb;
{
grpc_core::MutexLock l(&m_);
@ -174,51 +178,73 @@ void TransportStreamReceiverImpl::NotifyRecvTrailingMetadata(
cb(std::move(trailing_metadata), status);
}
void TransportStreamReceiverImpl::
CancelRecvMessageCallbacksDueToTrailingMetadata(StreamIdentifier id) {
gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
MessageDataCallbackType cb = nullptr;
void TransportStreamReceiverImpl::CancelInitialMetadataCallback(
StreamIdentifier id, absl::Status error) {
InitialMetadataCallbackType callback = nullptr;
{
grpc_core::MutexLock l(&m_);
auto iter = message_cbs_.find(id);
if (iter != message_cbs_.end()) {
cb = std::move(iter->second);
message_cbs_.erase(iter);
}
recv_message_cancelled_.insert(id);
}
if (cb != nullptr) {
// The registered callback will never be satisfied. Cancel it.
cb(absl::CancelledError(
TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully));
}
}
void TransportStreamReceiverImpl::CancelStream(StreamIdentifier id) {
gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
grpc_core::MutexLock l(&m_);
{
auto iter = initial_metadata_cbs_.find(id);
if (iter != initial_metadata_cbs_.end()) {
iter->second(absl::CancelledError("Stream cancelled"));
callback = std::move(iter->second);
initial_metadata_cbs_.erase(iter);
}
}
if (callback != nullptr) {
std::move(callback)(error);
}
}
void TransportStreamReceiverImpl::CancelMessageCallback(StreamIdentifier id,
absl::Status error) {
MessageDataCallbackType callback = nullptr;
{
grpc_core::MutexLock l(&m_);
auto iter = message_cbs_.find(id);
if (iter != message_cbs_.end()) {
iter->second(absl::CancelledError("Stream cancelled"));
callback = std::move(iter->second);
message_cbs_.erase(iter);
}
}
if (callback != nullptr) {
std::move(callback)(error);
}
}
void TransportStreamReceiverImpl::CancelTrailingMetadataCallback(
StreamIdentifier id, absl::Status error) {
TrailingMetadataCallbackType callback = nullptr;
{
grpc_core::MutexLock l(&m_);
auto iter = trailing_metadata_cbs_.find(id);
if (iter != trailing_metadata_cbs_.end()) {
iter->second(absl::CancelledError("Stream cancelled"), 0);
callback = std::move(iter->second);
trailing_metadata_cbs_.erase(iter);
}
}
recv_message_cancelled_.erase(id);
if (callback != nullptr) {
std::move(callback)(error, 0);
}
}
void TransportStreamReceiverImpl::OnRecvTrailingMetadata(StreamIdentifier id) {
gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
m_.Lock();
trailing_metadata_recvd_.insert(id);
m_.Unlock();
CancelInitialMetadataCallback(id, absl::CancelledError(""));
CancelMessageCallback(
id,
absl::CancelledError(
TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully));
}
void TransportStreamReceiverImpl::CancelStream(StreamIdentifier id) {
gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
CancelInitialMetadataCallback(id, absl::CancelledError("Stream cancelled"));
CancelMessageCallback(id, absl::CancelledError("Stream cancelled"));
CancelTrailingMetadataCallback(id, absl::CancelledError("Stream cancelled"));
grpc_core::MutexLock l(&m_);
trailing_metadata_recvd_.erase(id);
pending_initial_metadata_.erase(id);
pending_message_.erase(id);
pending_trailing_metadata_.erase(id);

@ -50,11 +50,21 @@ class TransportStreamReceiverImpl : public TransportStreamReceiver {
absl::StatusOr<Metadata> trailing_metadata,
int status) override;
void CancelRecvMessageCallbacksDueToTrailingMetadata(
StreamIdentifier id) override;
void CancelStream(StreamIdentifier id) override;
private:
// Trailing metadata marks the end of one-side of the stream. Thus, after
// receiving trailing metadata from the other-end, we know that there will
// never be in-coming message data anymore, and all recv_message callbacks
// (as well as recv_initial_metadata callback, if there's any) registered will
// never be satisfied. This function cancels all such callbacks gracefully
// (with GRPC_ERROR_NONE) to avoid being blocked waiting for them.
void OnRecvTrailingMetadata(StreamIdentifier id);
void CancelInitialMetadataCallback(StreamIdentifier id, absl::Status error);
void CancelMessageCallback(StreamIdentifier id, absl::Status error);
void CancelTrailingMetadataCallback(StreamIdentifier id, absl::Status error);
std::map<StreamIdentifier, InitialMetadataCallbackType> initial_metadata_cbs_;
std::map<StreamIdentifier, MessageDataCallbackType> message_cbs_;
std::map<StreamIdentifier, TrailingMetadataCallbackType>
@ -90,7 +100,7 @@ class TransportStreamReceiverImpl : public TransportStreamReceiver {
// when RegisterRecvMessage() gets called, we should check whether
// recv_message_cancelled_ contains the corresponding stream ID, and if so,
// directly cancel the callback gracefully without pending it.
std::set<StreamIdentifier> recv_message_cancelled_ ABSL_GUARDED_BY(m_);
std::set<StreamIdentifier> trailing_metadata_recvd_ ABSL_GUARDED_BY(m_);
bool is_client_;
// Called when receiving initial metadata to inform the server about a new

@ -105,8 +105,6 @@ class MockTransportStreamReceiver : public TransportStreamReceiver {
(StreamIdentifier, absl::StatusOr<std::string>), (override));
MOCK_METHOD(void, NotifyRecvTrailingMetadata,
(StreamIdentifier, absl::StatusOr<Metadata>, int), (override));
MOCK_METHOD(void, CancelRecvMessageCallbacksDueToTrailingMetadata,
(StreamIdentifier), (override));
MOCK_METHOD(void, CancelStream, (StreamIdentifier), (override));
};

Loading…
Cancel
Save