diff --git a/test/cpp/qps/client_async.cc b/test/cpp/qps/client_async.cc index 057e5a0d6be..a0705673bdc 100644 --- a/test/cpp/qps/client_async.cc +++ b/test/cpp/qps/client_async.cc @@ -174,6 +174,7 @@ class AsyncClient : public ClientImpl { for (int i = 0; i < num_async_threads_; i++) { cli_cqs_.emplace_back(new CompletionQueue); next_issuers_.emplace_back(NextIssuer(i)); + shutdown_state_.emplace_back(new PerThreadShutdownState()); } using namespace std::placeholders; @@ -189,7 +190,21 @@ class AsyncClient : public ClientImpl { } } virtual ~AsyncClient() { - FinalShutdownCQs(); + for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) { + std::lock_guard lock((*ss)->mutex); + (*ss)->shutdown = true; + } + for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { + (*cq)->Shutdown(); + } + this->EndThreads(); // Need "this->" for resolution + for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { + void* got_tag; + bool ok; + while ((*cq)->Next(&got_tag, &ok)) { + delete ClientRpcContext::detag(got_tag); + } + } } bool ThreadFunc(HistogramEntry* entry, @@ -200,7 +215,12 @@ class AsyncClient : public ClientImpl { if (cli_cqs_[thread_idx]->Next(&got_tag, &ok)) { // Got a regular event, so process it ClientRpcContext* ctx = ClientRpcContext::detag(got_tag); - if (!ctx->RunNextState(ok, entry)) { + // Proceed while holding a lock to make sure that + // this thread isn't supposed to shut down + std::lock_guard l(shutdown_state_[thread_idx]->mutex); + if (shutdown_state_[thread_idx]->shutdown) { + return true; + } else if (!ctx->RunNextState(ok, entry)) { // The RPC and callback are done, so clone the ctx // and kickstart the new one auto clone = ctx->StartNewClone(); @@ -217,22 +237,13 @@ class AsyncClient : public ClientImpl { protected: const int num_async_threads_; - void ShutdownCQs() { - for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { - (*cq)->Shutdown(); - } - } - void FinalShutdownCQs() { - for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { - void* got_tag; - bool ok; - while ((*cq)->Next(&got_tag, &ok)) { - delete ClientRpcContext::detag(got_tag); - } - } - } - private: + struct PerThreadShutdownState { + mutable std::mutex mutex; + bool shutdown; + PerThreadShutdownState() : shutdown(false) {} + }; + int NumThreads(const ClientConfig& config) { int num_threads = config.async_client_threads(); if (num_threads <= 0) { // Use dynamic sizing @@ -241,9 +252,9 @@ class AsyncClient : public ClientImpl { } return num_threads; } - std::vector> cli_cqs_; std::vector> next_issuers_; + std::vector> shutdown_state_; }; static std::unique_ptr BenchmarkStubCreator( @@ -259,10 +270,7 @@ class AsyncUnaryClient GRPC_FINAL config, SetupCtx, BenchmarkStubCreator) { StartThreads(num_async_threads_); } - ~AsyncUnaryClient() GRPC_OVERRIDE { - ShutdownCQs(); - EndThreads(); - } + ~AsyncUnaryClient() GRPC_OVERRIDE {} private: static void CheckDone(grpc::Status s, SimpleResponse* response) {} @@ -391,10 +399,7 @@ class AsyncStreamingClient GRPC_FINAL StartThreads(num_async_threads_); } - ~AsyncStreamingClient() GRPC_OVERRIDE { - ShutdownCQs(); - EndThreads(); - } + ~AsyncStreamingClient() GRPC_OVERRIDE {} private: static void CheckDone(grpc::Status s, SimpleResponse* response) {} @@ -530,10 +535,7 @@ class GenericAsyncStreamingClient GRPC_FINAL StartThreads(num_async_threads_); } - ~GenericAsyncStreamingClient() GRPC_OVERRIDE { - ShutdownCQs(); - EndThreads(); - } + ~GenericAsyncStreamingClient() GRPC_OVERRIDE {} private: static void CheckDone(grpc::Status s, ByteBuffer* response) {} diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc index c9954d0d02d..85acefa00b5 100644 --- a/test/cpp/qps/server_async.cc +++ b/test/cpp/qps/server_async.cc @@ -123,21 +123,22 @@ class AsyncQpsServerTest : public Server { for (int i = 0; i < num_threads; i++) { shutdown_state_.emplace_back(new PerThreadShutdownState()); - } - for (int i = 0; i < num_threads; i++) { threads_.emplace_back(&AsyncQpsServerTest::ThreadFunc, this, i); } } ~AsyncQpsServerTest() { for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) { - (*ss)->set_shutdown(); + std::lock_guard lock((*ss)->mutex); + (*ss)->shutdown = true; } server_->Shutdown(); + for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) { + (*cq)->Shutdown(); + } for (auto thr = threads_.begin(); thr != threads_.end(); thr++) { thr->join(); } for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) { - (*cq)->Shutdown(); bool ok; void *got_tag; while ((*cq)->Next(&got_tag, &ok)) @@ -150,21 +151,21 @@ class AsyncQpsServerTest : public Server { } private: - void ThreadFunc(int rank) { + void ThreadFunc(int thread_idx) { // Wait until work is available or we are shutting down bool ok; void *got_tag; - while (srv_cqs_[rank]->Next(&got_tag, &ok)) { + while (srv_cqs_[thread_idx]->Next(&got_tag, &ok)) { ServerRpcContext *ctx = detag(got_tag); // The tag is a pointer to an RPC context to invoke + // Proceed while holding a lock to make sure that + // this thread isn't supposed to shut down + std::lock_guard l(shutdown_state_[thread_idx]->mutex); + if (shutdown_state_[thread_idx]->shutdown) { return; } const bool still_going = ctx->RunNextState(ok); - if (!shutdown_state_[rank]->shutdown()) { - // this RPC context is done, so refresh it - if (!still_going) { - ctx->Reset(); - } - } else { - return; + // if this RPC context is done, refresh it + if (!still_going) { + ctx->Reset(); } } return; @@ -333,24 +334,12 @@ class AsyncQpsServerTest : public Server { ServiceType async_service_; std::forward_list contexts_; - class PerThreadShutdownState { - public: - PerThreadShutdownState() : shutdown_(false) {} - - bool shutdown() const { - std::lock_guard lock(mutex_); - return shutdown_; - } - - void set_shutdown() { - std::lock_guard lock(mutex_); - shutdown_ = true; - } - - private: - mutable std::mutex mutex_; - bool shutdown_; + struct PerThreadShutdownState { + mutable std::mutex mutex; + bool shutdown; + PerThreadShutdownState() : shutdown(false) {} }; + std::vector> shutdown_state_; };