diff --git a/include/grpcpp/impl/codegen/client_callback.h b/include/grpcpp/impl/codegen/client_callback.h index 4d9579fd6ab..66cf9b7754c 100644 --- a/include/grpcpp/impl/codegen/client_callback.h +++ b/include/grpcpp/impl/codegen/client_callback.h @@ -255,10 +255,12 @@ class ClientCallbackReaderWriterImpl void MaybeFinish() { if (--callbacks_outstanding_ == 0) { - reactor_->OnDone(finish_status_); + Status s = std::move(finish_status_); + auto* reactor = reactor_; auto* call = call_.call(); this->~ClientCallbackReaderWriterImpl(); g_core_codegen_interface->grpc_call_unref(call); + reactor->OnDone(s); } } @@ -268,6 +270,7 @@ class ClientCallbackReaderWriterImpl // 2. Any read backlog // 3. Recv trailing metadata, on_completion callback // 4. Any write backlog + // 5. See if the call can finish (if other callbacks were triggered already) started_ = true; start_tag_.Set(call_.call(), @@ -318,6 +321,7 @@ class ClientCallbackReaderWriterImpl if (writes_done_ops_at_start_) { call_.PerformOps(&writes_done_ops_); } + MaybeFinish(); } void Read(Response* msg) override { @@ -410,8 +414,8 @@ class ClientCallbackReaderWriterImpl CallbackWithSuccessTag read_tag_; bool read_ops_at_start_{false}; - // Minimum of 2 outstanding callbacks to pre-register for start and finish - std::atomic_int callbacks_outstanding_{2}; + // Minimum of 3 callbacks to pre-register for StartCall, start, and finish + std::atomic_int callbacks_outstanding_{3}; bool started_{false}; }; @@ -450,10 +454,12 @@ class ClientCallbackReaderImpl void MaybeFinish() { if (--callbacks_outstanding_ == 0) { - reactor_->OnDone(finish_status_); + Status s = std::move(finish_status_); + auto* reactor = reactor_; auto* call = call_.call(); this->~ClientCallbackReaderImpl(); g_core_codegen_interface->grpc_call_unref(call); + reactor->OnDone(s); } } @@ -462,6 +468,7 @@ class ClientCallbackReaderImpl // 1. Send initial metadata (unless corked) + recv initial metadata // 2. Any backlog // 3. Recv trailing metadata, on_completion callback + // 4. See if the call can finish (if other callbacks were triggered already) started_ = true; start_tag_.Set(call_.call(), @@ -493,6 +500,8 @@ class ClientCallbackReaderImpl finish_ops_.ClientRecvStatus(context_, &finish_status_); finish_ops_.set_core_cq_tag(&finish_tag_); call_.PerformOps(&finish_ops_); + + MaybeFinish(); } void Read(Response* msg) override { @@ -536,8 +545,8 @@ class ClientCallbackReaderImpl CallbackWithSuccessTag read_tag_; bool read_ops_at_start_{false}; - // Minimum of 2 outstanding callbacks to pre-register for start and finish - std::atomic_int callbacks_outstanding_{2}; + // Minimum of 3 callbacks to pre-register for StartCall, start, and finish + std::atomic_int callbacks_outstanding_{3}; bool started_{false}; }; @@ -576,10 +585,12 @@ class ClientCallbackWriterImpl void MaybeFinish() { if (--callbacks_outstanding_ == 0) { - reactor_->OnDone(finish_status_); + Status s = std::move(finish_status_); + auto* reactor = reactor_; auto* call = call_.call(); this->~ClientCallbackWriterImpl(); g_core_codegen_interface->grpc_call_unref(call); + reactor->OnDone(s); } } @@ -588,6 +599,7 @@ class ClientCallbackWriterImpl // 1. Send initial metadata (unless corked) + recv initial metadata // 2. Recv trailing metadata, on_completion callback // 3. Any backlog + // 4. See if the call can finish (if other callbacks were triggered already) started_ = true; start_tag_.Set(call_.call(), @@ -627,6 +639,8 @@ class ClientCallbackWriterImpl if (writes_done_ops_at_start_) { call_.PerformOps(&writes_done_ops_); } + + MaybeFinish(); } void Write(const Request* msg, WriteOptions options) override { @@ -708,8 +722,8 @@ class ClientCallbackWriterImpl CallbackWithSuccessTag writes_done_tag_; bool writes_done_ops_at_start_{false}; - // Minimum of 2 outstanding callbacks to pre-register for start and finish - std::atomic_int callbacks_outstanding_{2}; + // Minimum of 3 callbacks to pre-register for StartCall, start, and finish + std::atomic_int callbacks_outstanding_{3}; bool started_{false}; }; diff --git a/test/cpp/end2end/client_callback_end2end_test.cc b/test/cpp/end2end/client_callback_end2end_test.cc index 65434bac6b2..a999321992f 100644 --- a/test/cpp/end2end/client_callback_end2end_test.cc +++ b/test/cpp/end2end/client_callback_end2end_test.cc @@ -182,7 +182,7 @@ class ClientCallbackEnd2endTest } } - void SendGenericEchoAsBidi(int num_rpcs) { + void SendGenericEchoAsBidi(int num_rpcs, int reuses) { const grpc::string kMethodName("/grpc.testing.EchoTestService/Echo"); grpc::string test_string(""); for (int i = 0; i < num_rpcs; i++) { @@ -191,14 +191,26 @@ class ClientCallbackEnd2endTest ByteBuffer> { public: Client(ClientCallbackEnd2endTest* test, const grpc::string& method_name, - const grpc::string& test_str) { - test->generic_stub_->experimental().PrepareBidiStreamingCall( - &cli_ctx_, method_name, this); - request_.set_message(test_str); - send_buf_ = SerializeToByteBuffer(&request_); - StartWrite(send_buf_.get()); - StartRead(&recv_buf_); - StartCall(); + const grpc::string& test_str, int reuses) + : reuses_remaining_(reuses) { + activate_ = [this, test, method_name, test_str] { + if (reuses_remaining_ > 0) { + cli_ctx_.reset(new ClientContext); + reuses_remaining_--; + test->generic_stub_->experimental().PrepareBidiStreamingCall( + cli_ctx_.get(), method_name, this); + request_.set_message(test_str); + send_buf_ = SerializeToByteBuffer(&request_); + StartWrite(send_buf_.get()); + StartRead(&recv_buf_); + StartCall(); + } else { + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + }; + activate_(); } void OnWriteDone(bool ok) override { StartWritesDone(); } void OnReadDone(bool ok) override { @@ -208,9 +220,7 @@ class ClientCallbackEnd2endTest }; void OnDone(const Status& s) override { EXPECT_TRUE(s.ok()); - std::unique_lock l(mu_); - done_ = true; - cv_.notify_one(); + activate_(); } void Await() { std::unique_lock l(mu_); @@ -222,11 +232,13 @@ class ClientCallbackEnd2endTest EchoRequest request_; std::unique_ptr send_buf_; ByteBuffer recv_buf_; - ClientContext cli_ctx_; + std::unique_ptr cli_ctx_; + int reuses_remaining_; + std::function activate_; std::mutex mu_; std::condition_variable cv_; bool done_ = false; - } rpc{this, kMethodName, test_string}; + } rpc{this, kMethodName, test_string, reuses}; rpc.Await(); } @@ -293,7 +305,12 @@ TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcs) { TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidi) { ResetStub(); - SendGenericEchoAsBidi(10); + SendGenericEchoAsBidi(10, 1); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidiWithReactorReuse) { + ResetStub(); + SendGenericEchoAsBidi(10, 10); } #if GRPC_ALLOW_EXCEPTIONS