From 41c6cba9f55b8104a9f5068e5418f689e00e6f65 Mon Sep 17 00:00:00 2001 From: Vijay Pai Date: Mon, 8 Apr 2019 09:03:13 -0700 Subject: [PATCH] Make sure that OnCancel happens after OnStarted --- include/grpcpp/impl/codegen/server_callback.h | 38 +++++++++ src/cpp/server/server_context.cc | 11 ++- test/cpp/end2end/end2end_test.cc | 8 +- test/cpp/end2end/test_service_impl.cc | 79 +++++++++++-------- 4 files changed, 97 insertions(+), 39 deletions(-) diff --git a/include/grpcpp/impl/codegen/server_callback.h b/include/grpcpp/impl/codegen/server_callback.h index 4c6b189214e..ce27b156283 100644 --- a/include/grpcpp/impl/codegen/server_callback.h +++ b/include/grpcpp/impl/codegen/server_callback.h @@ -37,11 +37,43 @@ namespace grpc { // Declare base class of all reactors as internal namespace internal { +// Forward declarations +template +class CallbackClientStreamingHandler; +template +class CallbackServerStreamingHandler; +template +class CallbackBidiHandler; + class ServerReactor { public: virtual ~ServerReactor() = default; virtual void OnDone() = 0; virtual void OnCancel() = 0; + + private: + friend class ::grpc::ServerContext; + template + friend class CallbackClientStreamingHandler; + template + friend class CallbackServerStreamingHandler; + template + friend class CallbackBidiHandler; + + // The ServerReactor is responsible for tracking when it is safe to call + // OnCancel. This function should not be called until after OnStarted is done + // and the RPC has completed with a cancellation. This is tracked by counting + // how many of these conditions have been met and calling OnCancel when none + // remain unmet. + + void MaybeCallOnCancel() { + if (on_cancel_conditions_remaining_.fetch_sub( + 1, std::memory_order_acq_rel) == 1) { + OnCancel(); + } + } + + std::atomic_int on_cancel_conditions_remaining_{2}; }; } // namespace internal @@ -590,6 +622,8 @@ class CallbackClientStreamingHandler : public MethodHandler { reader->BindReactor(reactor); reactor->OnStarted(param.server_context, reader->response()); + // The earliest that OnCancel can be called is after OnStarted is done. + reactor->MaybeCallOnCancel(); reader->MaybeDone(); } @@ -732,6 +766,8 @@ class CallbackServerStreamingHandler : public MethodHandler { std::move(param.call_requester), reactor); writer->BindReactor(reactor); reactor->OnStarted(param.server_context, writer->request()); + // The earliest that OnCancel can be called is after OnStarted is done. + reactor->MaybeCallOnCancel(); writer->MaybeDone(); } @@ -908,6 +944,8 @@ class CallbackBidiHandler : public MethodHandler { stream->BindReactor(reactor); reactor->OnStarted(param.server_context); + // The earliest that OnCancel can be called is after OnStarted is done. + reactor->MaybeCallOnCancel(); stream->MaybeDone(); } diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 70c5cd8861e..eced89d1a79 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -210,17 +210,20 @@ bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { bool call_cancel = (cancelled_ != 0); // If it's a unary cancel callback, call it under the lock so that it doesn't - // race with ClearCancelCallback + // race with ClearCancelCallback. Although we don't normally call callbacks + // under a lock, this is a special case since the user needs a guarantee that + // the callback won't issue or run after ClearCancelCallback has returned. + // This requirement imposes certain restrictions on the callback, documented + // in the API comments of SetCancelCallback. if (cancel_callback_) { cancel_callback_(); } - // Release the lock since we are going to be calling a callback and - // interceptors now + // Release the lock since we may call a callback and interceptors now. lock.Unlock(); if (call_cancel && reactor_ != nullptr) { - reactor_->OnCancel(); + reactor_->MaybeCallOnCancel(); } /* Add interception point and run through interceptors */ interceptor_methods_.AddInterceptionHookPoint( diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 40023c72f62..fb951fd44e6 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -1420,18 +1420,18 @@ TEST_P(End2endTest, DelayedRpcLateCanceledUsingCancelCallback) { EchoResponse response; request.set_message("Hello"); request.mutable_param()->set_skip_cancelled_check(true); - // Let server sleep for 80 ms first to give the cancellation a chance. - // This is split into 40 ms to start the cancel and 40 ms extra time for + // Let server sleep for 200 ms first to give the cancellation a chance. + // This is split into 100 ms to start the cancel and 100 ms extra time for // it to make it to the server, to make it highly probable that the server // RPC would have already started by the time the cancellation is sent // and the server-side gets enough time to react to it. - request.mutable_param()->set_server_sleep_us(80 * 1000); + request.mutable_param()->set_server_sleep_us(200000); std::thread echo_thread{[this, &context, &request, &response] { Status s = stub_->Echo(&context, request, &response); EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); }}; - std::this_thread::sleep_for(std::chrono::microseconds(40000)); + std::this_thread::sleep_for(std::chrono::microseconds(100000)); context.TryCancel(); echo_thread.join(); } diff --git a/test/cpp/end2end/test_service_impl.cc b/test/cpp/end2end/test_service_impl.cc index 1cbbc703076..048715300ad 100644 --- a/test/cpp/end2end/test_service_impl.cc +++ b/test/cpp/end2end/test_service_impl.cc @@ -589,8 +589,9 @@ CallbackTestServiceImpl::RequestStream() { public: Reactor() {} void OnStarted(ServerContext* context, EchoResponse* response) override { - ctx_ = context; - response_ = response; + // Assign ctx_ and response_ as late as possible to increase likelihood of + // catching any races + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by // the server by calling ServerContext::TryCancel() depending on the // value: @@ -602,22 +603,26 @@ CallbackTestServiceImpl::RequestStream() { server_try_cancel_ = GetIntValueFromMetadata( kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); - response_->set_message(""); + response->set_message(""); if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) { - ServerTryCancelNonblocking(ctx_); - return; - } - - if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { - ctx_->TryCancel(); - // Don't wait for it here + ServerTryCancelNonblocking(context); + ctx_ = context; + } else { + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + context->TryCancel(); + // Don't wait for it here + } + ctx_ = context; + response_ = response; + StartRead(&request_); } - StartRead(&request_); + on_started_done_ = true; } void OnDone() override { delete this; } void OnCancel() override { + EXPECT_TRUE(on_started_done_); EXPECT_TRUE(ctx_->IsCancelled()); FinishOnce(Status::CANCELLED); } @@ -657,6 +662,7 @@ CallbackTestServiceImpl::RequestStream() { int server_try_cancel_; std::mutex finish_mu_; bool finished_{false}; + bool on_started_done_{false}; }; return new Reactor; @@ -673,8 +679,9 @@ CallbackTestServiceImpl::ResponseStream() { Reactor() {} void OnStarted(ServerContext* context, const EchoRequest* request) override { - ctx_ = context; - request_ = request; + // Assign ctx_ and request_ as late as possible to increase likelihood of + // catching any races + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by // the server by calling ServerContext::TryCancel() depending on the // value: @@ -691,19 +698,23 @@ CallbackTestServiceImpl::ResponseStream() { kServerResponseStreamsToSend, context->client_metadata(), kServerDefaultResponseStreamsToSend); if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) { - ServerTryCancelNonblocking(ctx_); - return; - } - - if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { - ctx_->TryCancel(); - } - if (num_msgs_sent_ < server_responses_to_send_) { - NextWrite(); + ServerTryCancelNonblocking(context); + ctx_ = context; + } else { + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + context->TryCancel(); + } + ctx_ = context; + request_ = request; + if (num_msgs_sent_ < server_responses_to_send_) { + NextWrite(); + } } + on_started_done_ = true; } void OnDone() override { delete this; } void OnCancel() override { + EXPECT_TRUE(on_started_done_); EXPECT_TRUE(ctx_->IsCancelled()); FinishOnce(Status::CANCELLED); } @@ -753,6 +764,7 @@ CallbackTestServiceImpl::ResponseStream() { int server_responses_to_send_; std::mutex finish_mu_; bool finished_{false}; + bool on_started_done_{false}; }; return new Reactor; } @@ -764,7 +776,9 @@ CallbackTestServiceImpl::BidiStream() { public: Reactor() {} void OnStarted(ServerContext* context) override { - ctx_ = context; + // Assign ctx_ as late as possible to increase likelihood of catching any + // races + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by // the server by calling ServerContext::TryCancel() depending on the // value: @@ -778,18 +792,20 @@ CallbackTestServiceImpl::BidiStream() { server_write_last_ = GetIntValueFromMetadata( kServerFinishAfterNReads, context->client_metadata(), 0); if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) { - ServerTryCancelNonblocking(ctx_); - return; - } - - if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { - ctx_->TryCancel(); + ServerTryCancelNonblocking(context); + ctx_ = context; + } else { + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + context->TryCancel(); + } + ctx_ = context; + StartRead(&request_); } - - StartRead(&request_); + on_started_done_ = true; } void OnDone() override { delete this; } void OnCancel() override { + EXPECT_TRUE(on_started_done_); EXPECT_TRUE(ctx_->IsCancelled()); FinishOnce(Status::CANCELLED); } @@ -839,6 +855,7 @@ CallbackTestServiceImpl::BidiStream() { int server_write_last_; std::mutex finish_mu_; bool finished_{false}; + bool on_started_done_{false}; }; return new Reactor;