diff --git a/src/core/ext/transport/binder/transport/binder_transport.cc b/src/core/ext/transport/binder/transport/binder_transport.cc index a1a2665e03a..293ec5c7ca7 100644 --- a/src/core/ext/transport/binder/transport/binder_transport.cc +++ b/src/core/ext/transport/binder/transport/binder_transport.cc @@ -149,7 +149,26 @@ static void cancel_stream_locked(grpc_binder_transport* gbt, grpc_binder_stream* gbs, grpc_error_handle error) { gpr_log(GPR_INFO, "cancel_stream_locked"); + if (!gbs->is_closed) { + if (gbt->is_client) { + // Always do an out-of-band close to play it safe. Due the design of gRPC, + // a bidi-streaming call can be "closed" from the client side, but a + // server-streaming call can only be "cancelled". An out-of-band close + // effectively "cancels" both bidi-streaming and server-streaming call. + // + // TODO(littlecvr): Investigate how BinderTransport work in Java and align + // with their behavior. + auto cancel_tx = std::make_unique( + gbs->GetTxCode(), gbt->is_client); + cancel_tx->SetOutOfBandClose(); + cancel_tx->SetStatus(GRPC_STATUS_CANCELLED); + gpr_log(GPR_INFO, + "Sending out-of-band close, gbt = %p, gbs = %p, is_client = %d", + gbt, gbs, gbt->is_client); + gbt->wire_writer->RpcCall(std::move(cancel_tx)).IgnoreError(); + } + GPR_ASSERT(gbs->cancel_self_error.ok()); gbs->is_closed = true; gbs->cancel_self_error = error; diff --git a/src/core/ext/transport/binder/wire_format/transaction.cc b/src/core/ext/transport/binder/wire_format/transaction.cc index 9c3bb0225b6..2b0c35c7b88 100644 --- a/src/core/ext/transport/binder/wire_format/transaction.cc +++ b/src/core/ext/transport/binder/wire_format/transaction.cc @@ -28,6 +28,7 @@ ABSL_CONST_INIT const int kFlagExpectSingleMessage = 0x10; ABSL_CONST_INIT const int kFlagStatusDescription = 0x20; ABSL_CONST_INIT const int kFlagMessageDataIsParcelable = 0x40; ABSL_CONST_INIT const int kFlagMessageDataIsPartial = 0x80; +ABSL_CONST_INIT const int kStatusCodeShift = 16; } // namespace grpc_binder #endif diff --git a/src/core/ext/transport/binder/wire_format/transaction.h b/src/core/ext/transport/binder/wire_format/transaction.h index 1bd61377c73..5b5c3f6ede5 100644 --- a/src/core/ext/transport/binder/wire_format/transaction.h +++ b/src/core/ext/transport/binder/wire_format/transaction.h @@ -34,6 +34,7 @@ ABSL_CONST_INIT extern const int kFlagExpectSingleMessage; ABSL_CONST_INIT extern const int kFlagStatusDescription; ABSL_CONST_INIT extern const int kFlagMessageDataIsParcelable; ABSL_CONST_INIT extern const int kFlagMessageDataIsPartial; +ABSL_CONST_INIT extern const int kStatusCodeShift; using Metadata = std::vector>; @@ -67,11 +68,11 @@ class Transaction { GPR_ASSERT((flags_ & kFlagStatusDescription) == 0); status_desc_ = status_desc; } + void SetOutOfBandClose() { flags_ |= kFlagOutOfBandClose; } void SetStatus(int status) { - GPR_ASSERT(!is_client_); - GPR_ASSERT((flags_ >> 16) == 0); - GPR_ASSERT(status < (1 << 16)); - flags_ |= (status << 16); + GPR_ASSERT((flags_ >> kStatusCodeShift) == 0); + GPR_ASSERT(status < (1 << kStatusCodeShift)); + flags_ |= (status << kStatusCodeShift); } bool IsClient() const { return is_client_; } diff --git a/test/core/transport/binder/binder_transport_test.cc b/test/core/transport/binder/binder_transport_test.cc index e2da466ba0c..0bbb34f6717 100644 --- a/test/core/transport/binder/binder_transport_test.cc +++ b/test/core/transport/binder/binder_transport_test.cc @@ -44,6 +44,59 @@ using ::testing::Expectation; using ::testing::NiceMock; using ::testing::Return; +std::string MetadataString(const Metadata& a) { + return absl::StrCat( + "{", + absl::StrJoin( + a, ", ", + [](std::string* out, const std::pair& kv) { + out->append( + absl::StrCat("\"", kv.first, "\": \"", kv.second, "\"")); + }), + "}"); +} + +bool MetadataEquivalent(Metadata a, Metadata b) { + std::sort(a.begin(), a.end()); + std::sort(b.begin(), b.end()); + return a == b; +} + +// Matches with transactions having the desired flag, method_ref, +// initial_metadata, and message_data. +MATCHER_P4(TransactionMatches, flag, method_ref, initial_metadata, message_data, + "") { + if (arg->GetFlags() != flag) return false; + if (flag & kFlagPrefix) { + if (arg->GetMethodRef() != method_ref) { + printf("METHOD REF NOT EQ: %s %s\n", + std::string(arg->GetMethodRef()).c_str(), + std::string(method_ref).c_str()); + return false; + } + if (!MetadataEquivalent(arg->GetPrefixMetadata(), initial_metadata)) { + printf("METADATA NOT EQUIVALENT: %s %s\n", + MetadataString(arg->GetPrefixMetadata()).c_str(), + MetadataString(initial_metadata).c_str()); + return false; + } + } + if (flag & kFlagMessageData) { + if (arg->GetMessageData() != message_data) { + printf("MESSAGE NOT EQUIVALENT: %s %s\n", + std::string(arg->GetMessageData()).c_str(), + std::string(message_data).c_str()); + return false; + } + } + return true; +} + +// Matches with grpc_error having error message containing |msg|. +MATCHER_P(GrpcErrorMessageContains, msg, "") { + return absl::StrContains(grpc_core::StatusToString(arg), msg); +} + class BinderTransportTest : public ::testing::Test { public: BinderTransportTest() @@ -91,6 +144,14 @@ class BinderTransportTest : public ::testing::Test { GetBinderTransport()->wire_writer.get()); } + void ExpectOutOfBandCloseCalled() { + EXPECT_CALL( + GetWireWriter(), + RpcCall(TransactionMatches( + kFlagOutOfBandClose | (GRPC_STATUS_CANCELLED << kStatusCodeShift), + "", Metadata{}, ""))); + } + static void SetUpTestSuite() { grpc_init(); } static void TearDownTestSuite() { grpc_shutdown(); } @@ -132,55 +193,6 @@ void MockCallback(void* arg, grpc_error_handle error) { } } -std::string MetadataString(const Metadata& a) { - return absl::StrCat( - "{", - absl::StrJoin( - a, ", ", - [](std::string* out, const std::pair& kv) { - out->append( - absl::StrCat("\"", kv.first, "\": \"", kv.second, "\"")); - }), - "}"); -} - -bool MetadataEquivalent(Metadata a, Metadata b) { - std::sort(a.begin(), a.end()); - std::sort(b.begin(), b.end()); - return a == b; -} - -// Matches with transactions having the desired flag, method_ref, -// initial_metadata, and message_data. -MATCHER_P4(TransactionMatches, flag, method_ref, initial_metadata, message_data, - "") { - if (arg->GetFlags() != flag) return false; - if (flag & kFlagPrefix) { - if (arg->GetMethodRef() != method_ref) { - printf("METHOD REF NOT EQ: %s %s\n", - std::string(arg->GetMethodRef()).c_str(), - std::string(method_ref).c_str()); - return false; - } - if (!MetadataEquivalent(arg->GetPrefixMetadata(), initial_metadata)) { - printf("METADATA NOT EQUIVALENT: %s %s\n", - MetadataString(arg->GetPrefixMetadata()).c_str(), - MetadataString(initial_metadata).c_str()); - return false; - } - } - if (flag & kFlagMessageData) { - if (arg->GetMessageData() != message_data) return false; - } - return true; -} - -// Matches with grpc_error having error message containing |msg|. -MATCHER_P(GrpcErrorMessageContains, msg, "") { - return absl::StrContains(grpc_core::StatusToString(arg), msg); -} - -namespace { class MetadataEncoder { public: void Encode(const grpc_core::Slice& key, const grpc_core::Slice& value) { @@ -201,7 +213,6 @@ class MetadataEncoder { private: Metadata metadata_; }; -} // namespace // Verify that the lower-level metadata has the same content as the gRPC // metadata. @@ -407,6 +418,7 @@ TEST_F(BinderTransportTest, PerformSendInitialMetadata) { EXPECT_CALL(GetWireWriter(), RpcCall(TransactionMatches( kFlagPrefix, "", kInitialMetadata, ""))); EXPECT_CALL(mock_on_complete, Callback); + ExpectOutOfBandCloseCalled(); PerformStreamOp(gbs, &op); grpc_core::ExecCtx::Get()->Flush(); @@ -430,6 +442,7 @@ TEST_F(BinderTransportTest, PerformSendInitialMetadataMethodRef) { RpcCall(TransactionMatches(kFlagPrefix, kMethodRef.substr(1), kInitialMetadata, ""))); EXPECT_CALL(mock_on_complete, Callback); + ExpectOutOfBandCloseCalled(); PerformStreamOp(gbs, &op); grpc_core::ExecCtx::Get()->Flush(); @@ -452,6 +465,7 @@ TEST_F(BinderTransportTest, PerformSendMessage) { GetWireWriter(), RpcCall(TransactionMatches(kFlagMessageData, "", Metadata{}, kMessage))); EXPECT_CALL(mock_on_complete, Callback); + ExpectOutOfBandCloseCalled(); PerformStreamOp(gbs, &op); grpc_core::ExecCtx::Get()->Flush(); @@ -475,6 +489,7 @@ TEST_F(BinderTransportTest, PerformSendTrailingMetadata) { EXPECT_CALL(GetWireWriter(), RpcCall(TransactionMatches( kFlagSuffix, "", kTrailingMetadata, ""))); EXPECT_CALL(mock_on_complete, Callback); + ExpectOutOfBandCloseCalled(); PerformStreamOp(gbs, &op); grpc_core::ExecCtx::Get()->Flush(); @@ -510,6 +525,7 @@ TEST_F(BinderTransportTest, PerformSendAll) { kFlagPrefix | kFlagMessageData | kFlagSuffix, kMethodRef.substr(1), kInitialMetadata, kMessage))); EXPECT_CALL(mock_on_complete, Callback); + ExpectOutOfBandCloseCalled(); PerformStreamOp(gbs, &op); grpc_core::ExecCtx::Get()->Flush(); @@ -667,6 +683,7 @@ TEST_F(BinderTransportTest, PerformAllOps) { RpcCall(TransactionMatches( kFlagPrefix | kFlagMessageData | kFlagSuffix, kMethodRef.substr(1), kSendInitialMetadata, kSendMessage))); + ExpectOutOfBandCloseCalled(); Expectation on_complete = EXPECT_CALL(mock_on_complete, Callback); // Recv callbacks can happen after the on_complete callback. @@ -720,6 +737,7 @@ TEST_F(BinderTransportTest, WireWriterRpcCallErrorPropagates) { EXPECT_CALL(mock_on_complete1, Callback(absl::OkStatus())); EXPECT_CALL(mock_on_complete2, Callback(GrpcErrorMessageContains("WireWriter::RpcCall failed"))); + ExpectOutOfBandCloseCalled(); const Metadata kInitialMetadata = {}; grpc_transport_stream_op_batch op1{};