From 52e5b64c5b74aae91baa9dce114e50f43857caf7 Mon Sep 17 00:00:00 2001 From: Ta-Wei Tu Date: Sat, 11 Sep 2021 17:39:17 +0800 Subject: [PATCH] [binder] Handle outbound flow control (#27243) --- .../binder/transport/binder_stream.h | 3 - .../binder/transport/binder_transport.cc | 32 ++-- .../ext/transport/binder/wire_format/BUILD | 1 + .../binder/wire_format/transaction.h | 6 +- .../binder/wire_format/wire_reader_impl.cc | 12 +- .../binder/wire_format/wire_writer.cc | 167 ++++++++++++++---- .../binder/wire_format/wire_writer.h | 39 +++- .../transport/binder/binder_transport_test.cc | 62 ------- .../end2end/end2end_binder_transport_test.cc | 19 ++ test/core/transport/binder/mock_objects.h | 3 +- .../core/transport/binder/wire_writer_test.cc | 95 ++++++++-- 11 files changed, 288 insertions(+), 151 deletions(-) diff --git a/src/core/ext/transport/binder/transport/binder_stream.h b/src/core/ext/transport/binder/transport/binder_stream.h index 176d1f59b32..3bdd5d76148 100644 --- a/src/core/ext/transport/binder/transport/binder_stream.h +++ b/src/core/ext/transport/binder/transport/binder_stream.h @@ -51,7 +51,6 @@ struct grpc_binder_stream { : t(t), refcount(refcount), arena(arena), - seq(0), tx_code(tx_code), is_client(is_client) { // TODO(waynetu): Should this be protected? @@ -74,13 +73,11 @@ struct grpc_binder_stream { } int GetTxCode() const { return tx_code; } - int GetThenIncSeq() { return seq++; } grpc_binder_transport* t; grpc_stream_refcount* refcount; grpc_core::Arena* arena; grpc_core::ManualConstructor sbs; - int seq; int tx_code; bool is_client; bool is_closed = false; diff --git a/src/core/ext/transport/binder/transport/binder_transport.cc b/src/core/ext/transport/binder/transport/binder_transport.cc index 3327f42e160..dbde0377ba8 100644 --- a/src/core/ext/transport/binder/transport/binder_transport.cc +++ b/src/core/ext/transport/binder/transport/binder_transport.cc @@ -324,8 +324,7 @@ static void perform_stream_op_locked(void* stream_op, if (!gbs->is_client) { // Send trailing metadata to inform the other end about the cancellation, // regardless if we'd already done that or not. - grpc_binder::Transaction cancel_tx(gbs->GetTxCode(), gbs->GetThenIncSeq(), - gbt->is_client); + grpc_binder::Transaction cancel_tx(gbs->GetTxCode(), gbt->is_client); cancel_tx.SetSuffix(grpc_binder::Metadata{}); cancel_tx.SetStatus(1); absl::Status status = gbt->wire_writer->RpcCall(cancel_tx); @@ -368,20 +367,13 @@ static void perform_stream_op_locked(void* stream_op, return; } - std::unique_ptr tx; int tx_code = gbs->tx_code; + grpc_binder::Transaction tx(tx_code, gbt->is_client); - if (op->send_initial_metadata || op->send_message || - op->send_trailing_metadata) { - // Only increment sequence number when there's a send operation. - tx = absl::make_unique( - /*tx_code=*/tx_code, /*seq_num=*/gbs->GetThenIncSeq(), gbt->is_client); - } if (op->send_initial_metadata) { gpr_log(GPR_INFO, "send_initial_metadata"); grpc_binder::Metadata init_md; auto batch = op->payload->send_initial_metadata.send_initial_metadata; - GPR_ASSERT(tx); for (grpc_linked_mdelem* md = batch->list.head; md != nullptr; md = md->next) { @@ -399,12 +391,12 @@ static void perform_stream_op_locked(void* stream_op, // Only client send method ref. GPR_ASSERT(gbt->is_client); - tx->SetMethodRef(path); + tx.SetMethodRef(path); } else { init_md.emplace_back(std::string(key), std::string(value)); } } - tx->SetPrefix(init_md); + tx.SetPrefix(init_md); } if (op->send_message) { gpr_log(GPR_INFO, "send_message"); @@ -426,8 +418,7 @@ static void perform_stream_op_locked(void* stream_op, grpc_slice_unref_internal(message_slice); } gpr_log(GPR_INFO, "message_data = %s", message_data.c_str()); - GPR_ASSERT(tx); - tx->SetData(message_data); + tx.SetData(message_data); // TODO(b/192369787): Are we supposed to reset here to avoid // use-after-free issue in call.cc? op->payload->send_message.send_message.reset(); @@ -437,7 +428,6 @@ static void perform_stream_op_locked(void* stream_op, gpr_log(GPR_INFO, "send_trailing_metadata"); auto batch = op->payload->send_trailing_metadata.send_trailing_metadata; grpc_binder::Metadata trailing_metadata; - GPR_ASSERT(tx); for (grpc_linked_mdelem* md = batch->list.head; md != nullptr; md = md->next) { @@ -447,7 +437,7 @@ static void perform_stream_op_locked(void* stream_op, if (grpc_slice_eq(GRPC_MDKEY(md->md), GRPC_MDSTR_GRPC_STATUS)) { int status = grpc_get_status_code_from_metadata(md->md); gpr_log(GPR_INFO, "send trailing metadata status = %d", status); - tx->SetStatus(status); + tx.SetStatus(status); } else { absl::string_view key = grpc_core::StringViewFromSlice(GRPC_MDKEY(md->md)); @@ -460,7 +450,7 @@ static void perform_stream_op_locked(void* stream_op, } // TODO(mingcl): Will we ever has key-value pair here? According to // wireformat client suffix data is always empty. - tx->SetSuffix(trailing_metadata); + tx.SetSuffix(trailing_metadata); } if (op->recv_initial_metadata) { gpr_log(GPR_INFO, "recv_initial_metadata"); @@ -540,8 +530,12 @@ static void perform_stream_op_locked(void* stream_op, } // Only send transaction when there's a send op presented. absl::Status status = absl::OkStatus(); - if (tx) { - status = gbt->wire_writer->RpcCall(*tx); + if (op->send_initial_metadata || op->send_message || + op->send_trailing_metadata) { + // TODO(waynetu): RpcCall() is doing a lot of work (including waiting for + // acknowledgements from the other side). Consider delaying this operation + // with combiner. + status = gbt->wire_writer->RpcCall(tx); if (!gbs->is_client && op->send_trailing_metadata) { gbs->trailing_metadata_sent = true; // According to transport explaineer - "Server extra: This op shouldn't diff --git a/src/core/ext/transport/binder/wire_format/BUILD b/src/core/ext/transport/binder/wire_format/BUILD index 6e0a3d961f3..dd57094a78a 100644 --- a/src/core/ext/transport/binder/wire_format/BUILD +++ b/src/core/ext/transport/binder/wire_format/BUILD @@ -86,6 +86,7 @@ grpc_cc_library( srcs = ["wire_writer.cc"], hdrs = ["wire_writer.h"], external_deps = [ + "absl/container:flat_hash_map", "absl/strings", ], deps = [ diff --git a/src/core/ext/transport/binder/wire_format/transaction.h b/src/core/ext/transport/binder/wire_format/transaction.h index 48eed73d37b..33e02a2582c 100644 --- a/src/core/ext/transport/binder/wire_format/transaction.h +++ b/src/core/ext/transport/binder/wire_format/transaction.h @@ -39,8 +39,8 @@ using Metadata = std::vector>; class Transaction { public: - Transaction(int tx_code, int seq_num, bool is_client) - : tx_code_(tx_code), seq_num_(seq_num), is_client_(is_client) {} + Transaction(int tx_code, bool is_client) + : tx_code_(tx_code), is_client_(is_client) {} // TODO(mingcl): Consider using string_view void SetPrefix(Metadata prefix_metadata) { prefix_metadata_ = prefix_metadata; @@ -77,7 +77,6 @@ class Transaction { bool IsClient() const { return is_client_; } bool IsServer() const { return !is_client_; } int GetTxCode() const { return tx_code_; } - int GetSeqNum() const { return seq_num_; } int GetFlags() const { return flags_; } absl::string_view GetMethodRef() const { return method_ref_; } @@ -88,7 +87,6 @@ class Transaction { private: int tx_code_; - int seq_num_; bool is_client_; Metadata prefix_metadata_; Metadata suffix_metadata_; diff --git a/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc b/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc index 62f250726c2..6107d5c0fe3 100644 --- a/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc +++ b/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc @@ -202,9 +202,11 @@ absl::Status WireReaderImpl::ProcessTransaction(transaction_code_t code, return absl::UnimplementedError("SHUTDOWN_TRANSPORT"); } case BinderTransportTxCode::ACKNOWLEDGE_BYTES: { - int num_bytes = -1; - RETURN_IF_ERROR(parcel->ReadInt32(&num_bytes)); - gpr_log(GPR_INFO, "received acknowledge bytes = %d", num_bytes); + int64_t num_bytes = -1; + RETURN_IF_ERROR(parcel->ReadInt64(&num_bytes)); + gpr_log(GPR_INFO, "received acknowledge bytes = %lld", + static_cast(num_bytes)); + wire_writer_->OnAckReceived(num_bytes); break; } case BinderTransportTxCode::PING: { @@ -259,7 +261,8 @@ absl::Status WireReaderImpl::ProcessStreamingTransaction( } } if ((num_incoming_bytes_ - num_acknowledged_bytes_) >= kFlowControlAckBytes) { - absl::Status ack_status = wire_writer_->Ack(num_incoming_bytes_); + GPR_ASSERT(wire_writer_); + absl::Status ack_status = wire_writer_->SendAck(num_incoming_bytes_); if (status.ok()) { status = ack_status; } @@ -339,6 +342,7 @@ absl::Status WireReaderImpl::ProcessStreamingTransactionImpl( } gpr_log(GPR_INFO, "msg_data = %s", msg_data.c_str()); message_buffer_[code] += msg_data; + // TODO(waynetu): This should be parcel->GetDataSize(). num_incoming_bytes_ += count; if ((flags & kFlagMessageDataIsPartial) == 0) { std::string s = std::move(message_buffer_[code]); diff --git a/src/core/ext/transport/binder/wire_format/wire_writer.cc b/src/core/ext/transport/binder/wire_format/wire_writer.cc index defb702aef2..2b70dff6c46 100644 --- a/src/core/ext/transport/binder/wire_format/wire_writer.cc +++ b/src/core/ext/transport/binder/wire_format/wire_writer.cc @@ -30,55 +30,142 @@ namespace grpc_binder { WireWriterImpl::WireWriterImpl(std::unique_ptr binder) : binder_(std::move(binder)) {} +absl::Status WireWriterImpl::WriteInitialMetadata(const Transaction& tx, + WritableParcel* parcel) { + if (tx.IsClient()) { + // Only client sends method ref. + RETURN_IF_ERROR(parcel->WriteString(tx.GetMethodRef())); + } + RETURN_IF_ERROR(parcel->WriteInt32(tx.GetPrefixMetadata().size())); + for (const auto& md : tx.GetPrefixMetadata()) { + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.first)); + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.second)); + } + return absl::OkStatus(); +} + +absl::Status WireWriterImpl::WriteTrailingMetadata(const Transaction& tx, + WritableParcel* parcel) { + if (tx.IsServer()) { + if (tx.GetFlags() & kFlagStatusDescription) { + RETURN_IF_ERROR(parcel->WriteString(tx.GetStatusDesc())); + } + RETURN_IF_ERROR(parcel->WriteInt32(tx.GetSuffixMetadata().size())); + for (const auto& md : tx.GetSuffixMetadata()) { + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.first)); + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.second)); + } + } else { + // client suffix currently is always empty according to the wireformat + if (!tx.GetSuffixMetadata().empty()) { + gpr_log(GPR_ERROR, "Got non-empty suffix metadata from client."); + } + } + return absl::OkStatus(); +} + +const int64_t WireWriterImpl::kBlockSize = 16 * 1024; +const int64_t WireWriterImpl::kFlowControlWindowSize = 128 * 1024; + +bool WireWriterImpl::CanBeSentInOneTransaction(const Transaction& tx) const { + return (tx.GetFlags() & kFlagMessageData) == 0 || + tx.GetMessageData().size() <= kBlockSize; +} + +absl::Status WireWriterImpl::RpcCallFastPath(const Transaction& tx) { + int& seq = seq_num_[tx.GetTxCode()]; + // Fast path: send data in one transaction. + RETURN_IF_ERROR(binder_->PrepareTransaction()); + WritableParcel* parcel = binder_->GetWritableParcel(); + RETURN_IF_ERROR(parcel->WriteInt32(tx.GetFlags())); + RETURN_IF_ERROR(parcel->WriteInt32(seq++)); + if (tx.GetFlags() & kFlagPrefix) { + RETURN_IF_ERROR(WriteInitialMetadata(tx, parcel)); + } + if (tx.GetFlags() & kFlagMessageData) { + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(tx.GetMessageData())); + } + if (tx.GetFlags() & kFlagSuffix) { + RETURN_IF_ERROR(WriteTrailingMetadata(tx, parcel)); + } + // FIXME(waynetu): Construct BinderTransportTxCode from an arbitrary integer + // is an undefined behavior. + return binder_->Transact(BinderTransportTxCode(tx.GetTxCode())); +} + +bool WireWriterImpl::WaitForAcknowledgement() { + if (num_outgoing_bytes_ < num_acknowledged_bytes_ + kFlowControlWindowSize) { + return true; + } + absl::Time deadline = absl::Now() + absl::Seconds(1); + do { + if (cv_.WaitWithDeadline(&mu_, deadline)) { + return false; + } + if (absl::Now() >= deadline) { + return false; + } + } while (num_outgoing_bytes_ >= + num_acknowledged_bytes_ + kFlowControlWindowSize); + return true; +} + absl::Status WireWriterImpl::RpcCall(const Transaction& tx) { // TODO(mingcl): check tx_code <= last call id grpc_core::MutexLock lock(&mu_); GPR_ASSERT(tx.GetTxCode() >= kFirstCallId); - RETURN_IF_ERROR(binder_->PrepareTransaction()); - WritableParcel* parcel = binder_->GetWritableParcel(); - { - // fill parcel - RETURN_IF_ERROR(parcel->WriteInt32(tx.GetFlags())); - RETURN_IF_ERROR(parcel->WriteInt32(tx.GetSeqNum())); - if (tx.GetFlags() & kFlagPrefix) { - // prefix set - if (tx.IsClient()) { - // Only client sends method ref. - RETURN_IF_ERROR(parcel->WriteString(tx.GetMethodRef())); + if (CanBeSentInOneTransaction(tx)) { + return RpcCallFastPath(tx); + } + // Slow path: the message data is too large to fit in one transaction. + int& seq = seq_num_[tx.GetTxCode()]; + int original_flags = tx.GetFlags(); + GPR_ASSERT(original_flags & kFlagMessageData); + absl::string_view data = tx.GetMessageData(); + size_t bytes_sent = 0; + while (bytes_sent < data.size()) { + if (!WaitForAcknowledgement()) { + return absl::InternalError("Timeout waiting for acknowledgement"); + } + RETURN_IF_ERROR(binder_->PrepareTransaction()); + WritableParcel* parcel = binder_->GetWritableParcel(); + size_t size = + std::min(static_cast(kBlockSize), data.size() - bytes_sent); + int flags = kFlagMessageData; + if (bytes_sent == 0) { + // This is the first transaction. Include initial metadata if there's any. + if (original_flags & kFlagPrefix) { + flags |= kFlagPrefix; } - RETURN_IF_ERROR(parcel->WriteInt32(tx.GetPrefixMetadata().size())); - for (const auto& md : tx.GetPrefixMetadata()) { - RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.first)); - RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.second)); + } + if (bytes_sent + kBlockSize >= data.size()) { + // This is the last transaction. Include trailing metadata if there's any. + if (original_flags & kFlagSuffix) { + flags |= kFlagSuffix; } + } else { + // There are more messages to send. + flags |= kFlagMessageDataIsPartial; } - if (tx.GetFlags() & kFlagMessageData) { - RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(tx.GetMessageData())); + RETURN_IF_ERROR(parcel->WriteInt32(flags)); + RETURN_IF_ERROR(parcel->WriteInt32(seq++)); + if (flags & kFlagPrefix) { + RETURN_IF_ERROR(WriteInitialMetadata(tx, parcel)); } - if (tx.GetFlags() & kFlagSuffix) { - if (tx.IsServer()) { - if (tx.GetFlags() & kFlagStatusDescription) { - RETURN_IF_ERROR(parcel->WriteString(tx.GetStatusDesc())); - } - RETURN_IF_ERROR(parcel->WriteInt32(tx.GetSuffixMetadata().size())); - for (const auto& md : tx.GetSuffixMetadata()) { - RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.first)); - RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.second)); - } - } else { - // client suffix currently is always empty according to the wireformat - if (!tx.GetSuffixMetadata().empty()) { - gpr_log(GPR_ERROR, "Got non-empty suffix metadata from client."); - } - } + RETURN_IF_ERROR( + parcel->WriteByteArrayWithLength(data.substr(bytes_sent, size))); + if (flags & kFlagSuffix) { + RETURN_IF_ERROR(WriteTrailingMetadata(tx, parcel)); } + RETURN_IF_ERROR(binder_->Transact(BinderTransportTxCode(tx.GetTxCode()))); + bytes_sent += size; + // TODO(waynetu): This should be parcel->GetDataSize(). + num_outgoing_bytes_ += size; } - // FIXME(waynetu): Construct BinderTransportTxCode from an arbitrary integer - // is an undefined behavior. - return binder_->Transact(BinderTransportTxCode(tx.GetTxCode())); + return absl::OkStatus(); } -absl::Status WireWriterImpl::Ack(int64_t num_bytes) { +absl::Status WireWriterImpl::SendAck(int64_t num_bytes) { grpc_core::MutexLock lock(&mu_); RETURN_IF_ERROR(binder_->PrepareTransaction()); WritableParcel* parcel = binder_->GetWritableParcel(); @@ -86,4 +173,10 @@ absl::Status WireWriterImpl::Ack(int64_t num_bytes) { return binder_->Transact(BinderTransportTxCode::ACKNOWLEDGE_BYTES); } +void WireWriterImpl::OnAckReceived(int64_t num_bytes) { + grpc_core::MutexLock lock(&mu_); + num_acknowledged_bytes_ = std::max(num_acknowledged_bytes_, num_bytes); + cv_.Signal(); +} + } // namespace grpc_binder diff --git a/src/core/ext/transport/binder/wire_format/wire_writer.h b/src/core/ext/transport/binder/wire_format/wire_writer.h index 578d0ae414c..765c6e69073 100644 --- a/src/core/ext/transport/binder/wire_format/wire_writer.h +++ b/src/core/ext/transport/binder/wire_format/wire_writer.h @@ -20,7 +20,7 @@ #include #include -#include +#include "absl/container/flat_hash_map.h" #include "src/core/ext/transport/binder/wire_format/binder.h" #include "src/core/ext/transport/binder/wire_format/transaction.h" @@ -32,18 +32,51 @@ class WireWriter { public: virtual ~WireWriter() = default; virtual absl::Status RpcCall(const Transaction& call) = 0; - virtual absl::Status Ack(int64_t num_bytes) = 0; + virtual absl::Status SendAck(int64_t num_bytes) = 0; + virtual void OnAckReceived(int64_t num_bytes) = 0; }; class WireWriterImpl : public WireWriter { public: explicit WireWriterImpl(std::unique_ptr binder); absl::Status RpcCall(const Transaction& tx) override; - absl::Status Ack(int64_t num_bytes) override; + absl::Status SendAck(int64_t num_bytes) override; + void OnAckReceived(int64_t num_bytes) override; + + // Split long message into chunks of size 16k. This doesn't necessarily have + // to be the same as the flow control acknowledgement size, but it should not + // exceed 128k. + static const int64_t kBlockSize; + // Flow control allows sending at most 128k between acknowledgements. + static const int64_t kFlowControlWindowSize; private: + absl::Status WriteInitialMetadata(const Transaction& tx, + WritableParcel* parcel) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status WriteTrailingMetadata(const Transaction& tx, + WritableParcel* parcel) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + bool CanBeSentInOneTransaction(const Transaction& tx) const; + absl::Status RpcCallFastPath(const Transaction& tx) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Wait for acknowledgement from the other side for a while (the timeout is + // currently set to 10ms for debugability). Returns true if we are able to + // proceed, and false otherwise. + // + // TODO(waynetu): Currently, RpcCall() will fail if we are blocked for 10ms. + // In the future, we should queue the transactions and release them later when + // acknowledgement comes. + bool WaitForAcknowledgement() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + grpc_core::Mutex mu_; + grpc_core::CondVar cv_; std::unique_ptr binder_ ABSL_GUARDED_BY(mu_); + absl::flat_hash_map seq_num_ ABSL_GUARDED_BY(mu_); + int64_t num_outgoing_bytes_ ABSL_GUARDED_BY(mu_) = 0; + int64_t num_acknowledged_bytes_ ABSL_GUARDED_BY(mu_) = 0; }; } // namespace grpc_binder diff --git a/test/core/transport/binder/binder_transport_test.cc b/test/core/transport/binder/binder_transport_test.cc index 748725cd243..56084bc8d1e 100644 --- a/test/core/transport/binder/binder_transport_test.cc +++ b/test/core/transport/binder/binder_transport_test.cc @@ -332,74 +332,12 @@ TEST_F(BinderTransportTest, TransactionIdIncrement) { grpc_binder_stream* gbs0 = InitNewBinderStream(); EXPECT_EQ(gbs0->t, GetBinderTransport()); EXPECT_EQ(gbs0->tx_code, kFirstCallId); - EXPECT_EQ(gbs0->seq, 0); grpc_binder_stream* gbs1 = InitNewBinderStream(); EXPECT_EQ(gbs1->t, GetBinderTransport()); EXPECT_EQ(gbs1->tx_code, kFirstCallId + 1); - EXPECT_EQ(gbs1->seq, 0); grpc_binder_stream* gbs2 = InitNewBinderStream(); EXPECT_EQ(gbs2->t, GetBinderTransport()); EXPECT_EQ(gbs2->tx_code, kFirstCallId + 2); - EXPECT_EQ(gbs2->seq, 0); -} - -TEST_F(BinderTransportTest, SeqNumIncrement) { - grpc_core::ExecCtx exec_ctx; - grpc_binder_stream* gbs = InitNewBinderStream(); - EXPECT_EQ(gbs->t, GetBinderTransport()); - EXPECT_EQ(gbs->tx_code, kFirstCallId); - // A simple batch that contains only "send_initial_metadata" - grpc_transport_stream_op_batch op{}; - grpc_transport_stream_op_batch_payload payload(nullptr); - op.payload = &payload; - MakeSendInitialMetadata send_initial_metadata(kDefaultMetadata, "", &op); - EXPECT_EQ(gbs->seq, 0); - PerformStreamOp(gbs, &op); - grpc_core::ExecCtx::Get()->Flush(); - EXPECT_EQ(gbs->tx_code, kFirstCallId); - EXPECT_EQ(gbs->seq, 1); - PerformStreamOp(gbs, &op); - grpc_core::ExecCtx::Get()->Flush(); - EXPECT_EQ(gbs->tx_code, kFirstCallId); - EXPECT_EQ(gbs->seq, 2); -} - -TEST_F(BinderTransportTest, SeqNumNotIncrementWithoutSend) { - grpc_core::ExecCtx exec_ctx; - { - grpc_binder_stream* gbs = InitNewBinderStream(); - EXPECT_EQ(gbs->t, GetBinderTransport()); - EXPECT_EQ(gbs->tx_code, kFirstCallId); - // No-op batch. - grpc_transport_stream_op_batch op{}; - EXPECT_EQ(gbs->seq, 0); - PerformStreamOp(gbs, &op); - grpc_core::ExecCtx::Get()->Flush(); - EXPECT_EQ(gbs->tx_code, kFirstCallId); - EXPECT_EQ(gbs->seq, 0); - } - { - grpc_binder_stream* gbs = InitNewBinderStream(); - EXPECT_EQ(gbs->t, GetBinderTransport()); - EXPECT_EQ(gbs->tx_code, kFirstCallId + 1); - // Batch with only receiving operations. - grpc_transport_stream_op_batch op{}; - grpc_transport_stream_op_batch_payload payload(nullptr); - op.payload = &payload; - MakeRecvInitialMetadata recv_initial_metadata(&op); - EXPECT_EQ(gbs->seq, 0); - PerformStreamOp(gbs, &op); - EXPECT_EQ(gbs->tx_code, kFirstCallId + 1); - EXPECT_EQ(gbs->seq, 0); - - // Just to trigger the callback. - auto* gbt = reinterpret_cast(transport_); - gbt->transport_stream_receiver->NotifyRecvInitialMetadata(gbs->tx_code, - kDefaultMetadata); - PerformStreamOp(gbs, &op); - grpc_core::ExecCtx::Get()->Flush(); - recv_initial_metadata.notification.WaitForNotification(); - } } TEST_F(BinderTransportTest, PerformSendInitialMetadata) { diff --git a/test/core/transport/binder/end2end/end2end_binder_transport_test.cc b/test/core/transport/binder/end2end/end2end_binder_transport_test.cc index 5a21ec6355d..5958abc01ed 100644 --- a/test/core/transport/binder/end2end/end2end_binder_transport_test.cc +++ b/test/core/transport/binder/end2end/end2end_binder_transport_test.cc @@ -293,6 +293,25 @@ TEST_P(End2EndBinderTransportTest, BiDirStreamingCallThroughFakeBinderChannel) { server->Shutdown(); } +TEST_P(End2EndBinderTransportTest, LargeMessage) { + grpc::ChannelArguments args; + grpc::ServerBuilder builder; + end2end_testing::EchoServer service; + builder.RegisterService(&service); + std::unique_ptr server = builder.BuildAndStart(); + std::shared_ptr channel = BinderChannel(server.get(), args); + std::unique_ptr stub = EchoService::NewStub(channel); + grpc::ClientContext context; + EchoRequest request; + EchoResponse response; + request.set_text(std::string(1000000, 'a')); + grpc::Status status = stub->EchoUnaryCall(&context, request, &response); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(response.text(), std::string(1000000, 'a')); + + server->Shutdown(); +} + INSTANTIATE_TEST_SUITE_P( End2EndBinderTransportTestWithDifferentDelayTimes, End2EndBinderTransportTest, diff --git a/test/core/transport/binder/mock_objects.h b/test/core/transport/binder/mock_objects.h index d749257954f..bd8cfa708bf 100644 --- a/test/core/transport/binder/mock_objects.h +++ b/test/core/transport/binder/mock_objects.h @@ -88,7 +88,8 @@ class MockTransactionReceiver : public TransactionReceiver { class MockWireWriter : public WireWriter { public: MOCK_METHOD(absl::Status, RpcCall, (const Transaction&), (override)); - MOCK_METHOD(absl::Status, Ack, (int64_t), (override)); + MOCK_METHOD(absl::Status, SendAck, (int64_t), (override)); + MOCK_METHOD(void, OnAckReceived, (int64_t), (override)); }; class MockTransportStreamReceiver : public TransportStreamReceiver { diff --git a/test/core/transport/binder/wire_writer_test.cc b/test/core/transport/binder/wire_writer_test.cc index 939a713669b..0438273b1c9 100644 --- a/test/core/transport/binder/wire_writer_test.cc +++ b/test/core/transport/binder/wire_writer_test.cc @@ -30,7 +30,8 @@ using ::testing::Return; using ::testing::StrictMock; MATCHER_P(StrEqInt8Ptr, target, "") { - return std::string(reinterpret_cast(arg)) == target; + return std::string(reinterpret_cast(arg), target.size()) == + target; } TEST(WireWriterTest, RpcCall) { @@ -53,7 +54,6 @@ TEST(WireWriterTest, RpcCall) { ::testing::InSequence sequence; int sequence_number = 0; - int tx_code = kFirstCallId; { // flag @@ -61,18 +61,18 @@ TEST(WireWriterTest, RpcCall) { // sequence number EXPECT_CALL(mock_writable_parcel, WriteInt32(sequence_number)); - EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(tx_code))); + EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId))); - Transaction tx(tx_code, sequence_number, /*is_client=*/true); + Transaction tx(kFirstCallId, /*is_client=*/true); EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); sequence_number++; - tx_code++; } { // flag EXPECT_CALL(mock_writable_parcel, WriteInt32(kFlagPrefix)); - // sequence number - EXPECT_CALL(mock_writable_parcel, WriteInt32(sequence_number)); + // sequence number. This is another stream so the sequence number starts + // with 0. + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); EXPECT_CALL(mock_writable_parcel, WriteString(absl::string_view("/example/method/ref"))); @@ -96,12 +96,10 @@ TEST(WireWriterTest, RpcCall) { EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId + 1))); - Transaction tx(kFirstCallId + 1, 1, /*is_client=*/true); + Transaction tx(kFirstCallId + 1, /*is_client=*/true); tx.SetPrefix(kMetadata); tx.SetMethodRef("/example/method/ref"); EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); - sequence_number++; - tx_code++; } { // flag @@ -110,11 +108,12 @@ TEST(WireWriterTest, RpcCall) { EXPECT_CALL(mock_writable_parcel, WriteInt32(sequence_number)); ExpectWriteByteArray("data"); - EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(tx_code))); + EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId))); - Transaction tx(tx_code, sequence_number, /*is_client=*/true); + Transaction tx(kFirstCallId, /*is_client=*/true); tx.SetData("data"); EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + sequence_number++; } { // flag @@ -122,13 +121,12 @@ TEST(WireWriterTest, RpcCall) { // sequence number EXPECT_CALL(mock_writable_parcel, WriteInt32(sequence_number)); - EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(tx_code))); + EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId))); - Transaction tx(tx_code, sequence_number, /*is_client=*/true); + Transaction tx(kFirstCallId, /*is_client=*/true); tx.SetSuffix({}); EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); sequence_number++; - tx_code++; } { // flag @@ -159,9 +157,9 @@ TEST(WireWriterTest, RpcCall) { // Empty message data ExpectWriteByteArray(""); - EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(tx_code))); + EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId))); - Transaction tx(tx_code, sequence_number, /*is_client=*/true); + Transaction tx(kFirstCallId, /*is_client=*/true); // TODO(waynetu): Implement a helper function that automatically creates // EXPECT_CALL based on the tx object. tx.SetPrefix(kMetadata); @@ -170,7 +168,68 @@ TEST(WireWriterTest, RpcCall) { tx.SetSuffix({}); EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); sequence_number++; - tx_code++; + } + + // Really large message + { + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); + ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a')); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 2))); + + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(1)); + ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a')); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 2))); + + EXPECT_CALL(mock_writable_parcel, WriteInt32(kFlagMessageData)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(2)); + ExpectWriteByteArray("a"); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 2))); + + // Use a new stream. + Transaction tx(kFirstCallId + 2, /*is_client=*/true); + tx.SetData(std::string(2 * WireWriterImpl::kBlockSize + 1, 'a')); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + } + // Really large message with metadata + { + EXPECT_CALL( + mock_writable_parcel, + WriteInt32(kFlagPrefix | kFlagMessageData | kFlagMessageDataIsPartial)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); + EXPECT_CALL(mock_writable_parcel, WriteString(absl::string_view("123"))); + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); + ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a')); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 3))); + + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(1)); + ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a')); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 3))); + + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagMessageData | kFlagSuffix)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(2)); + ExpectWriteByteArray("a"); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 3))); + + // Use a new stream. + Transaction tx(kFirstCallId + 3, /*is_client=*/true); + tx.SetPrefix({}); + tx.SetMethodRef("123"); + tx.SetData(std::string(2 * WireWriterImpl::kBlockSize + 1, 'a')); + tx.SetSuffix({}); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); } }