diff --git a/include/grpcpp/impl/codegen/server_callback.h b/include/grpcpp/impl/codegen/server_callback.h index 60c308b22e7..274a00e0556 100644 --- a/include/grpcpp/impl/codegen/server_callback.h +++ b/include/grpcpp/impl/codegen/server_callback.h @@ -69,6 +69,31 @@ class ServerCallbackRpcController { // Allow the method handler to push out the initial metadata before // the response and status are ready virtual void SendInitialMetadata(std::function) = 0; + + /// SetCancelCallback passes in a callback to be called when the RPC is + /// canceled for whatever reason (streaming calls have OnCancel instead). This + /// is an advanced and uncommon use with several important restrictions. + /// + /// If code calls SetCancelCallback on an RPC, it must also call + /// ClearCancelCallback before calling Finish on the RPC controller. + /// + /// The callback should generally be lightweight and nonblocking and primarily + /// concerned with clearing application state related to the RPC or causing + /// operations (such as cancellations) to happen on dependent RPCs. + /// + /// If the RPC is already canceled at the time that SetCancelCallback is + /// called, the callback is invoked immediately. + /// + /// The cancellation callback may be executed concurrently with the method + /// handler that invokes it but will certainly not issue or execute after the + /// return of ClearCancelCallback. + /// + /// The callback is called under a lock that is also used for + /// ClearCancelCallback and ServerContext::IsCancelled, so the callback CANNOT + /// call either of those operations on this RPC or any other function that + /// causes those operations to be called before the callback completes. + virtual void SetCancelCallback(std::function callback) = 0; + virtual void ClearCancelCallback() = 0; }; // NOTE: The actual streaming object classes are provided @@ -349,6 +374,15 @@ class CallbackUnaryHandler : public MethodHandler { call_.PerformOps(&meta_ops_); } + // Neither SetCancelCallback nor ClearCancelCallback should affect the + // callbacks_outstanding_ count since they are paired and both must precede + // the invocation of Finish (if they are used at all) + void SetCancelCallback(std::function callback) override { + ctx_->SetCancelCallback(std::move(callback)); + } + + void ClearCancelCallback() override { ctx_->ClearCancelCallback(); } + private: friend class CallbackUnaryHandler; diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index fb82186d69e..591a9ff9549 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -329,6 +329,9 @@ class ServerContext { uint32_t initial_metadata_flags() const { return 0; } + void SetCancelCallback(std::function callback); + void ClearCancelCallback(); + experimental::ServerRpcInfo* set_server_rpc_info( const char* method, internal::RpcMethod::RpcType type, const std::vector< diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index d38b46822ae..73fd6a62c48 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -95,6 +95,22 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { tag_ = tag; } + void SetCancelCallback(std::function callback) { + std::lock_guard lock(mu_); + + if (finalized_ && (cancelled_ != 0)) { + callback(); + return; + } + + cancel_callback_ = std::move(callback); + } + + void ClearCancelCallback() { + std::lock_guard g(mu_); + cancel_callback_ = nullptr; + } + void set_core_cq_tag(void* core_cq_tag) { core_cq_tag_ = core_cq_tag; } void* core_cq_tag() override { return core_cq_tag_; } @@ -141,6 +157,7 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { std::mutex mu_; bool finalized_; int cancelled_; // This is an int (not bool) because it is passed to core + std::function cancel_callback_; bool done_intercepting_; internal::InterceptorBatchMethodsImpl interceptor_methods_; }; @@ -191,11 +208,17 @@ bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { // Decide whether to call the cancel callback before releasing the lock bool call_cancel = (cancelled_ != 0); + // If it's a unary cancel callback, call it under the lock so that it doesn't + // race with ClearCancelCallback + if (cancel_callback_) { + cancel_callback_(); + } + // Release the lock since we are going to be calling a callback and // interceptors now lock.unlock(); - if (call_cancel && (reactor_ != nullptr)) { + if (call_cancel && reactor_ != nullptr) { reactor_->OnCancel(); } @@ -315,6 +338,14 @@ void ServerContext::TryCancel() const { } } +void ServerContext::SetCancelCallback(std::function callback) { + completion_op_->SetCancelCallback(std::move(callback)); +} + +void ServerContext::ClearCancelCallback() { + completion_op_->ClearCancelCallback(); +} + bool ServerContext::IsCancelled() const { if (completion_tag_) { // When using callback API, this result is always valid. diff --git a/test/cpp/end2end/BUILD b/test/cpp/end2end/BUILD index de7725d163d..a51f833a3e0 100644 --- a/test/cpp/end2end/BUILD +++ b/test/cpp/end2end/BUILD @@ -89,6 +89,7 @@ grpc_cc_test( external_deps = [ "gtest", ], + tags = ["no_windows"], deps = [ ":test_service_impl", "//:gpr", @@ -245,6 +246,9 @@ grpc_cc_test( size = "large", deps = [ ":end2end_test_lib", + # DO NOT REMOVE THE grpc++ dependence below since the internal build + # system uses it to specialize targets + "//:grpc++", ], ) @@ -620,6 +624,7 @@ grpc_cc_test( external_deps = [ "gtest", ], + tags = ["no_windows"], deps = [ "//:gpr", "//:grpc", diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index f58a472bfaf..1726a7b189a 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -1381,6 +1381,61 @@ TEST_P(End2endTest, ExpectErrorTest) { } } +TEST_P(End2endTest, DelayedRpcCanceledUsingCancelCallback) { + MAYBE_SKIP_TEST; + // This test case is only relevant with callback server. + // Additionally, using interceptors makes this test subject to + // timing-dependent failures if the interceptors take too long to run. + if (!GetParam().callback_server || GetParam().use_interceptors) { + return; + } + + ResetStub(); + ClientContext context; + context.AddMetadata(kServerUseCancelCallback, + grpc::to_string(MAYBE_USE_CALLBACK_CANCEL)); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_skip_cancelled_check(true); + // Let server sleep for 40 ms first to give the cancellation a chance. + // 40 ms might seem a bit extreme but the timer manager would have been just + // initialized (when ResetStub() was called) and there are some warmup costs + // i.e the timer thread many not have even started. There might also be + // other delays in the timer manager thread (in acquiring locks, timer data + // structure manipulations, starting backup timer threads) that add to the + // delays. 40ms is still not enough in some cases but this significantly + // reduces the test flakes + request.mutable_param()->set_server_sleep_us(40 * 1000); + + std::thread echo_thread{[this, &context, &request, &response] { + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); + }}; + std::this_thread::sleep_for(std::chrono::microseconds(500)); + context.TryCancel(); + echo_thread.join(); +} + +TEST_P(End2endTest, DelayedRpcNonCanceledUsingCancelCallback) { + MAYBE_SKIP_TEST; + if (!GetParam().callback_server) { + return; + } + + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + context.AddMetadata(kServerUseCancelCallback, + grpc::to_string(MAYBE_USE_CALLBACK_NO_CANCEL)); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_TRUE(s.ok()); +} + ////////////////////////////////////////////////////////////////////////// // Test with and without a proxy. class ProxyEnd2endTest : public End2endTest { @@ -2015,7 +2070,7 @@ INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P( ProxyEnd2end, ProxyEnd2endTest, - ::testing::ValuesIn(CreateTestScenarios(true, true, true, true, false))); + ::testing::ValuesIn(CreateTestScenarios(true, true, true, true, true))); INSTANTIATE_TEST_CASE_P( SecureEnd2end, SecureEnd2endTest, @@ -2023,7 +2078,7 @@ INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P( ResourceQuotaEnd2end, ResourceQuotaEnd2endTest, - ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, false))); + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); } // namespace } // namespace testing diff --git a/test/cpp/end2end/test_service_impl.cc b/test/cpp/end2end/test_service_impl.cc index 159ea33c2bc..afc0cb0d8fd 100644 --- a/test/cpp/end2end/test_service_impl.cc +++ b/test/cpp/end2end/test_service_impl.cc @@ -126,13 +126,14 @@ void ServerTryCancelNonblocking(ServerContext* context) { } void LoopUntilCancelled(Alarm* alarm, ServerContext* context, - experimental::ServerCallbackRpcController* controller) { + experimental::ServerCallbackRpcController* controller, + int loop_delay_us) { if (!context->IsCancelled()) { alarm->experimental().Set( gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), - gpr_time_from_micros(1000, GPR_TIMESPAN)), - [alarm, context, controller](bool) { - LoopUntilCancelled(alarm, context, controller); + gpr_time_from_micros(loop_delay_us, GPR_TIMESPAN)), + [alarm, context, controller, loop_delay_us](bool) { + LoopUntilCancelled(alarm, context, controller, loop_delay_us); }); } else { controller->Finish(Status::CANCELLED); @@ -249,6 +250,16 @@ Status TestServiceImpl::CheckClientInitialMetadata(ServerContext* context, void CallbackTestServiceImpl::Echo( ServerContext* context, const EchoRequest* request, EchoResponse* response, experimental::ServerCallbackRpcController* controller) { + CancelState* cancel_state = new CancelState; + int server_use_cancel_callback = + GetIntValueFromMetadata(kServerUseCancelCallback, + context->client_metadata(), DO_NOT_USE_CALLBACK); + if (server_use_cancel_callback != DO_NOT_USE_CALLBACK) { + controller->SetCancelCallback([cancel_state] { + EXPECT_FALSE(cancel_state->callback_invoked.exchange( + true, std::memory_order_relaxed)); + }); + } // A bit of sleep to make sure that short deadline tests fail if (request->has_param() && request->param().server_sleep_us() > 0) { // Set an alarm for that much time @@ -256,11 +267,11 @@ void CallbackTestServiceImpl::Echo( gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), gpr_time_from_micros(request->param().server_sleep_us(), GPR_TIMESPAN)), - [this, context, request, response, controller](bool) { - EchoNonDelayed(context, request, response, controller); + [this, context, request, response, controller, cancel_state](bool) { + EchoNonDelayed(context, request, response, controller, cancel_state); }); } else { - EchoNonDelayed(context, request, response, controller); + EchoNonDelayed(context, request, response, controller, cancel_state); } } @@ -279,7 +290,25 @@ void CallbackTestServiceImpl::CheckClientInitialMetadata( void CallbackTestServiceImpl::EchoNonDelayed( ServerContext* context, const EchoRequest* request, EchoResponse* response, - experimental::ServerCallbackRpcController* controller) { + experimental::ServerCallbackRpcController* controller, + CancelState* cancel_state) { + int server_use_cancel_callback = + GetIntValueFromMetadata(kServerUseCancelCallback, + context->client_metadata(), DO_NOT_USE_CALLBACK); + + // Safe to clear cancel callback even if it wasn't set + controller->ClearCancelCallback(); + if (server_use_cancel_callback == MAYBE_USE_CALLBACK_CANCEL) { + EXPECT_TRUE(context->IsCancelled()); + EXPECT_TRUE(cancel_state->callback_invoked.load(std::memory_order_relaxed)); + delete cancel_state; + controller->Finish(Status::CANCELLED); + return; + } + + EXPECT_FALSE(cancel_state->callback_invoked.load(std::memory_order_relaxed)); + delete cancel_state; + if (request->has_param() && request->param().server_die()) { gpr_log(GPR_ERROR, "The request should not reach application handler."); GPR_ASSERT(0); @@ -301,9 +330,11 @@ void CallbackTestServiceImpl::EchoNonDelayed( EXPECT_FALSE(context->IsCancelled()); context->TryCancel(); gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); - // Now wait until it's really canceled - LoopUntilCancelled(&alarm_, context, controller); + if (server_use_cancel_callback == DO_NOT_USE_CALLBACK) { + // Now wait until it's really canceled + LoopUntilCancelled(&alarm_, context, controller, 1000); + } return; } @@ -318,20 +349,11 @@ void CallbackTestServiceImpl::EchoNonDelayed( std::unique_lock lock(mu_); signal_client_ = true; } - std::function recurrence = [this, context, request, controller, - &recurrence](bool) { - if (!context->IsCancelled()) { - alarm_.experimental().Set( - gpr_time_add( - gpr_now(GPR_CLOCK_REALTIME), - gpr_time_from_micros(request->param().client_cancel_after_us(), - GPR_TIMESPAN)), - recurrence); - } else { - controller->Finish(Status::CANCELLED); - } - }; - recurrence(true); + if (server_use_cancel_callback == DO_NOT_USE_CALLBACK) { + // Now wait until it's really canceled + LoopUntilCancelled(&alarm_, context, controller, + request->param().client_cancel_after_us()); + } return; } else if (request->has_param() && request->param().server_cancel_after_us()) { diff --git a/test/cpp/end2end/test_service_impl.h b/test/cpp/end2end/test_service_impl.h index e36423d44e4..9a52bed1ea7 100644 --- a/test/cpp/end2end/test_service_impl.h +++ b/test/cpp/end2end/test_service_impl.h @@ -33,6 +33,7 @@ namespace testing { const int kServerDefaultResponseStreamsToSend = 3; const char* const kServerResponseStreamsToSend = "server_responses_to_send"; const char* const kServerTryCancelRequest = "server_try_cancel"; +const char* const kServerUseCancelCallback = "server_use_cancel_callback"; const char* const kDebugInfoTrailerKey = "debug-info-bin"; const char* const kServerFinishAfterNReads = "server_finish_after_n_reads"; const char* const kServerUseCoalescingApi = "server_use_coalescing_api"; @@ -46,6 +47,12 @@ typedef enum { CANCEL_AFTER_PROCESSING } ServerTryCancelRequestPhase; +typedef enum { + DO_NOT_USE_CALLBACK = 0, + MAYBE_USE_CALLBACK_CANCEL, + MAYBE_USE_CALLBACK_NO_CANCEL, +} ServerUseCancelCallback; + class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { public: TestServiceImpl() : signal_client_(false), host_() {} @@ -115,9 +122,13 @@ class CallbackTestServiceImpl } private: + struct CancelState { + std::atomic_bool callback_invoked{false}; + }; void EchoNonDelayed(ServerContext* context, const EchoRequest* request, EchoResponse* response, - experimental::ServerCallbackRpcController* controller); + experimental::ServerCallbackRpcController* controller, + CancelState* cancel_state); Alarm alarm_; bool signal_client_;