Fix race at server shutdown between actual shutdown and MatchOrQueue (#25541)

* Fix race at server shutdown between actual shutdown and MatchOrQueue

* Address reviewer comments

* Add thread safety annotations

* Address reviewer comments
pull/25512/head
Vijay Pai 4 years ago committed by GitHub
parent fe37853055
commit 37bd0a0cbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 69
      src/core/lib/surface/server.cc
  2. 64
      src/core/lib/surface/server.h
  3. 65
      src/cpp/server/server_cc.cc

@ -318,7 +318,8 @@ class Server::RealRequestMatcher : public RequestMatcherInterface {
// advance or queue up any incoming RPC for later match. Instead, MatchOrQueue // advance or queue up any incoming RPC for later match. Instead, MatchOrQueue
// will call out to an allocation function passed in at the construction of the // will call out to an allocation function passed in at the construction of the
// object. These request matchers are designed for the C++ callback API, so they // object. These request matchers are designed for the C++ callback API, so they
// only support 1 completion queue (passed in at the constructor). // only support 1 completion queue (passed in at the constructor). They are also
// used for the sync API.
class Server::AllocatingRequestMatcherBase : public RequestMatcherInterface { class Server::AllocatingRequestMatcherBase : public RequestMatcherInterface {
public: public:
AllocatingRequestMatcherBase(Server* server, grpc_completion_queue* cq) AllocatingRequestMatcherBase(Server* server, grpc_completion_queue* cq)
@ -370,15 +371,20 @@ class Server::AllocatingRequestMatcherBatch
void MatchOrQueue(size_t /*start_request_queue_index*/, void MatchOrQueue(size_t /*start_request_queue_index*/,
CallData* calld) override { CallData* calld) override {
BatchCallAllocation call_info = allocator_(); if (server()->ShutdownRefOnRequest()) {
GPR_ASSERT(server()->ValidateServerRequest( BatchCallAllocation call_info = allocator_();
cq(), static_cast<void*>(call_info.tag), nullptr, nullptr) == GPR_ASSERT(server()->ValidateServerRequest(
GRPC_CALL_OK); cq(), static_cast<void*>(call_info.tag), nullptr,
RequestedCall* rc = new RequestedCall( nullptr) == GRPC_CALL_OK);
static_cast<void*>(call_info.tag), call_info.cq, call_info.call, RequestedCall* rc = new RequestedCall(
call_info.initial_metadata, call_info.details); static_cast<void*>(call_info.tag), call_info.cq, call_info.call,
calld->SetState(CallData::CallState::ACTIVATED); call_info.initial_metadata, call_info.details);
calld->Publish(cq_idx(), rc); calld->SetState(CallData::CallState::ACTIVATED);
calld->Publish(cq_idx(), rc);
} else {
calld->FailCallCreation();
}
server()->ShutdownUnrefOnRequest();
} }
private: private:
@ -398,15 +404,21 @@ class Server::AllocatingRequestMatcherRegistered
void MatchOrQueue(size_t /*start_request_queue_index*/, void MatchOrQueue(size_t /*start_request_queue_index*/,
CallData* calld) override { CallData* calld) override {
RegisteredCallAllocation call_info = allocator_(); if (server()->ShutdownRefOnRequest()) {
GPR_ASSERT(server()->ValidateServerRequest( RegisteredCallAllocation call_info = allocator_();
cq(), call_info.tag, call_info.optional_payload, GPR_ASSERT(server()->ValidateServerRequest(
registered_method_) == GRPC_CALL_OK); cq(), call_info.tag, call_info.optional_payload,
RequestedCall* rc = new RequestedCall( registered_method_) == GRPC_CALL_OK);
call_info.tag, call_info.cq, call_info.call, call_info.initial_metadata, RequestedCall* rc =
registered_method_, call_info.deadline, call_info.optional_payload); new RequestedCall(call_info.tag, call_info.cq, call_info.call,
calld->SetState(CallData::CallState::ACTIVATED); call_info.initial_metadata, registered_method_,
calld->Publish(cq_idx(), rc); call_info.deadline, call_info.optional_payload);
calld->SetState(CallData::CallState::ACTIVATED);
calld->Publish(cq_idx(), rc);
} else {
calld->FailCallCreation();
}
server()->ShutdownUnrefOnRequest();
} }
private: private:
@ -709,7 +721,7 @@ void Server::FailCall(size_t cq_idx, RequestedCall* rc, grpc_error* error) {
// Before calling MaybeFinishShutdown(), we must hold mu_global_ and not // Before calling MaybeFinishShutdown(), we must hold mu_global_ and not
// hold mu_call_. // hold mu_call_.
void Server::MaybeFinishShutdown() { void Server::MaybeFinishShutdown() {
if (!shutdown_flag_.load(std::memory_order_acquire) || shutdown_published_) { if (!ShutdownReady() || shutdown_published_) {
return; return;
} }
{ {
@ -803,19 +815,18 @@ void Server::ShutdownAndNotify(grpc_completion_queue* cq, void* tag) {
return; return;
} }
shutdown_tags_.emplace_back(tag, cq); shutdown_tags_.emplace_back(tag, cq);
if (shutdown_flag_.load(std::memory_order_acquire)) { if (ShutdownCalled()) {
return; return;
} }
last_shutdown_message_time_ = gpr_now(GPR_CLOCK_REALTIME); last_shutdown_message_time_ = gpr_now(GPR_CLOCK_REALTIME);
broadcaster.FillChannelsLocked(GetChannelsLocked()); broadcaster.FillChannelsLocked(GetChannelsLocked());
shutdown_flag_.store(true, std::memory_order_release);
// Collect all unregistered then registered calls. // Collect all unregistered then registered calls.
{ {
MutexLock lock(&mu_call_); MutexLock lock(&mu_call_);
KillPendingWorkLocked( KillPendingWorkLocked(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server Shutdown")); GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server Shutdown"));
} }
MaybeFinishShutdown(); ShutdownUnrefOnShutdownCall();
} }
// Shutdown listeners. // Shutdown listeners.
for (auto& listener : listeners_) { for (auto& listener : listeners_) {
@ -847,8 +858,7 @@ void Server::CancelAllCalls() {
void Server::Orphan() { void Server::Orphan() {
{ {
MutexLock lock(&mu_global_); MutexLock lock(&mu_global_);
GPR_ASSERT(shutdown_flag_.load(std::memory_order_acquire) || GPR_ASSERT(ShutdownCalled() || listeners_.empty());
listeners_.empty());
GPR_ASSERT(listeners_destroyed_ == listeners_.size()); GPR_ASSERT(listeners_destroyed_ == listeners_.size());
} }
if (default_resource_user_ != nullptr) { if (default_resource_user_ != nullptr) {
@ -895,7 +905,7 @@ grpc_call_error Server::ValidateServerRequestAndCq(
} }
grpc_call_error Server::QueueRequestedCall(size_t cq_idx, RequestedCall* rc) { grpc_call_error Server::QueueRequestedCall(size_t cq_idx, RequestedCall* rc) {
if (shutdown_flag_.load(std::memory_order_acquire)) { if (ShutdownCalled()) {
FailCall(cq_idx, rc, FailCall(cq_idx, rc,
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server Shutdown")); GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server Shutdown"));
return GRPC_CALL_OK; return GRPC_CALL_OK;
@ -1064,7 +1074,7 @@ void Server::ChannelData::InitTransport(RefCountedPtr<Server> server,
op->set_accept_stream_fn = AcceptStream; op->set_accept_stream_fn = AcceptStream;
op->set_accept_stream_user_data = this; op->set_accept_stream_user_data = this;
op->start_connectivity_watch = MakeOrphanable<ConnectivityWatcher>(this); op->start_connectivity_watch = MakeOrphanable<ConnectivityWatcher>(this);
if (server_->shutdown_flag_.load(std::memory_order_acquire)) { if (server_->ShutdownCalled()) {
op->disconnect_with_error = op->disconnect_with_error =
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server shutdown"); GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server shutdown");
} }
@ -1280,8 +1290,7 @@ void Server::CallData::PublishNewRpc(void* arg, grpc_error* error) {
auto* chand = static_cast<Server::ChannelData*>(call_elem->channel_data); auto* chand = static_cast<Server::ChannelData*>(call_elem->channel_data);
RequestMatcherInterface* rm = calld->matcher_; RequestMatcherInterface* rm = calld->matcher_;
Server* server = rm->server(); Server* server = rm->server();
if (error != GRPC_ERROR_NONE || if (error != GRPC_ERROR_NONE || server->ShutdownCalled()) {
server->shutdown_flag_.load(std::memory_order_acquire)) {
calld->state_.Store(CallState::ZOMBIED, MemoryOrder::RELAXED); calld->state_.Store(CallState::ZOMBIED, MemoryOrder::RELAXED);
calld->KillZombie(); calld->KillZombie();
return; return;
@ -1305,7 +1314,7 @@ void Server::CallData::KillZombie() {
void Server::CallData::StartNewRpc(grpc_call_element* elem) { void Server::CallData::StartNewRpc(grpc_call_element* elem) {
auto* chand = static_cast<ChannelData*>(elem->channel_data); auto* chand = static_cast<ChannelData*>(elem->channel_data);
if (server_->shutdown_flag_.load(std::memory_order_acquire)) { if (server_->ShutdownCalled()) {
state_.Store(CallState::ZOMBIED, MemoryOrder::RELAXED); state_.Store(CallState::ZOMBIED, MemoryOrder::RELAXED);
KillZombie(); KillZombie();
return; return;

@ -92,7 +92,7 @@ class Server : public InternallyRefCounted<Server> {
explicit Server(const grpc_channel_args* args); explicit Server(const grpc_channel_args* args);
~Server() override; ~Server() override;
void Orphan() override; void Orphan() ABSL_LOCKS_EXCLUDED(mu_global_) override;
const grpc_channel_args* channel_args() const { return channel_args_; } const grpc_channel_args* channel_args() const { return channel_args_; }
grpc_resource_user* default_resource_user() const { grpc_resource_user* default_resource_user() const {
@ -114,7 +114,7 @@ class Server : public InternallyRefCounted<Server> {
config_fetcher_ = std::move(config_fetcher); config_fetcher_ = std::move(config_fetcher);
} }
bool HasOpenConnections(); bool HasOpenConnections() ABSL_LOCKS_EXCLUDED(mu_global_);
// Adds a listener to the server. When the server starts, it will call // Adds a listener to the server. When the server starts, it will call
// the listener's Start() method, and when it shuts down, it will orphan // the listener's Start() method, and when it shuts down, it will orphan
@ -122,7 +122,7 @@ class Server : public InternallyRefCounted<Server> {
void AddListener(OrphanablePtr<ListenerInterface> listener); void AddListener(OrphanablePtr<ListenerInterface> listener);
// Starts listening for connections. // Starts listening for connections.
void Start(); void Start() ABSL_LOCKS_EXCLUDED(mu_global_);
// Sets up a transport. Creates a channel stack and binds the transport to // Sets up a transport. Creates a channel stack and binds the transport to
// the server. Called from the listener when a new connection is accepted. // the server. Called from the listener when a new connection is accepted.
@ -160,9 +160,10 @@ class Server : public InternallyRefCounted<Server> {
grpc_completion_queue* cq_bound_to_call, grpc_completion_queue* cq_bound_to_call,
grpc_completion_queue* cq_for_notification, void* tag_new); grpc_completion_queue* cq_for_notification, void* tag_new);
void ShutdownAndNotify(grpc_completion_queue* cq, void* tag); void ShutdownAndNotify(grpc_completion_queue* cq, void* tag)
ABSL_LOCKS_EXCLUDED(mu_global_, mu_call_);
void CancelAllCalls(); void CancelAllCalls() ABSL_LOCKS_EXCLUDED(mu_global_);
private: private:
struct RequestedCall; struct RequestedCall;
@ -209,7 +210,7 @@ class Server : public InternallyRefCounted<Server> {
static void AcceptStream(void* arg, grpc_transport* /*transport*/, static void AcceptStream(void* arg, grpc_transport* /*transport*/,
const void* transport_server_data); const void* transport_server_data);
void Destroy(); void Destroy() ABSL_EXCLUSIVE_LOCKS_REQUIRED(server_->mu_global_);
static void FinishDestroy(void* arg, grpc_error* error); static void FinishDestroy(void* arg, grpc_error* error);
@ -345,9 +346,11 @@ class Server : public InternallyRefCounted<Server> {
void FailCall(size_t cq_idx, RequestedCall* rc, grpc_error* error); void FailCall(size_t cq_idx, RequestedCall* rc, grpc_error* error);
grpc_call_error QueueRequestedCall(size_t cq_idx, RequestedCall* rc); grpc_call_error QueueRequestedCall(size_t cq_idx, RequestedCall* rc);
void MaybeFinishShutdown(); void MaybeFinishShutdown() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_global_)
ABSL_LOCKS_EXCLUDED(mu_call_);
void KillPendingWorkLocked(grpc_error* error); void KillPendingWorkLocked(grpc_error* error)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_call_);
static grpc_call_error ValidateServerRequest( static grpc_call_error ValidateServerRequest(
grpc_completion_queue* cq_for_notification, void* tag, grpc_completion_queue* cq_for_notification, void* tag,
@ -358,6 +361,39 @@ class Server : public InternallyRefCounted<Server> {
std::vector<grpc_channel*> GetChannelsLocked() const; std::vector<grpc_channel*> GetChannelsLocked() const;
// Take a shutdown ref for a request (increment by 2) and return if shutdown
// has already been called.
bool ShutdownRefOnRequest() {
int old_value = shutdown_refs_.FetchAdd(2, MemoryOrder::ACQ_REL);
return (old_value & 1) != 0;
}
// Decrement the shutdown ref counter by either 1 (for shutdown call) or 2
// (for in-flight request) and possibly call MaybeFinishShutdown if
// appropriate.
void ShutdownUnrefOnRequest() ABSL_LOCKS_EXCLUDED(mu_global_) {
if (shutdown_refs_.FetchSub(2, MemoryOrder::ACQ_REL) == 2) {
MutexLock lock(&mu_global_);
MaybeFinishShutdown();
}
}
void ShutdownUnrefOnShutdownCall() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_global_) {
if (shutdown_refs_.FetchSub(1, MemoryOrder::ACQ_REL) == 1) {
MaybeFinishShutdown();
}
}
bool ShutdownCalled() const {
return (shutdown_refs_.Load(MemoryOrder::ACQUIRE) & 1) == 0;
}
// Returns whether there are no more shutdown refs, which means that shutdown
// has been called and all accepted requests have been published if using an
// AllocatingRequestMatcher.
bool ShutdownReady() const {
return shutdown_refs_.Load(MemoryOrder::ACQUIRE) == 0;
}
grpc_channel_args* const channel_args_; grpc_channel_args* const channel_args_;
grpc_resource_user* default_resource_user_ = nullptr; grpc_resource_user* default_resource_user_ = nullptr;
RefCountedPtr<channelz::ServerNode> channelz_node_; RefCountedPtr<channelz::ServerNode> channelz_node_;
@ -387,9 +423,15 @@ class Server : public InternallyRefCounted<Server> {
// Request matcher for unregistered methods. // Request matcher for unregistered methods.
std::unique_ptr<RequestMatcherInterface> unregistered_request_matcher_; std::unique_ptr<RequestMatcherInterface> unregistered_request_matcher_;
std::atomic_bool shutdown_flag_{false}; // The shutdown refs counter tracks whether or not shutdown has been called
bool shutdown_published_ = false; // and whether there are any AllocatingRequestMatcher requests that have been
std::vector<ShutdownTag> shutdown_tags_; // accepted but not yet started (+2 on each one). If shutdown has been called,
// the lowest bit will be 0 (defaults to 1) and the counter will be even. The
// server should not notify on shutdown until the counter is 0 (shutdown is
// called and there are no requests that are accepted but not started).
Atomic<int> shutdown_refs_{1};
bool shutdown_published_ ABSL_GUARDED_BY(mu_global_) = false;
std::vector<ShutdownTag> shutdown_tags_ ABSL_GUARDED_BY(mu_global_);
std::list<ChannelData*> channels_; std::list<ChannelData*> channels_;

@ -356,17 +356,18 @@ class Server::SyncRequest final : public grpc::internal::CompletionQueueTag {
} }
~SyncRequest() override { ~SyncRequest() override {
// The destructor should only cleanup those objects created in the
// constructor, since some paths may or may not actually go through the
// Run stage where other objects are allocated.
if (has_request_payload_ && request_payload_) { if (has_request_payload_ && request_payload_) {
grpc_byte_buffer_destroy(request_payload_); grpc_byte_buffer_destroy(request_payload_);
} }
wrapped_call_.Destroy();
ctx_.Destroy();
if (call_details_ != nullptr) { if (call_details_ != nullptr) {
grpc_call_details_destroy(call_details_); grpc_call_details_destroy(call_details_);
delete call_details_; delete call_details_;
} }
grpc_metadata_array_destroy(&request_metadata_); grpc_metadata_array_destroy(&request_metadata_);
server_->UnrefWithPossibleNotify();
} }
bool FinalizeResult(void** /*tag*/, bool* status) override { bool FinalizeResult(void** /*tag*/, bool* status) override {
@ -424,26 +425,35 @@ class Server::SyncRequest final : public grpc::internal::CompletionQueueTag {
} }
void ContinueRunAfterInterception() { void ContinueRunAfterInterception() {
{ ctx_->ctx.BeginCompletionOp(&*wrapped_call_, nullptr, nullptr);
ctx_->ctx.BeginCompletionOp(&*wrapped_call_, nullptr, nullptr); global_callbacks_->PreSynchronousRequest(&ctx_->ctx);
global_callbacks_->PreSynchronousRequest(&ctx_->ctx); auto* handler = resources_ ? method_->handler()
auto* handler = resources_ ? method_->handler() : server_->resource_exhausted_handler_.get();
: server_->resource_exhausted_handler_.get(); handler->RunHandler(grpc::internal::MethodHandler::HandlerParameter(
handler->RunHandler(grpc::internal::MethodHandler::HandlerParameter( &*wrapped_call_, &ctx_->ctx, deserialized_request_, request_status_,
&*wrapped_call_, &ctx_->ctx, deserialized_request_, request_status_, nullptr, nullptr));
nullptr, nullptr)); global_callbacks_->PostSynchronousRequest(&ctx_->ctx);
global_callbacks_->PostSynchronousRequest(&ctx_->ctx);
cq_.Shutdown(); cq_.Shutdown();
grpc::internal::CompletionQueueTag* op_tag = grpc::internal::CompletionQueueTag* op_tag = ctx_->ctx.GetCompletionOpTag();
ctx_->ctx.GetCompletionOpTag(); cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME));
cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME));
/* Ensure the cq_ is shutdown */ // Ensure the cq_ is shutdown
grpc::PhonyTag ignored_tag; grpc::PhonyTag ignored_tag;
GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); GPR_ASSERT(cq_.Pluck(&ignored_tag) == false);
}
// Cleanup structures allocated during Run/ContinueRunAfterInterception
wrapped_call_.Destroy();
ctx_.Destroy();
delete this;
}
// For requests that must be only cleaned up but not actually Run
void Cleanup() {
cq_.Shutdown();
grpc_call_unref(call_);
delete this; delete this;
} }
@ -459,6 +469,7 @@ class Server::SyncRequest final : public grpc::internal::CompletionQueueTag {
template <class CallAllocation> template <class CallAllocation>
void CommonSetup(CallAllocation* data) { void CommonSetup(CallAllocation* data) {
server_->Ref();
grpc_metadata_array_init(&request_metadata_); grpc_metadata_array_init(&request_metadata_);
data->tag = static_cast<void*>(this); data->tag = static_cast<void*>(this);
data->call = &call_; data->call = &call_;
@ -473,7 +484,7 @@ class Server::SyncRequest final : public grpc::internal::CompletionQueueTag {
grpc_call_details* call_details_ = nullptr; grpc_call_details* call_details_ = nullptr;
gpr_timespec deadline_; gpr_timespec deadline_;
grpc_metadata_array request_metadata_; grpc_metadata_array request_metadata_;
grpc_byte_buffer* request_payload_; grpc_byte_buffer* request_payload_ = nullptr;
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
grpc::Status request_status_; grpc::Status request_status_;
std::shared_ptr<GlobalCallbacks> global_callbacks_; std::shared_ptr<GlobalCallbacks> global_callbacks_;
@ -812,9 +823,9 @@ class Server::SyncRequestThreadManager : public grpc::ThreadManager {
void* tag; void* tag;
bool ok; bool ok;
while (server_cq_->Next(&tag, &ok)) { while (server_cq_->Next(&tag, &ok)) {
// Drain the item and don't do any work on it. It is possible to see this // This problem can arise if the server CQ gets a request queued to it
// if there is an explicit call to Wait that is not part of the actual // before it gets shutdown but then pulls it after shutdown.
// Shutdown. static_cast<SyncRequest*>(tag)->Cleanup();
} }
} }
@ -1228,6 +1239,9 @@ void Server::ShutdownInternal(gpr_timespec deadline) {
// Else in case of SHUTDOWN or GOT_EVENT, it means that the server has // Else in case of SHUTDOWN or GOT_EVENT, it means that the server has
// successfully shutdown // successfully shutdown
// Drop the shutdown ref and wait for all other refs to drop as well.
UnrefAndWaitLocked();
// Shutdown all ThreadManagers. This will try to gracefully stop all the // Shutdown all ThreadManagers. This will try to gracefully stop all the
// threads in the ThreadManagers (once they process any inflight requests) // threads in the ThreadManagers (once they process any inflight requests)
for (const auto& value : sync_req_mgrs_) { for (const auto& value : sync_req_mgrs_) {
@ -1239,9 +1253,6 @@ void Server::ShutdownInternal(gpr_timespec deadline) {
value->Wait(); value->Wait();
} }
// Drop the shutdown ref and wait for all other refs to drop as well.
UnrefAndWaitLocked();
// Shutdown the callback CQ. The CQ is owned by its own shutdown tag, so it // Shutdown the callback CQ. The CQ is owned by its own shutdown tag, so it
// will delete itself at true shutdown. // will delete itself at true shutdown.
if (callback_cq_ != nullptr) { if (callback_cq_ != nullptr) {

Loading…
Cancel
Save