diff --git a/include/grpcpp/impl/codegen/client_callback.h b/include/grpcpp/impl/codegen/client_callback.h index 0b2631014a2..89629c079af 100644 --- a/include/grpcpp/impl/codegen/client_callback.h +++ b/include/grpcpp/impl/codegen/client_callback.h @@ -112,6 +112,8 @@ class ClientCallbackReaderWriter { virtual void Write(const Request* req, WriteOptions options) = 0; virtual void WritesDone() = 0; virtual void Read(Response* resp) = 0; + virtual void AddHold(int holds) = 0; + virtual void RemoveHold() = 0; protected: void BindReactor(ClientBidiReactor* reactor) { @@ -125,6 +127,8 @@ class ClientCallbackReader { virtual ~ClientCallbackReader() {} virtual void StartCall() = 0; virtual void Read(Response* resp) = 0; + virtual void AddHold(int holds) = 0; + virtual void RemoveHold() = 0; protected: void BindReactor(ClientReadReactor* reactor) { @@ -144,6 +148,9 @@ class ClientCallbackWriter { } virtual void WritesDone() = 0; + virtual void AddHold(int holds) = 0; + virtual void RemoveHold() = 0; + protected: void BindReactor(ClientWriteReactor* reactor) { reactor->BindWriter(this); @@ -174,6 +181,29 @@ class ClientBidiReactor { } void StartWritesDone() { stream_->WritesDone(); } + /// Holds are needed if (and only if) this stream has operations that take + /// place on it after StartCall but from outside one of the reactions + /// (OnReadDone, etc). This is _not_ a common use of the streaming API. + /// + /// Holds must be added before calling StartCall. If a stream still has a hold + /// in place, its resources will not be destroyed even if the status has + /// already come in from the wire and there are currently no active callbacks + /// outstanding. Similarly, the stream will not call OnDone if there are still + /// holds on it. + /// + /// For example, if a StartRead or StartWrite operation is going to be + /// initiated from elsewhere in the application, the application should call + /// AddHold or AddMultipleHolds before StartCall. If there is going to be, + /// for example, a read-flow and a write-flow taking place outside the + /// reactions, then call AddMultipleHolds(2) before StartCall. When the + /// application knows that it won't issue any more Read operations (such as + /// when a read comes back as not ok), it should issue a RemoveHold(). It + /// should also call RemoveHold() again after it does StartWriteLast or + /// StartWritesDone that indicates that there will be no more Write ops. + void AddHold() { AddMultipleHolds(1); } + void AddMultipleHolds(int holds) { stream_->AddHold(holds); } + void RemoveHold() { stream_->RemoveHold(); } + private: friend class ClientCallbackReaderWriter; void BindStream(ClientCallbackReaderWriter* stream) { @@ -193,6 +223,10 @@ class ClientReadReactor { void StartCall() { reader_->StartCall(); } void StartRead(Response* resp) { reader_->Read(resp); } + void AddHold() { AddMultipleHolds(1); } + void AddMultipleHolds(int holds) { reader_->AddHold(holds); } + void RemoveHold() { reader_->RemoveHold(); } + private: friend class ClientCallbackReader; void BindReader(ClientCallbackReader* reader) { reader_ = reader; } @@ -218,6 +252,10 @@ class ClientWriteReactor { } void StartWritesDone() { writer_->WritesDone(); } + void AddHold() { AddMultipleHolds(1); } + void AddMultipleHolds(int holds) { writer_->AddHold(holds); } + void RemoveHold() { writer_->RemoveHold(); } + private: friend class ClientCallbackWriter; void BindWriter(ClientCallbackWriter* writer) { writer_ = writer; } @@ -374,6 +412,9 @@ class ClientCallbackReaderWriterImpl } } + virtual void AddHold(int holds) override { callbacks_outstanding_ += holds; } + virtual void RemoveHold() override { MaybeFinish(); } + private: friend class ClientCallbackReaderWriterFactory; @@ -509,6 +550,9 @@ class ClientCallbackReaderImpl } } + virtual void AddHold(int holds) override { callbacks_outstanding_ += holds; } + virtual void RemoveHold() override { MaybeFinish(); } + private: friend class ClientCallbackReaderFactory; @@ -677,6 +721,9 @@ class ClientCallbackWriterImpl } } + virtual void AddHold(int holds) override { callbacks_outstanding_ += holds; } + virtual void RemoveHold() override { MaybeFinish(); } + private: friend class ClientCallbackWriterFactory; diff --git a/test/cpp/end2end/client_callback_end2end_test.cc b/test/cpp/end2end/client_callback_end2end_test.cc index 3845c4c0b2a..e1e898275b3 100644 --- a/test/cpp/end2end/client_callback_end2end_test.cc +++ b/test/cpp/end2end/client_callback_end2end_test.cc @@ -1117,6 +1117,82 @@ TEST_P(ClientCallbackEnd2endTest, UnimplementedRpc) { } } +TEST_P(ClientCallbackEnd2endTest, + ResponseStreamExtraReactionFlowReadsUntilDone) { + MAYBE_SKIP_TEST; + ResetStub(); + class ReadAllIncomingDataClient + : public grpc::experimental::ClientReadReactor { + public: + ReadAllIncomingDataClient(grpc::testing::EchoTestService::Stub* stub) { + request_.set_message("Hello client "); + stub->experimental_async()->ResponseStream(&context_, &request_, this); + } + bool WaitForReadDone() { + std::unique_lock l(mu_); + while (!read_done_) { + read_cv_.wait(l); + } + read_done_ = false; + return read_ok_; + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + done_cv_.wait(l); + } + } + const Status& status() { + std::unique_lock l(mu_); + return status_; + } + + private: + void OnReadDone(bool ok) override { + std::unique_lock l(mu_); + read_ok_ = ok; + read_done_ = true; + read_cv_.notify_one(); + } + void OnDone(const Status& s) override { + std::unique_lock l(mu_); + done_ = true; + status_ = s; + done_cv_.notify_one(); + } + + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + bool read_ok_ = false; + bool read_done_ = false; + std::mutex mu_; + std::condition_variable read_cv_; + std::condition_variable done_cv_; + bool done_ = false; + Status status_; + } client{stub_.get()}; + + int reads_complete = 0; + client.AddHold(); + client.StartCall(); + + EchoResponse response; + bool read_ok = true; + while (read_ok) { + client.StartRead(&response); + read_ok = client.WaitForReadDone(); + if (read_ok) { + ++reads_complete; + } + } + client.RemoveHold(); + client.Await(); + + EXPECT_EQ(kServerDefaultResponseStreamsToSend, reads_complete); + EXPECT_EQ(client.status().error_code(), grpc::StatusCode::OK); +} + std::vector CreateTestScenarios(bool test_insecure) { std::vector scenarios; std::vector credentials_types{