diff --git a/test/cpp/end2end/thread_stress_test.cc b/test/cpp/end2end/thread_stress_test.cc index 4e8860e8432..4c7caa9b878 100644 --- a/test/cpp/end2end/thread_stress_test.cc +++ b/test/cpp/end2end/thread_stress_test.cc @@ -45,6 +45,7 @@ #include #include +#include "src/core/surface/api_trace.h" #include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" #include "src/proto/grpc/testing/echo.grpc.pb.h" #include "test/core/util/port.h" @@ -54,6 +55,9 @@ using grpc::testing::EchoRequest; using grpc::testing::EchoResponse; using std::chrono::system_clock; +const int kNumThreads = 100; // Number of threads +const int kNumRpcs = 1000; // Number of RPCs per thread + namespace grpc { namespace testing { @@ -84,7 +88,7 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { MaybeEchoDeadline(context, request, response); if (request->has_param() && request->param().client_cancel_after_us()) { { - std::unique_lock lock(mu_); + unique_lock lock(mu_); signal_client_ = true; } while (!context->IsCancelled()) { @@ -149,13 +153,13 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { } bool signal_client() { - std::unique_lock lock(mu_); + unique_lock lock(mu_); return signal_client_; } private: bool signal_client_; - std::mutex mu_; + mutex mu_; }; class TestServiceImplDupPkg @@ -168,11 +172,10 @@ class TestServiceImplDupPkg } }; -class End2endTest : public ::testing::Test { - protected: - End2endTest() : kMaxMessageSize_(8192) {} - - void SetUp() GRPC_OVERRIDE { +class CommonStressTest { + public: + CommonStressTest() : kMaxMessageSize_(8192) {} + void SetUp() { int port = grpc_pick_unused_port_or_die(); server_address_ << "localhost:" << port; // Setup server @@ -185,15 +188,15 @@ class End2endTest : public ::testing::Test { builder.RegisterService(&dup_pkg_service_); server_ = builder.BuildAndStart(); } - - void TearDown() GRPC_OVERRIDE { server_->Shutdown(); } - + void TearDown() { server_->Shutdown(); } void ResetStub() { std::shared_ptr channel = CreateChannel(server_address_.str(), InsecureChannelCredentials()); stub_ = grpc::testing::EchoTestService::NewStub(channel); } + grpc::testing::EchoTestService::Stub* GetStub() { return stub_.get(); } + private: std::unique_ptr stub_; std::unique_ptr server_; std::ostringstream server_address_; @@ -202,6 +205,16 @@ class End2endTest : public ::testing::Test { TestServiceImplDupPkg dup_pkg_service_; }; +class End2endTest : public ::testing::Test { + protected: + End2endTest() {} + void SetUp() GRPC_OVERRIDE { common_.SetUp(); } + void TearDown() GRPC_OVERRIDE { common_.TearDown(); } + void ResetStub() { common_.ResetStub(); } + + CommonStressTest common_; +}; + static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) { EchoRequest request; EchoResponse response; @@ -216,17 +229,113 @@ static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) { } TEST_F(End2endTest, ThreadStress) { - ResetStub(); + common_.ResetStub(); std::vector threads; - for (int i = 0; i < 100; ++i) { - threads.push_back(new std::thread(SendRpc, stub_.get(), 1000)); + for (int i = 0; i < kNumThreads; ++i) { + threads.push_back(new std::thread(SendRpc, common_.GetStub(), kNumRpcs)); } - for (int i = 0; i < 100; ++i) { + for (int i = 0; i < kNumThreads; ++i) { threads[i]->join(); delete threads[i]; } } +class AsyncClientEnd2endTest : public ::testing::Test { + protected: + AsyncClientEnd2endTest() : rpcs_outstanding_(0) {} + + void SetUp() GRPC_OVERRIDE { common_.SetUp(); } + void TearDown() GRPC_OVERRIDE { + void* ignored_tag; + bool ignored_ok; + while (cq_.Next(&ignored_tag, &ignored_ok)) + ; + common_.TearDown(); + } + + void Wait() { + unique_lock l(mu_); + while (rpcs_outstanding_ != 0) { + cv_.wait(l); + } + + cq_.Shutdown(); + } + + struct AsyncClientCall { + EchoResponse response; + ClientContext context; + Status status; + std::unique_ptr> response_reader; + }; + + void AsyncSendRpc(int num_rpcs) { + for (int i = 0; i < num_rpcs; ++i) { + AsyncClientCall* call = new AsyncClientCall; + EchoRequest request; + request.set_message("Hello"); + call->response_reader = + common_.GetStub()->AsyncEcho(&call->context, request, &cq_); + call->response_reader->Finish(&call->response, &call->status, + (void*)call); + + unique_lock l(mu_); + rpcs_outstanding_++; + } + } + + void AsyncCompleteRpc() { + while (true) { + void* got_tag; + bool ok = false; + if (!cq_.Next(&got_tag, &ok)) break; + AsyncClientCall* call = static_cast(got_tag); + GPR_ASSERT(ok); + delete call; + + bool notify; + { + unique_lock l(mu_); + rpcs_outstanding_--; + notify = (rpcs_outstanding_ == 0); + } + if (notify) { + cv_.notify_all(); + } + } + } + + CommonStressTest common_; + CompletionQueue cq_; + mutex mu_; + condition_variable cv_; + int rpcs_outstanding_; +}; + +TEST_F(AsyncClientEnd2endTest, ThreadStress) { + common_.ResetStub(); + std::vector send_threads, completion_threads; + for (int i = 0; i < kNumThreads / 2; ++i) { + completion_threads.push_back(new std::thread( + &AsyncClientEnd2endTest_ThreadStress_Test::AsyncCompleteRpc, this)); + } + for (int i = 0; i < kNumThreads / 2; ++i) { + send_threads.push_back( + new std::thread(&AsyncClientEnd2endTest_ThreadStress_Test::AsyncSendRpc, + this, kNumRpcs)); + } + for (int i = 0; i < kNumThreads / 2; ++i) { + send_threads[i]->join(); + delete send_threads[i]; + } + + Wait(); + for (int i = 0; i < kNumThreads / 2; ++i) { + completion_threads[i]->join(); + delete completion_threads[i]; + } +} + } // namespace testing } // namespace grpc