diff --git a/src/core/ext/transport/binder/transport/BUILD b/src/core/ext/transport/binder/transport/BUILD index c1b59037036..2f83a4e4671 100644 --- a/src/core/ext/transport/binder/transport/BUILD +++ b/src/core/ext/transport/binder/transport/BUILD @@ -33,8 +33,9 @@ grpc_cc_library( "binder_transport.h", ], external_deps = [ - "absl/strings", + "absl/container:flat_hash_map", "absl/memory", + "absl/strings", ], deps = [ "//:gpr_base", diff --git a/src/core/ext/transport/binder/transport/binder_stream.h b/src/core/ext/transport/binder/transport/binder_stream.h index b9cbb3e7742..0d27333fd29 100644 --- a/src/core/ext/transport/binder/transport/binder_stream.h +++ b/src/core/ext/transport/binder/transport/binder_stream.h @@ -19,33 +19,93 @@ #include "src/core/ext/transport/binder/transport/binder_transport.h" +struct RecvInitialMetadataArgs { + grpc_binder_stream* gbs; + grpc_binder_transport* gbt; + int tx_code; + absl::StatusOr initial_metadata; +}; + +struct RecvMessageArgs { + grpc_binder_stream* gbs; + grpc_binder_transport* gbt; + int tx_code; + absl::StatusOr message; +}; + +struct RecvTrailingMetadataArgs { + grpc_binder_stream* gbs; + grpc_binder_transport* gbt; + int tx_code; + absl::StatusOr trailing_metadata; + int status; +}; + // TODO(mingcl): Figure out if we want to use class instead of struct here struct grpc_binder_stream { // server_data will be null for client, and for server it will be whatever // passed in to the accept_stream_fn callback by client. - grpc_binder_stream(grpc_binder_transport* t, grpc_core::Arena* arena, - const void* /*server_data*/, int tx_code, bool is_client) - : t(t), arena(arena), seq(0), tx_code(tx_code), is_client(is_client) {} - ~grpc_binder_stream() = default; - int GetTxCode() { return tx_code; } + grpc_binder_stream(grpc_binder_transport* t, grpc_stream_refcount* refcount, + const void* /*server_data*/, grpc_core::Arena* arena, + int tx_code, bool is_client) + : t(t), + refcount(refcount), + arena(arena), + seq(0), + tx_code(tx_code), + is_client(is_client) { + // TODO(waynetu): Should this be protected? + t->registered_stream[tx_code] = this; + + recv_initial_metadata_args.gbs = this; + recv_initial_metadata_args.gbt = t; + recv_message_args.gbs = this; + recv_message_args.gbt = t; + recv_trailing_metadata_args.gbs = this; + recv_trailing_metadata_args.gbt = t; + } + + ~grpc_binder_stream() { + GRPC_ERROR_UNREF(cancel_self_error); + if (destroy_stream_then_closure != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, destroy_stream_then_closure, + GRPC_ERROR_NONE); + } + } + + int GetTxCode() const { return tx_code; } int GetThenIncSeq() { return seq++; } grpc_binder_transport* t; + grpc_stream_refcount* refcount; grpc_core::Arena* arena; grpc_core::ManualConstructor sbs; int seq; int tx_code; bool is_client; + bool is_closed = false; + + grpc_closure* destroy_stream_then_closure = nullptr; + grpc_closure destroy_stream; + + // The reason why this stream is cancelled and closed. + grpc_error_handle cancel_self_error = GRPC_ERROR_NONE; - // TODO(waynetu): This should be guarded by a mutex. - absl::Status cancellation_error = absl::OkStatus(); + grpc_closure recv_initial_metadata_closure; + RecvInitialMetadataArgs recv_initial_metadata_args; + grpc_closure recv_message_closure; + RecvMessageArgs recv_message_args; + grpc_closure recv_trailing_metadata_closure; + RecvTrailingMetadataArgs recv_trailing_metadata_args; // We store these fields passed from op batch, in order to access them through // grpc_binder_stream grpc_metadata_batch* recv_initial_metadata; grpc_closure* recv_initial_metadata_ready = nullptr; + bool* trailing_metadata_available = nullptr; grpc_core::OrphanablePtr* recv_message; grpc_closure* recv_message_ready = nullptr; + bool* call_failed_before_recv_message = nullptr; grpc_metadata_batch* recv_trailing_metadata; grpc_closure* recv_trailing_metadata_finished = nullptr; }; diff --git a/src/core/ext/transport/binder/transport/binder_transport.cc b/src/core/ext/transport/binder/transport/binder_transport.cc index ce3aca16761..75cad8c9682 100644 --- a/src/core/ext/transport/binder/transport/binder_transport.cc +++ b/src/core/ext/transport/binder/transport/binder_transport.cc @@ -42,6 +42,60 @@ #include "src/core/lib/transport/status_metadata.h" #include "src/core/lib/transport/transport.h" +#ifndef NDEBUG +static void grpc_binder_stream_ref(grpc_binder_stream* s, const char* reason) { + grpc_stream_ref(s->refcount, reason); +} +static void grpc_binder_stream_unref(grpc_binder_stream* s, + const char* reason) { + grpc_stream_unref(s->refcount, reason); +} +static void grpc_binder_ref_transport(grpc_binder_transport* t, + const char* reason, const char* file, + int line) { + t->refs.Ref(grpc_core::DebugLocation(file, line), reason); +} +static void grpc_binder_unref_transport(grpc_binder_transport* t, + const char* reason, const char* file, + int line) { + if (t->refs.Unref(grpc_core::DebugLocation(file, line), reason)) { + delete t; + } +} +#else +static void grpc_binder_stream_ref(grpc_binder_stream* s) { + grpc_stream_ref(s->refcount); +} +static void grpc_binder_stream_unref(grpc_binder_stream* s) { + grpc_stream_unref(s->refcount); +} +static void grpc_binder_ref_transport(grpc_binder_transport* t) { + t->refs.Ref(); +} +static void grpc_binder_unref_transport(grpc_binder_transport* t) { + if (t->refs.Unref()) { + delete t; + } +} +#endif + +#ifndef NDEBUG +#define GRPC_BINDER_STREAM_REF(stream, reason) \ + grpc_binder_stream_ref(stream, reason) +#define GRPC_BINDER_STREAM_UNREF(stream, reason) \ + grpc_binder_stream_unref(stream, reason) +#define GRPC_BINDER_REF_TRANSPORT(t, r) \ + grpc_binder_ref_transport(t, r, __FILE__, __LINE__) +#define GRPC_BINDER_UNREF_TRANSPORT(t, r) \ + grpc_binder_unref_transport(t, r, __FILE__, __LINE__) +#else +#define GRPC_BINDER_STREAM_REF(stream, reason) grpc_binder_stream_ref(stream) +#define GRPC_BINDER_STREAM_UNREF(stream, reason) \ + grpc_binder_stream_unref(stream) +#define GRPC_BINDER_REF_TRANSPORT(t, r) grpc_binder_ref_transport(t) +#define GRPC_BINDER_UNREF_TRANSPORT(t, r) grpc_binder_unref_transport(t) +#endif + static int init_stream(grpc_transport* gt, grpc_stream* gs, grpc_stream_refcount* refcount, const void* server_data, grpc_core::Arena* arena) { @@ -51,8 +105,8 @@ static int init_stream(grpc_transport* gt, grpc_stream* gs, grpc_binder_transport* t = reinterpret_cast(gt); // TODO(mingcl): Figure out if we need to worry about concurrent invocation // here - new (gs) grpc_binder_stream(t, arena, server_data, t->NewStreamTxCode(), - t->is_client); + new (gs) grpc_binder_stream(t, refcount, server_data, arena, + t->NewStreamTxCode(), t->is_client); return 0; } @@ -64,8 +118,8 @@ static void set_pollset_set(grpc_transport*, grpc_stream*, grpc_pollset_set*) { gpr_log(GPR_INFO, __func__); } -void AssignMetadata(grpc_metadata_batch* mb, grpc_core::Arena* arena, - const grpc_binder::Metadata& md) { +static void AssignMetadata(grpc_metadata_batch* mb, grpc_core::Arena* arena, + const grpc_binder::Metadata& md) { grpc_metadata_batch_init(mb); for (auto& p : md) { grpc_linked_mdelem* glm = static_cast( @@ -82,51 +136,226 @@ void AssignMetadata(grpc_metadata_batch* mb, grpc_core::Arena* arena, } } -static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, - grpc_transport_stream_op_batch* op) { - GPR_TIMER_SCOPE("perform_stream_op", 0); - gpr_log(GPR_INFO, "%s = %p %p %p", __func__, gt, gs, op); - grpc_binder_transport* gbt = reinterpret_cast(gt); - grpc_binder_stream* gbs = reinterpret_cast(gs); +static void cancel_stream_locked(grpc_binder_transport* gbt, + grpc_binder_stream* gbs, + grpc_error_handle error) { + gpr_log(GPR_INFO, "cancel_stream_locked"); + if (!gbs->is_closed) { + GPR_ASSERT(gbs->cancel_self_error == GRPC_ERROR_NONE); + gbs->is_closed = true; + gbs->cancel_self_error = GRPC_ERROR_REF(error); + gbt->transport_stream_receiver->CancelStream(gbs->tx_code); + gbt->registered_stream.erase(gbs->tx_code); + if (gbs->recv_initial_metadata_ready != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, gbs->recv_initial_metadata_ready, + GRPC_ERROR_REF(error)); + gbs->recv_initial_metadata_ready = nullptr; + gbs->recv_initial_metadata = nullptr; + gbs->trailing_metadata_available = nullptr; + } + if (gbs->recv_message_ready != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, gbs->recv_message_ready, + GRPC_ERROR_REF(error)); + gbs->recv_message_ready = nullptr; + gbs->recv_message = nullptr; + gbs->call_failed_before_recv_message = nullptr; + } + if (gbs->recv_trailing_metadata_finished != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + gbs->recv_trailing_metadata_finished, + GRPC_ERROR_REF(error)); + gbs->recv_trailing_metadata_finished = nullptr; + gbs->recv_trailing_metadata = nullptr; + } + } + GRPC_ERROR_UNREF(error); +} + +static void recv_initial_metadata_locked(void* arg, + grpc_error_handle /*error*/) { + gpr_log(GPR_INFO, "recv_initial_metadata_locked"); + RecvInitialMetadataArgs* args = static_cast(arg); + + grpc_binder_stream* gbs = args->gbs; + + if (!gbs->is_closed) { + grpc_error_handle error = [&] { + GPR_ASSERT(gbs->recv_initial_metadata); + GPR_ASSERT(gbs->recv_initial_metadata_ready); + if (!args->initial_metadata.ok()) { + gpr_log(GPR_ERROR, "Failed to parse initial metadata"); + return absl_status_to_grpc_error(args->initial_metadata.status()); + } + AssignMetadata(gbs->recv_initial_metadata, gbs->arena, + *args->initial_metadata); + return GRPC_ERROR_NONE; + }(); + + grpc_closure* cb = gbs->recv_initial_metadata_ready; + gbs->recv_initial_metadata_ready = nullptr; + gbs->recv_initial_metadata = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + } + GRPC_BINDER_STREAM_UNREF(gbs, "recv_initial_metadata"); +} + +static void recv_message_locked(void* arg, grpc_error_handle /*error*/) { + gpr_log(GPR_INFO, "recv_message_locked"); + RecvMessageArgs* args = static_cast(arg); + + grpc_binder_stream* gbs = args->gbs; + + if (!gbs->is_closed) { + grpc_error_handle error = [&] { + GPR_ASSERT(gbs->recv_message); + GPR_ASSERT(gbs->recv_message_ready); + if (!args->message.ok()) { + gpr_log(GPR_ERROR, "Failed to receive message"); + if (args->message.status().message() == + grpc_binder::TransportStreamReceiver:: + kGrpcBinderTransportCancelledGracefully) { + gpr_log(GPR_ERROR, "message cancelled gracefully"); + // Cancelled because we've already received trailing metadata. + // It's not an error in this case. + return GRPC_ERROR_NONE; + } else { + return absl_status_to_grpc_error(args->message.status()); + } + } + grpc_slice_buffer buf; + grpc_slice_buffer_init(&buf); + grpc_slice_buffer_add(&buf, grpc_slice_from_cpp_string(*args->message)); + + gbs->sbs.Init(&buf, 0); + gbs->recv_message->reset(gbs->sbs.get()); + return GRPC_ERROR_NONE; + }(); + + if (error != GRPC_ERROR_NONE && + gbs->call_failed_before_recv_message != nullptr) { + *gbs->call_failed_before_recv_message = true; + } + grpc_closure* cb = gbs->recv_message_ready; + gbs->recv_message_ready = nullptr; + gbs->recv_message = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + } + + GRPC_BINDER_STREAM_UNREF(gbs, "recv_message"); +} +static void recv_trailing_metadata_locked(void* arg, + grpc_error_handle /*error*/) { + gpr_log(GPR_INFO, "recv_trailing_metadata_locked"); + RecvTrailingMetadataArgs* args = static_cast(arg); + + grpc_binder_stream* gbs = args->gbs; + + if (!gbs->is_closed) { + grpc_error_handle error = [&] { + GPR_ASSERT(gbs->recv_trailing_metadata); + GPR_ASSERT(gbs->recv_trailing_metadata_finished); + if (!args->trailing_metadata.ok()) { + gpr_log(GPR_ERROR, "Failed to receive trailing metadata"); + return absl_status_to_grpc_error(args->trailing_metadata.status()); + } + if (!gbs->is_client) { + // Client will not send non-empty trailing metadata. + if (!args->trailing_metadata.value().empty()) { + gpr_log(GPR_ERROR, "Server receives non-empty trailing metadata."); + return GRPC_ERROR_CANCELLED; + } + } else { + AssignMetadata(gbs->recv_trailing_metadata, gbs->arena, + *args->trailing_metadata); + // Append status to metadata + // TODO(b/192208695): See if we can avoid to manually put status + // code into the header + gpr_log(GPR_INFO, "status = %d", args->status); + grpc_linked_mdelem* glm = static_cast( + gbs->arena->Alloc(sizeof(grpc_linked_mdelem))); + glm->md = grpc_get_reffed_status_elem(args->status); + GPR_ASSERT(grpc_metadata_batch_link_tail(gbs->recv_trailing_metadata, + glm) == GRPC_ERROR_NONE); + gpr_log(GPR_INFO, "trailing_metadata = %p", + gbs->recv_trailing_metadata); + gpr_log(GPR_INFO, "glm = %p", glm); + } + return GRPC_ERROR_NONE; + }(); + + grpc_closure* cb = gbs->recv_trailing_metadata_finished; + gbs->recv_trailing_metadata_finished = nullptr; + gbs->recv_trailing_metadata = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + } + GRPC_BINDER_STREAM_UNREF(gbs, "recv_trailing_metadata"); +} + +static void perform_stream_op_locked(void* stream_op, + grpc_error_handle /*error*/) { + grpc_transport_stream_op_batch* op = + static_cast(stream_op); + grpc_binder_stream* gbs = + static_cast(op->handler_private.extra_arg); + grpc_binder_transport* gbt = gbs->t; if (op->cancel_stream) { // TODO(waynetu): Is this true? GPR_ASSERT(!op->send_initial_metadata && !op->send_message && !op->send_trailing_metadata && !op->recv_initial_metadata && !op->recv_message && !op->recv_trailing_metadata); gpr_log(GPR_INFO, "cancel_stream"); - gpr_log( - GPR_INFO, "cancel_stream error = %s", - grpc_error_std_string(op->payload->cancel_stream.cancel_error).c_str()); - gbs->cancellation_error = - grpc_error_to_absl_status(op->payload->cancel_stream.cancel_error); // Send trailing metadata to inform the other end about the cancellation, // regardless if we'd already done that or not. grpc_binder::Transaction cancel_tx(gbs->GetTxCode(), gbs->GetThenIncSeq(), gbt->is_client); cancel_tx.SetSuffix(grpc_binder::Metadata{}); absl::Status status = gbt->wire_writer->RpcCall(cancel_tx); - gbt->transport_stream_receiver->CancelStream(gbs->tx_code, - gbs->cancellation_error); - GRPC_ERROR_UNREF(op->payload->cancel_stream.cancel_error); + cancel_stream_locked(gbt, gbs, op->payload->cancel_stream.cancel_error); if (op->on_complete != nullptr) { grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, absl_status_to_grpc_error(status)); - gpr_log(GPR_INFO, "on_complete closure schuduled"); } + GRPC_BINDER_STREAM_UNREF(gbs, "perform_stream_op"); + return; + } + + if (gbs->is_closed) { + if (op->recv_initial_metadata) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_REF(gbs->cancel_self_error)); + } + if (op->recv_message) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + op->payload->recv_message.recv_message_ready, + GRPC_ERROR_REF(gbs->cancel_self_error)); + } + if (op->recv_trailing_metadata) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready, + GRPC_ERROR_REF(gbs->cancel_self_error)); + } + if (op->on_complete != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, + GRPC_ERROR_REF(gbs->cancel_self_error)); + } + GRPC_BINDER_STREAM_UNREF(gbs, "perform_stream_op"); return; } std::unique_ptr tx; + int tx_code = gbs->tx_code; if (op->send_initial_metadata || op->send_message || op->send_trailing_metadata) { // Only increment sequence number when there's a send operation. tx = absl::make_unique( - /*tx_code=*/gbs->GetTxCode(), /*seq_num=*/gbs->GetThenIncSeq(), - gbt->is_client); + /*tx_code=*/tx_code, /*seq_num=*/gbs->GetThenIncSeq(), gbt->is_client); } - if (op->send_initial_metadata && gbs->cancellation_error.ok()) { + if (op->send_initial_metadata) { gpr_log(GPR_INFO, "send_initial_metadata"); grpc_binder::Metadata init_md; auto batch = op->payload->send_initial_metadata.send_initial_metadata; @@ -155,7 +384,7 @@ static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, } tx->SetPrefix(init_md); } - if (op->send_message && gbs->cancellation_error.ok()) { + if (op->send_message) { gpr_log(GPR_INFO, "send_message"); size_t remaining = op->payload->send_message.send_message->length(); std::string message_data; @@ -181,7 +410,8 @@ static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, // use-after-free issue in call.cc? op->payload->send_message.send_message.reset(); } - if (op->send_trailing_metadata && gbs->cancellation_error.ok()) { + + if (op->send_trailing_metadata) { gpr_log(GPR_INFO, "send_trailing_metadata"); auto batch = op->payload->send_trailing_metadata.send_trailing_metadata; grpc_binder::Metadata trailing_metadata; @@ -212,137 +442,79 @@ static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, } if (op->recv_initial_metadata) { gpr_log(GPR_INFO, "recv_initial_metadata"); - if (!gbs->cancellation_error.ok()) { - grpc_core::ExecCtx::Run( - DEBUG_LOCATION, - op->payload->recv_initial_metadata.recv_initial_metadata_ready, - absl_status_to_grpc_error(gbs->cancellation_error)); - } else { - gbs->recv_initial_metadata_ready = - op->payload->recv_initial_metadata.recv_initial_metadata_ready; - gbs->recv_initial_metadata = - op->payload->recv_initial_metadata.recv_initial_metadata; - gbt->transport_stream_receiver->RegisterRecvInitialMetadata( - gbs->tx_code, - [gbs](absl::StatusOr initial_metadata) { - grpc_core::ExecCtx exec_ctx; - GPR_ASSERT(gbs->recv_initial_metadata); - GPR_ASSERT(gbs->recv_initial_metadata_ready); - if (!initial_metadata.ok()) { - gpr_log(GPR_ERROR, "Failed to parse initial metadata"); - grpc_core::ExecCtx::Run( - DEBUG_LOCATION, gbs->recv_initial_metadata_ready, - absl_status_to_grpc_error(initial_metadata.status())); - return; - } - AssignMetadata(gbs->recv_initial_metadata, gbs->arena, - *initial_metadata); - grpc_core::ExecCtx::Run(DEBUG_LOCATION, - gbs->recv_initial_metadata_ready, - GRPC_ERROR_NONE); - }); - } + gbs->recv_initial_metadata_ready = + op->payload->recv_initial_metadata.recv_initial_metadata_ready; + gbs->recv_initial_metadata = + op->payload->recv_initial_metadata.recv_initial_metadata; + gbs->trailing_metadata_available = + op->payload->recv_initial_metadata.trailing_metadata_available; + GRPC_BINDER_STREAM_REF(gbs, "recv_initial_metadata"); + gbt->transport_stream_receiver->RegisterRecvInitialMetadata( + tx_code, [tx_code, gbs, + gbt](absl::StatusOr initial_metadata) { + grpc_core::ExecCtx exec_ctx; + if (gbs->is_closed) { + GRPC_BINDER_STREAM_UNREF(gbs, "recv_initial_metadata"); + return; + } + gbs->recv_initial_metadata_args.tx_code = tx_code; + gbs->recv_initial_metadata_args.initial_metadata = + std::move(initial_metadata); + gbt->combiner->Run( + GRPC_CLOSURE_INIT(&gbs->recv_initial_metadata_closure, + recv_initial_metadata_locked, + &gbs->recv_initial_metadata_args, nullptr), + GRPC_ERROR_NONE); + }); } if (op->recv_message) { gpr_log(GPR_INFO, "recv_message"); - if (!gbs->cancellation_error.ok()) { - grpc_core::ExecCtx::Run( - DEBUG_LOCATION, op->payload->recv_message.recv_message_ready, - absl_status_to_grpc_error(gbs->cancellation_error)); - } else { - gbs->recv_message_ready = op->payload->recv_message.recv_message_ready; - gbs->recv_message = op->payload->recv_message.recv_message; - gbt->transport_stream_receiver->RegisterRecvMessage( - gbs->tx_code, [gbs](absl::StatusOr message) { - grpc_core::ExecCtx exec_ctx; - GPR_ASSERT(gbs->recv_message); - GPR_ASSERT(gbs->recv_message_ready); - if (!message.ok()) { - gpr_log(GPR_ERROR, "Failed to receive message"); - if (message.status().message() == - grpc_binder::TransportStreamReceiver:: - kGrpcBinderTransportCancelledGracefully) { - gpr_log(GPR_ERROR, "message cancelled gracefully"); - // Cancelled because we've already received trailing metadata. - // It's not an error in this case. - grpc_core::ExecCtx::Run(DEBUG_LOCATION, gbs->recv_message_ready, - GRPC_ERROR_NONE); - } else { - grpc_core::ExecCtx::Run( - DEBUG_LOCATION, gbs->recv_message_ready, - absl_status_to_grpc_error(message.status())); - } - return; - } - grpc_slice_buffer buf; - grpc_slice_buffer_init(&buf); - grpc_slice_buffer_add(&buf, grpc_slice_from_cpp_string(*message)); - - gbs->sbs.Init(&buf, 0); - gbs->recv_message->reset(gbs->sbs.get()); - grpc_core::ExecCtx::Run(DEBUG_LOCATION, gbs->recv_message_ready, - GRPC_ERROR_NONE); - }); - } + gbs->recv_message_ready = op->payload->recv_message.recv_message_ready; + gbs->recv_message = op->payload->recv_message.recv_message; + gbs->call_failed_before_recv_message = + op->payload->recv_message.call_failed_before_recv_message; + GRPC_BINDER_STREAM_REF(gbs, "recv_message"); + gbt->transport_stream_receiver->RegisterRecvMessage( + tx_code, [tx_code, gbs, gbt](absl::StatusOr message) { + grpc_core::ExecCtx exec_ctx; + if (gbs->is_closed) { + GRPC_BINDER_STREAM_UNREF(gbs, "recv_message"); + return; + } + gbs->recv_message_args.tx_code = tx_code; + gbs->recv_message_args.message = std::move(message); + gbt->combiner->Run( + GRPC_CLOSURE_INIT(&gbs->recv_message_closure, recv_message_locked, + &gbs->recv_message_args, nullptr), + GRPC_ERROR_NONE); + }); } if (op->recv_trailing_metadata) { gpr_log(GPR_INFO, "recv_trailing_metadata"); - if (!gbs->cancellation_error.ok()) { - grpc_core::ExecCtx::Run( - DEBUG_LOCATION, - op->payload->recv_trailing_metadata.recv_trailing_metadata_ready, - absl_status_to_grpc_error(gbs->cancellation_error)); - } else { - gbs->recv_trailing_metadata_finished = - op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; - gbs->recv_trailing_metadata = - op->payload->recv_trailing_metadata.recv_trailing_metadata; - gbt->transport_stream_receiver->RegisterRecvTrailingMetadata( - gbs->tx_code, - [gbs](absl::StatusOr trailing_metadata, - int status) { - grpc_core::ExecCtx exec_ctx; - GPR_ASSERT(gbs->recv_trailing_metadata); - GPR_ASSERT(gbs->recv_trailing_metadata_finished); - if (!trailing_metadata.ok()) { - gpr_log(GPR_ERROR, "Failed to receive trailing metadata"); - grpc_core::ExecCtx::Run( - DEBUG_LOCATION, gbs->recv_trailing_metadata_finished, - absl_status_to_grpc_error(trailing_metadata.status())); - return; - } - if (!gbs->is_client) { - // Client will not send non-empty trailing metadata. - if (!trailing_metadata.value().empty()) { - gpr_log(GPR_ERROR, - "Server receives non-empty trailing metadata."); - grpc_core::ExecCtx::Run(DEBUG_LOCATION, - gbs->recv_trailing_metadata_finished, - GRPC_ERROR_CANCELLED); - return; - } - } else { - AssignMetadata(gbs->recv_trailing_metadata, gbs->arena, - *trailing_metadata); - // Append status to metadata - // TODO(b/192208695): See if we can avoid to manually put status - // code into the header - gpr_log(GPR_INFO, "status = %d", status); - grpc_linked_mdelem* glm = static_cast( - gbs->arena->Alloc(sizeof(grpc_linked_mdelem))); - glm->md = grpc_get_reffed_status_elem(status); - GPR_ASSERT(grpc_metadata_batch_link_tail( - gbs->recv_trailing_metadata, glm) == - GRPC_ERROR_NONE); - gpr_log(GPR_INFO, "trailing_metadata = %p", - gbs->recv_trailing_metadata); - gpr_log(GPR_INFO, "glm = %p", glm); - } - grpc_core::ExecCtx::Run(DEBUG_LOCATION, - gbs->recv_trailing_metadata_finished, - GRPC_ERROR_NONE); - }); - } + gbs->recv_trailing_metadata_finished = + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + gbs->recv_trailing_metadata = + op->payload->recv_trailing_metadata.recv_trailing_metadata; + GRPC_BINDER_STREAM_REF(gbs, "recv_trailing_metadata"); + gbt->transport_stream_receiver->RegisterRecvTrailingMetadata( + tx_code, [tx_code, gbs, gbt]( + absl::StatusOr trailing_metadata, + int status) { + grpc_core::ExecCtx exec_ctx; + if (gbs->is_closed) { + GRPC_BINDER_STREAM_UNREF(gbs, "recv_trailing_metadata"); + return; + } + gbs->recv_trailing_metadata_args.tx_code = tx_code; + gbs->recv_trailing_metadata_args.trailing_metadata = + std::move(trailing_metadata); + gbs->recv_trailing_metadata_args.status = status; + gbt->combiner->Run( + GRPC_CLOSURE_INIT(&gbs->recv_trailing_metadata_closure, + recv_trailing_metadata_locked, + &gbs->recv_trailing_metadata_args, nullptr), + GRPC_ERROR_NONE); + }); } // Only send transaction when there's a send op presented. absl::Status status = absl::OkStatus(); @@ -356,12 +528,39 @@ static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, absl_status_to_grpc_error(status)); gpr_log(GPR_INFO, "on_complete closure schuduled"); } + GRPC_BINDER_STREAM_UNREF(gbs, "perform_stream_op"); } -static void perform_transport_op(grpc_transport* gt, grpc_transport_op* op) { - gpr_log(GPR_INFO, __func__); +static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, + grpc_transport_stream_op_batch* op) { + GPR_TIMER_SCOPE("perform_stream_op", 0); + gpr_log(GPR_INFO, "%s = %p %p %p", __func__, gt, gs, op); grpc_binder_transport* gbt = reinterpret_cast(gt); - grpc_core::MutexLock lock(&gbt->mu); + grpc_binder_stream* gbs = reinterpret_cast(gs); + GRPC_BINDER_STREAM_REF(gbs, "perform_stream_op"); + op->handler_private.extra_arg = gbs; + gbt->combiner->Run(GRPC_CLOSURE_INIT(&op->handler_private.closure, + perform_stream_op_locked, op, nullptr), + GRPC_ERROR_NONE); +} + +static void close_transport_locked(grpc_binder_transport* gbt) { + gbt->state_tracker.SetState(GRPC_CHANNEL_SHUTDOWN, absl::OkStatus(), + "transport closed due to disconnection/goaway"); + while (!gbt->registered_stream.empty()) { + cancel_stream_locked( + gbt, gbt->registered_stream.begin()->second, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("transport closed"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } +} + +static void perform_transport_op_locked(void* transport_op, + grpc_error_handle /*error*/) { + grpc_transport_op* op = static_cast(transport_op); + grpc_binder_transport* gbt = + static_cast(op->handler_private.extra_arg); // TODO(waynetu): Should we lock here to avoid data race? if (op->start_connectivity_watch != nullptr) { gbt->state_tracker.AddWatcher(op->start_connectivity_watch_state, @@ -387,33 +586,58 @@ static void perform_transport_op(grpc_transport* gt, grpc_transport_op* op) { GRPC_ERROR_UNREF(op->goaway_error); } if (do_close) { - gbt->state_tracker.SetState(GRPC_CHANNEL_SHUTDOWN, absl::OkStatus(), - "transport closed due to disconnection/goaway"); + close_transport_locked(gbt); } + GRPC_BINDER_UNREF_TRANSPORT(gbt, "perform_transport_op"); } -static void destroy_stream(grpc_transport* gt, grpc_stream* gs, - grpc_closure* then_schedule_closure) { +static void perform_transport_op(grpc_transport* gt, grpc_transport_op* op) { gpr_log(GPR_INFO, __func__); grpc_binder_transport* gbt = reinterpret_cast(gt); - grpc_binder_stream* gbs = reinterpret_cast(gs); - gbt->transport_stream_receiver->Clear(gbs->tx_code); - // TODO(waynetu): Currently, there's nothing to be cleaned up. If additional - // fields are added to grpc_binder_stream in the future, we might need to use - // reference-counting to determine who does the actual cleaning. + op->handler_private.extra_arg = gbt; + GRPC_BINDER_REF_TRANSPORT(gbt, "perform_transport_op"); + gbt->combiner->Run( + GRPC_CLOSURE_INIT(&op->handler_private.closure, + perform_transport_op_locked, op, nullptr), + GRPC_ERROR_NONE); +} + +static void destroy_stream_locked(void* sp, grpc_error_handle /*error*/) { + grpc_binder_stream* gbs = static_cast(sp); + grpc_binder_transport* gbt = gbs->t; + cancel_stream_locked( + gbt, gbs, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("destroy stream"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); gbs->~grpc_binder_stream(); - grpc_core::ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, - GRPC_ERROR_NONE); } -static void destroy_transport(grpc_transport* gt) { +static void destroy_stream(grpc_transport* /*gt*/, grpc_stream* gs, + grpc_closure* then_schedule_closure) { gpr_log(GPR_INFO, __func__); - grpc_binder_transport* gbt = reinterpret_cast(gt); + grpc_binder_stream* gbs = reinterpret_cast(gs); + gbs->destroy_stream_then_closure = then_schedule_closure; + gbs->t->combiner->Run(GRPC_CLOSURE_INIT(&gbs->destroy_stream, + destroy_stream_locked, gbs, nullptr), + GRPC_ERROR_NONE); +} + +static void destroy_transport_locked(void* gt, grpc_error_handle /*error*/) { + grpc_binder_transport* gbt = static_cast(gt); + close_transport_locked(gbt); // Release the references held by the transport. gbt->wire_reader = nullptr; gbt->transport_stream_receiver = nullptr; gbt->wire_writer = nullptr; - gbt->Unref(); + GRPC_BINDER_UNREF_TRANSPORT(gbt, "transport destroyed"); +} + +static void destroy_transport(grpc_transport* gt) { + gpr_log(GPR_INFO, __func__); + grpc_binder_transport* gbt = reinterpret_cast(gt); + gbt->combiner->Run( + GRPC_CLOSURE_CREATE(destroy_transport_locked, gbt, nullptr), + GRPC_ERROR_NONE); } static grpc_endpoint* get_endpoint(grpc_transport*) { @@ -435,32 +659,46 @@ static const grpc_transport_vtable vtable = {sizeof(grpc_binder_stream), static const grpc_transport_vtable* get_vtable() { return &vtable; } +static void accept_stream_locked(void* gt, grpc_error_handle /*error*/) { + grpc_binder_transport* gbt = static_cast(gt); + if (gbt->accept_stream_fn) { + // must pass in a non-null value. + (*gbt->accept_stream_fn)(gbt->accept_stream_user_data, &gbt->base, gbt); + } +} + grpc_binder_transport::grpc_binder_transport( std::unique_ptr binder, bool is_client) : is_client(is_client), + combiner(grpc_combiner_create()), state_tracker(is_client ? "binder_transport_client" : "binder_transport_server"), refs(1, nullptr) { gpr_log(GPR_INFO, __func__); base.vtable = get_vtable(); + GRPC_CLOSURE_INIT(&accept_stream_closure, accept_stream_locked, this, + nullptr); transport_stream_receiver = std::make_shared( is_client, /*accept_stream_callback=*/[this] { grpc_core::ExecCtx exec_ctx; - grpc_core::MutexLock lock(&mu); - if (accept_stream_fn) { - // must pass in a non-null value. - (*accept_stream_fn)(accept_stream_user_data, &base, this); - } + combiner->Run(&accept_stream_closure, GRPC_ERROR_NONE); }); // WireReader holds a ref to grpc_binder_transport. - Ref(); + GRPC_BINDER_REF_TRANSPORT(this, "wire reader"); wire_reader = grpc_core::MakeOrphanable( transport_stream_receiver, is_client, - /*on_destruct_callback=*/[this] { Unref(); }); + /*on_destruct_callback=*/[this] { + // Unref transport when destructed. + GRPC_BINDER_UNREF_TRANSPORT(this, "wire reader"); + }); wire_writer = wire_reader->SetupTransport(std::move(binder)); } +grpc_binder_transport::~grpc_binder_transport() { + GRPC_COMBINER_UNREF(combiner, "binder_transport"); +} + grpc_transport* grpc_create_binder_transport_client( std::unique_ptr endpoint_binder) { gpr_log(GPR_INFO, __func__); diff --git a/src/core/ext/transport/binder/transport/binder_transport.h b/src/core/ext/transport/binder/transport/binder_transport.h index fd842d0ff72..d7e141b71f5 100644 --- a/src/core/ext/transport/binder/transport/binder_transport.h +++ b/src/core/ext/transport/binder/transport/binder_transport.h @@ -22,15 +22,20 @@ #include #include +#include "absl/container/flat_hash_map.h" + #include #include "src/core/ext/transport/binder/utils/transport_stream_receiver.h" #include "src/core/ext/transport/binder/wire_format/binder.h" #include "src/core/ext/transport/binder/wire_format/wire_reader.h" #include "src/core/ext/transport/binder/wire_format/wire_writer.h" +#include "src/core/lib/iomgr/combiner.h" #include "src/core/lib/transport/transport.h" #include "src/core/lib/transport/transport_impl.h" +struct grpc_binder_stream; + // TODO(mingcl): Consider putting the struct in a namespace (Eventually this // depends on what style we want to follow) // TODO(mingcl): Decide casing for this class name. Should we use C-style class @@ -38,6 +43,7 @@ struct grpc_binder_transport { explicit grpc_binder_transport(std::unique_ptr binder, bool is_client); + ~grpc_binder_transport(); int NewStreamTxCode() { // TODO(mingcl): Wrap around when all tx codes are used. "If we do detect a @@ -47,14 +53,6 @@ struct grpc_binder_transport { return next_free_tx_code++; } - void Ref() { refs.Ref(); } - - void Unref() { - if (refs.Unref()) { - delete this; - } - } - grpc_transport base; /* must be first */ std::shared_ptr @@ -63,7 +61,11 @@ struct grpc_binder_transport { std::shared_ptr wire_writer; bool is_client; - grpc_core::Mutex mu; + // A set of currently registered streams (the key is the stream ID). + absl::flat_hash_map registered_stream; + grpc_core::Combiner* combiner; + + grpc_closure accept_stream_closure; // The callback and the data for the callback when the stream is connected // between client and server. @@ -72,10 +74,10 @@ struct grpc_binder_transport { void* accept_stream_user_data = nullptr; grpc_core::ConnectivityStateTracker state_tracker; + grpc_core::RefCount refs; private: int next_free_tx_code = grpc_binder::kFirstCallId; - grpc_core::RefCount refs; }; grpc_transport* grpc_create_binder_transport_client( diff --git a/src/core/ext/transport/binder/utils/transport_stream_receiver.h b/src/core/ext/transport/binder/utils/transport_stream_receiver.h index 3e209bf36d5..fa1d4774ff9 100644 --- a/src/core/ext/transport/binder/utils/transport_stream_receiver.h +++ b/src/core/ext/transport/binder/utils/transport_stream_receiver.h @@ -70,8 +70,7 @@ class TransportStreamReceiver { virtual void CancelRecvMessageCallbacksDueToTrailingMetadata( StreamIdentifier id) = 0; // Remove all entries associated with stream number `id`. - virtual void Clear(StreamIdentifier id) = 0; - virtual void CancelStream(StreamIdentifier id, absl::Status error) = 0; + virtual void CancelStream(StreamIdentifier id) = 0; static const absl::string_view kGrpcBinderTransportCancelledGracefully; }; diff --git a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc index 14e10112549..36241423e6f 100644 --- a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc +++ b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc @@ -194,46 +194,30 @@ void TransportStreamReceiverImpl:: } } -void TransportStreamReceiverImpl::CancelStream(StreamIdentifier id, - absl::Status error) { - InitialMetadataCallbackType initial_metadata_callback = nullptr; - MessageDataCallbackType message_data_callback = nullptr; - TrailingMetadataCallbackType trailing_metadata_callback = nullptr; +void TransportStreamReceiverImpl::CancelStream(StreamIdentifier id) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + grpc_core::MutexLock l(&m_); { - grpc_core::MutexLock l(&m_); - auto initial_metadata_iter = initial_metadata_cbs_.find(id); - if (initial_metadata_iter != initial_metadata_cbs_.end()) { - initial_metadata_callback = std::move(initial_metadata_iter->second); - initial_metadata_cbs_.erase(initial_metadata_iter); - } - auto message_data_iter = message_cbs_.find(id); - if (message_data_iter != message_cbs_.end()) { - message_data_callback = std::move(message_data_iter->second); - message_cbs_.erase(message_data_iter); - } - auto trailing_metadata_iter = trailing_metadata_cbs_.find(id); - if (trailing_metadata_iter != trailing_metadata_cbs_.end()) { - trailing_metadata_callback = std::move(trailing_metadata_iter->second); - trailing_metadata_cbs_.erase(trailing_metadata_iter); + auto iter = initial_metadata_cbs_.find(id); + if (iter != initial_metadata_cbs_.end()) { + iter->second(absl::CancelledError("Stream cancelled")); + initial_metadata_cbs_.erase(iter); } } - if (initial_metadata_callback != nullptr) { - initial_metadata_callback(error); - } - if (message_data_callback != nullptr) { - message_data_callback(error); + { + auto iter = message_cbs_.find(id); + if (iter != message_cbs_.end()) { + iter->second(absl::CancelledError("Stream cancelled")); + message_cbs_.erase(iter); + } } - if (trailing_metadata_callback != nullptr) { - trailing_metadata_callback(error, 0); + { + auto iter = trailing_metadata_cbs_.find(id); + if (iter != trailing_metadata_cbs_.end()) { + iter->second(absl::CancelledError("Stream cancelled"), 0); + trailing_metadata_cbs_.erase(iter); + } } -} - -void TransportStreamReceiverImpl::Clear(StreamIdentifier id) { - gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); - grpc_core::MutexLock l(&m_); - initial_metadata_cbs_.erase(id); - message_cbs_.erase(id); - trailing_metadata_cbs_.erase(id); recv_message_cancelled_.erase(id); pending_initial_metadata_.erase(id); pending_message_.erase(id); diff --git a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h index 040b4d78388..154a0db04d3 100644 --- a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h +++ b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h @@ -52,8 +52,7 @@ class TransportStreamReceiverImpl : public TransportStreamReceiver { void CancelRecvMessageCallbacksDueToTrailingMetadata( StreamIdentifier id) override; - void Clear(StreamIdentifier id) override; - void CancelStream(StreamIdentifier, absl::Status error) override; + void CancelStream(StreamIdentifier id) override; private: std::map initial_metadata_cbs_; diff --git a/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc b/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc index 9a41d032e04..e39e0707563 100644 --- a/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc +++ b/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc @@ -60,7 +60,7 @@ absl::StatusOr parse_metadata(const ReadableParcel* reader) { std::string value{}; if (count > 0) RETURN_IF_ERROR(reader->ReadByteArray(&value)); gpr_log(GPR_INFO, "value = %s", value.c_str()); - ret.push_back({key, value}); + ret.emplace_back(key, value); } return ret; } diff --git a/test/core/transport/binder/binder_transport_test.cc b/test/core/transport/binder/binder_transport_test.cc index 4130ee8cfb2..748725cd243 100644 --- a/test/core/transport/binder/binder_transport_test.cc +++ b/test/core/transport/binder/binder_transport_test.cc @@ -26,6 +26,7 @@ #include "absl/memory/memory.h" #include "absl/strings/match.h" +#include "absl/synchronization/notification.h" #include @@ -52,9 +53,11 @@ class BinderTransportTest : public ::testing::Test { } ~BinderTransportTest() override { - auto* gbt = reinterpret_cast(transport_); - delete gbt; + grpc_core::ExecCtx exec_ctx; + grpc_transport_destroy(transport_); + grpc_core::ExecCtx::Get()->Flush(); for (grpc_binder_stream* gbs : stream_buffer_) { + gbs->~grpc_binder_stream(); gpr_free(gbs); } arena_->Destroy(); @@ -98,13 +101,16 @@ void MockCallback(void* arg, grpc_error_handle error); class MockGrpcClosure { public: - MockGrpcClosure() { + explicit MockGrpcClosure(absl::Notification* notification = nullptr) + : notification_(notification) { GRPC_CLOSURE_INIT(&closure_, MockCallback, this, nullptr); } grpc_closure* GetGrpcClosure() { return &closure_; } MOCK_METHOD(void, Callback, (grpc_error_handle), ()); + absl::Notification* notification_; + private: grpc_closure closure_; }; @@ -112,6 +118,9 @@ class MockGrpcClosure { void MockCallback(void* arg, grpc_error_handle error) { MockGrpcClosure* mock_closure = static_cast(arg); mock_closure->Callback(error); + if (mock_closure->notification_) { + mock_closure->notification_->Notify(); + } } // Matches with transactions having the desired flag, method_ref, @@ -221,7 +230,8 @@ struct MakeSendTrailingMetadata { struct MakeRecvInitialMetadata { explicit MakeRecvInitialMetadata(grpc_transport_stream_op_batch* op, - Expectation* call_before = nullptr) { + Expectation* call_before = nullptr) + : ready(¬ification) { grpc_metadata_batch_init(&grpc_initial_metadata); op->recv_initial_metadata = true; op->payload->recv_initial_metadata.recv_initial_metadata = @@ -241,11 +251,13 @@ struct MakeRecvInitialMetadata { MockGrpcClosure ready; grpc_metadata_batch grpc_initial_metadata; + absl::Notification notification; }; struct MakeRecvMessage { explicit MakeRecvMessage(grpc_transport_stream_op_batch* op, - Expectation* call_before = nullptr) { + Expectation* call_before = nullptr) + : ready(¬ification) { op->recv_message = true; op->payload->recv_message.recv_message = &grpc_message; op->payload->recv_message.recv_message_ready = ready.GetGrpcClosure(); @@ -257,12 +269,14 @@ struct MakeRecvMessage { } MockGrpcClosure ready; + absl::Notification notification; grpc_core::OrphanablePtr grpc_message; }; struct MakeRecvTrailingMetadata { explicit MakeRecvTrailingMetadata(grpc_transport_stream_op_batch* op, - Expectation* call_before = nullptr) { + Expectation* call_before = nullptr) + : ready(¬ification) { grpc_metadata_batch_init(&grpc_trailing_metadata); op->recv_trailing_metadata = true; op->payload->recv_trailing_metadata.recv_trailing_metadata = @@ -282,6 +296,7 @@ struct MakeRecvTrailingMetadata { MockGrpcClosure ready; grpc_metadata_batch grpc_trailing_metadata; + absl::Notification notification; }; const Metadata kDefaultMetadata = { @@ -329,6 +344,7 @@ TEST_F(BinderTransportTest, TransactionIdIncrement) { } TEST_F(BinderTransportTest, SeqNumIncrement) { + grpc_core::ExecCtx exec_ctx; grpc_binder_stream* gbs = InitNewBinderStream(); EXPECT_EQ(gbs->t, GetBinderTransport()); EXPECT_EQ(gbs->tx_code, kFirstCallId); @@ -339,14 +355,17 @@ TEST_F(BinderTransportTest, SeqNumIncrement) { MakeSendInitialMetadata send_initial_metadata(kDefaultMetadata, "", &op); EXPECT_EQ(gbs->seq, 0); PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); EXPECT_EQ(gbs->tx_code, kFirstCallId); EXPECT_EQ(gbs->seq, 1); PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); EXPECT_EQ(gbs->tx_code, kFirstCallId); EXPECT_EQ(gbs->seq, 2); } TEST_F(BinderTransportTest, SeqNumNotIncrementWithoutSend) { + grpc_core::ExecCtx exec_ctx; { grpc_binder_stream* gbs = InitNewBinderStream(); EXPECT_EQ(gbs->t, GetBinderTransport()); @@ -355,11 +374,11 @@ TEST_F(BinderTransportTest, SeqNumNotIncrementWithoutSend) { grpc_transport_stream_op_batch op{}; EXPECT_EQ(gbs->seq, 0); PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); EXPECT_EQ(gbs->tx_code, kFirstCallId); EXPECT_EQ(gbs->seq, 0); } { - grpc_core::ExecCtx exec_ctx; grpc_binder_stream* gbs = InitNewBinderStream(); EXPECT_EQ(gbs->t, GetBinderTransport()); EXPECT_EQ(gbs->tx_code, kFirstCallId + 1); @@ -378,7 +397,8 @@ TEST_F(BinderTransportTest, SeqNumNotIncrementWithoutSend) { gbt->transport_stream_receiver->NotifyRecvInitialMetadata(gbs->tx_code, kDefaultMetadata); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); + recv_initial_metadata.notification.WaitForNotification(); } } @@ -399,7 +419,7 @@ TEST_F(BinderTransportTest, PerformSendInitialMetadata) { EXPECT_CALL(mock_on_complete, Callback); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); } TEST_F(BinderTransportTest, PerformSendInitialMetadataMethodRef) { @@ -422,7 +442,7 @@ TEST_F(BinderTransportTest, PerformSendInitialMetadataMethodRef) { EXPECT_CALL(mock_on_complete, Callback); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); } TEST_F(BinderTransportTest, PerformSendMessage) { @@ -444,7 +464,7 @@ TEST_F(BinderTransportTest, PerformSendMessage) { EXPECT_CALL(mock_on_complete, Callback); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); } TEST_F(BinderTransportTest, PerformSendTrailingMetadata) { @@ -467,7 +487,7 @@ TEST_F(BinderTransportTest, PerformSendTrailingMetadata) { EXPECT_CALL(mock_on_complete, Callback); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); } TEST_F(BinderTransportTest, PerformSendAll) { @@ -502,7 +522,7 @@ TEST_F(BinderTransportTest, PerformSendAll) { EXPECT_CALL(mock_on_complete, Callback); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); } TEST_F(BinderTransportTest, PerformRecvInitialMetadata) { @@ -519,7 +539,8 @@ TEST_F(BinderTransportTest, PerformRecvInitialMetadata) { gbt->transport_stream_receiver->NotifyRecvInitialMetadata(gbs->tx_code, kInitialMetadata); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); + recv_initial_metadata.notification.WaitForNotification(); VerifyMetadataEqual(kInitialMetadata, recv_initial_metadata.grpc_initial_metadata); @@ -540,7 +561,8 @@ TEST_F(BinderTransportTest, PerformRecvInitialMetadataWithMethodRef) { gbt->transport_stream_receiver->NotifyRecvInitialMetadata( gbs->tx_code, kInitialMetadataWithMethodRef); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); + recv_initial_metadata.notification.WaitForNotification(); VerifyMetadataEqual(kInitialMetadataWithMethodRef, recv_initial_metadata.grpc_initial_metadata); @@ -560,7 +582,9 @@ TEST_F(BinderTransportTest, PerformRecvMessage) { gbt->transport_stream_receiver->NotifyRecvMessage(gbs->tx_code, kMessage); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); + recv_message.notification.WaitForNotification(); + EXPECT_TRUE(recv_message.grpc_message->Next(SIZE_MAX, nullptr)); grpc_slice slice; recv_message.grpc_message->Pull(&slice); @@ -585,7 +609,9 @@ TEST_F(BinderTransportTest, PerformRecvTrailingMetadata) { gbt->transport_stream_receiver->NotifyRecvTrailingMetadata( gbs->tx_code, kTrailingMetadata, kStatus); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); + recv_trailing_metadata.notification.WaitForNotification(); + VerifyMetadataEqual(AppendStatus(kTrailingMetadata, kStatus), recv_trailing_metadata.grpc_trailing_metadata); } @@ -615,7 +641,8 @@ TEST_F(BinderTransportTest, PerformRecvAll) { gbt->transport_stream_receiver->NotifyRecvTrailingMetadata( gbs->tx_code, trailing_metadata, kStatus); PerformStreamOp(gbs, &op); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); + recv_trailing_metadata.notification.WaitForNotification(); VerifyMetadataEqual(kInitialMetadataWithMethodRef, recv_initial_metadata.grpc_initial_metadata); @@ -675,7 +702,7 @@ TEST_F(BinderTransportTest, PerformAllOps) { // Flush the execution context to force on_complete to run before recv // callbacks get scheduled. - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); auto* gbt = reinterpret_cast(transport_); const Metadata kRecvInitialMetadata = @@ -689,7 +716,11 @@ TEST_F(BinderTransportTest, PerformAllOps) { gbt->transport_stream_receiver->NotifyRecvTrailingMetadata( gbs->tx_code, kRecvTrailingMetadata, kStatus); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); + recv_initial_metadata.notification.WaitForNotification(); + recv_message.notification.WaitForNotification(); + recv_trailing_metadata.notification.WaitForNotification(); + VerifyMetadataEqual(kRecvInitialMetadata, recv_initial_metadata.grpc_initial_metadata); VerifyMetadataEqual(AppendStatus(kRecvTrailingMetadata, kStatus), @@ -733,7 +764,7 @@ TEST_F(BinderTransportTest, WireWriterRpcCallErrorPropagates) { PerformStreamOp(gbs, &op1); PerformStreamOp(gbs, &op2); - exec_ctx.Flush(); + grpc_core::ExecCtx::Get()->Flush(); } } // namespace grpc_binder diff --git a/test/core/transport/binder/end2end/end2end_binder_transport_test.cc b/test/core/transport/binder/end2end/end2end_binder_transport_test.cc index d8bfc2141f5..5a21ec6355d 100644 --- a/test/core/transport/binder/end2end/end2end_binder_transport_test.cc +++ b/test/core/transport/binder/end2end/end2end_binder_transport_test.cc @@ -63,6 +63,7 @@ using end2end_testing::EchoService; } // namespace TEST_P(End2EndBinderTransportTest, SetupTransport) { + grpc_core::ExecCtx exec_ctx; grpc_transport *client_transport, *server_transport; std::tie(client_transport, server_transport) = end2end_testing::CreateClientServerBindersPairForTesting(); diff --git a/test/core/transport/binder/mock_objects.h b/test/core/transport/binder/mock_objects.h index 80a94e778bc..5a584c8eaa2 100644 --- a/test/core/transport/binder/mock_objects.h +++ b/test/core/transport/binder/mock_objects.h @@ -107,8 +107,7 @@ class MockTransportStreamReceiver : public TransportStreamReceiver { (StreamIdentifier, absl::StatusOr, int), (override)); MOCK_METHOD(void, CancelRecvMessageCallbacksDueToTrailingMetadata, (StreamIdentifier), (override)); - MOCK_METHOD(void, Clear, (StreamIdentifier), (override)); - MOCK_METHOD(void, CancelStream, (StreamIdentifier, absl::Status), (override)); + MOCK_METHOD(void, CancelStream, (StreamIdentifier), (override)); }; } // namespace grpc_binder