diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h index 668d9419169..0b9837660be 100644 --- a/test/cpp/qps/client.h +++ b/test/cpp/qps/client.h @@ -236,6 +236,21 @@ class Client { return 0; } + bool IsClosedLoop() { return closed_loop_; } + + gpr_timespec NextIssueTime(int thread_idx) { + const gpr_timespec result = next_time_[thread_idx]; + next_time_[thread_idx] = + gpr_time_add(next_time_[thread_idx], + gpr_time_from_nanos(interarrival_timer_.next(thread_idx), + GPR_TIMESPAN)); + return result; + } + + bool ThreadCompleted() { + return static_cast(gpr_atm_acq_load(&thread_pool_done_)); + } + protected: bool closed_loop_; gpr_atm thread_pool_done_; @@ -289,14 +304,6 @@ class Client { } } - gpr_timespec NextIssueTime(int thread_idx) { - const gpr_timespec result = next_time_[thread_idx]; - next_time_[thread_idx] = - gpr_time_add(next_time_[thread_idx], - gpr_time_from_nanos(interarrival_timer_.next(thread_idx), - GPR_TIMESPAN)); - return result; - } std::function NextIssuer(int thread_idx) { return closed_loop_ ? std::function() : std::bind(&Client::NextIssueTime, this, thread_idx); @@ -380,10 +387,6 @@ class Client { double interval_start_time_; }; - bool ThreadCompleted() { - return static_cast(gpr_atm_acq_load(&thread_pool_done_)); - } - virtual void ThreadFunc(size_t thread_idx, Client::Thread* t) = 0; std::vector> threads_; @@ -442,6 +445,7 @@ class ClientImpl : public Client { config.payload_config()); } virtual ~ClientImpl() {} + const RequestType* request() { return &request_; } protected: const int cores_; diff --git a/test/cpp/qps/client_callback.cc b/test/cpp/qps/client_callback.cc index 87889e36dc5..00d5853a8e8 100644 --- a/test/cpp/qps/client_callback.cc +++ b/test/cpp/qps/client_callback.cc @@ -73,6 +73,20 @@ class CallbackClient virtual ~CallbackClient() {} + /** + * The main thread of the benchmark will be waiting on DestroyMultithreading. + * Increment the rpcs_done_ variable to signify that the Callback RPC + * after thread completion is done. When the last outstanding rpc increments + * the counter it should also signal the main thread's conditional variable. + */ + void NotifyMainThreadOfThreadCompletion() { + std::lock_guard l(shutdown_mu_); + rpcs_done_++; + if (rpcs_done_ == total_outstanding_rpcs_) { + shutdown_cv_.notify_one(); + } + } + protected: size_t num_threads_; size_t total_outstanding_rpcs_; @@ -93,23 +107,6 @@ class CallbackClient ThreadFuncImpl(t, thread_idx); } - virtual void ScheduleRpc(Thread* t, size_t thread_idx, - size_t ctx_vector_idx) = 0; - - /** - * The main thread of the benchmark will be waiting on DestroyMultithreading. - * Increment the rpcs_done_ variable to signify that the Callback RPC - * after thread completion is done. When the last outstanding rpc increments - * the counter it should also signal the main thread's conditional variable. - */ - void NotifyMainThreadOfThreadCompletion() { - std::lock_guard l(shutdown_mu_); - rpcs_done_++; - if (rpcs_done_ == total_outstanding_rpcs_) { - shutdown_cv_.notify_one(); - } - } - private: int NumThreads(const ClientConfig& config) { int num_threads = config.async_client_threads(); @@ -157,7 +154,7 @@ class CallbackUnaryClient final : public CallbackClient { void InitThreadFuncImpl(size_t thread_idx) override { return; } private: - void ScheduleRpc(Thread* t, size_t thread_idx, size_t vector_idx) override { + void ScheduleRpc(Thread* t, size_t thread_idx, size_t vector_idx) { if (!closed_loop_) { gpr_timespec next_issue_time = NextIssueTime(thread_idx); // Start an alarm callback to run the internal callback after @@ -199,11 +196,154 @@ class CallbackUnaryClient final : public CallbackClient { } }; +class CallbackStreamingClient : public CallbackClient { + public: + CallbackStreamingClient(const ClientConfig& config) + : CallbackClient(config), + messages_per_stream_(config.messages_per_stream()) { + for (int ch = 0; ch < config.client_channels(); ch++) { + for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) { + ctx_.emplace_back( + new CallbackClientRpcContext(channels_[ch].get_stub())); + } + } + StartThreads(num_threads_); + } + ~CallbackStreamingClient() {} + + void AddHistogramEntry(double start_, bool ok, void* thread_ptr) { + // Update Histogram with data from the callback run + HistogramEntry entry; + if (ok) { + entry.set_value((UsageTimer::Now() - start_) * 1e9); + } + ((Client::Thread*)thread_ptr)->UpdateHistogram(&entry); + } + + int messages_per_stream() { return messages_per_stream_; } + + protected: + const int messages_per_stream_; +}; + +class CallbackStreamingPingPongClient : public CallbackStreamingClient { + public: + CallbackStreamingPingPongClient(const ClientConfig& config) + : CallbackStreamingClient(config) {} + ~CallbackStreamingPingPongClient() {} +}; + +class CallbackStreamingPingPongReactor final + : public grpc::experimental::ClientBidiReactor { + public: + CallbackStreamingPingPongReactor( + CallbackStreamingPingPongClient* client, + std::unique_ptr ctx) + : client_(client), ctx_(std::move(ctx)), messages_issued_(0) {} + + void StartNewRpc() { + if (client_->ThreadCompleted()) return; + start_ = UsageTimer::Now(); + ctx_->stub_->experimental_async()->StreamingCall(&(ctx_->context_), this); + StartWrite(client_->request()); + StartCall(); + } + + void OnWriteDone(bool ok) override { + if (!ok || client_->ThreadCompleted()) { + if (!ok) gpr_log(GPR_ERROR, "Error writing RPC"); + StartWritesDone(); + return; + } + StartRead(&ctx_->response_); + } + + void OnReadDone(bool ok) override { + client_->AddHistogramEntry(start_, ok, thread_ptr_); + + if (client_->ThreadCompleted() || !ok || + (client_->messages_per_stream() != 0 && + ++messages_issued_ >= client_->messages_per_stream())) { + if (!ok) { + gpr_log(GPR_ERROR, "Error reading RPC"); + } + StartWritesDone(); + return; + } + StartWrite(client_->request()); + } + + void OnDone(const Status& s) override { + if (client_->ThreadCompleted() || !s.ok()) { + client_->NotifyMainThreadOfThreadCompletion(); + return; + } + ctx_.reset(new CallbackClientRpcContext(ctx_->stub_)); + ScheduleRpc(); + } + + void ScheduleRpc() { + if (client_->ThreadCompleted()) return; + + if (!client_->IsClosedLoop()) { + gpr_timespec next_issue_time = client_->NextIssueTime(thread_idx_); + // Start an alarm callback to run the internal callback after + // next_issue_time + ctx_->alarm_.experimental().Set(next_issue_time, + [this](bool ok) { StartNewRpc(); }); + } else { + StartNewRpc(); + } + } + + void set_thread_ptr(void* ptr) { thread_ptr_ = ptr; } + void set_thread_idx(int thread_idx) { thread_idx_ = thread_idx; } + + CallbackStreamingPingPongClient* client_; + std::unique_ptr ctx_; + int thread_idx_; // Needed to update histogram entries + void* thread_ptr_; // Needed to update histogram entries + double start_; // Track message start time + int messages_issued_; // Messages issued by this stream +}; + +class CallbackStreamingPingPongClientImpl final + : public CallbackStreamingPingPongClient { + public: + CallbackStreamingPingPongClientImpl(const ClientConfig& config) + : CallbackStreamingPingPongClient(config) { + for (size_t i = 0; i < total_outstanding_rpcs_; i++) + reactor_.emplace_back( + new CallbackStreamingPingPongReactor(this, std::move(ctx_[i]))); + } + ~CallbackStreamingPingPongClientImpl() {} + + bool ThreadFuncImpl(Client::Thread* t, size_t thread_idx) override { + for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_; + vector_idx += num_threads_) { + reactor_[vector_idx]->set_thread_ptr(t); + reactor_[vector_idx]->set_thread_idx(thread_idx); + reactor_[vector_idx]->ScheduleRpc(); + } + return true; + } + + void InitThreadFuncImpl(size_t thread_idx) override {} + + private: + std::vector> reactor_; +}; + +// TODO(mhaidry) : Implement Streaming from client, server and both ways + std::unique_ptr CreateCallbackClient(const ClientConfig& config) { switch (config.rpc_type()) { case UNARY: return std::unique_ptr(new CallbackUnaryClient(config)); case STREAMING: + return std::unique_ptr( + new CallbackStreamingPingPongClientImpl(config)); case STREAMING_FROM_CLIENT: case STREAMING_FROM_SERVER: case STREAMING_BOTH_WAYS: