diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h index 221fb30fc59..cae7f44537c 100644 --- a/test/cpp/qps/client.h +++ b/test/cpp/qps/client.h @@ -115,12 +115,12 @@ class Client { impl_([this, idx, client]() { for (;;) { // run the loop body - client->ThreadFunc(&histogram_, idx); + client->ThreadFunc(&histogram_, idx); // lock, see if we're done std::lock_guard g(mu_); - if (done_) return; - // also check if we're marking, and swap out the histogram if so - if (new_) { + if (done_) {return;} + // check if we're marking, swap out the histogram if so + if (new_) { new_->Swap(&histogram_); new_ = nullptr; cv_.notify_one(); @@ -164,8 +164,12 @@ class Client { std::unique_ptr timer_; }; -std::unique_ptr CreateSynchronousClient(const ClientConfig& args); -std::unique_ptr CreateAsyncClient(const ClientConfig& args); +std::unique_ptr + CreateSynchronousUnaryClient(const ClientConfig& args); +std::unique_ptr + CreateSynchronousStreamingClient(const ClientConfig& args); +std::unique_ptr CreateAsyncUnaryClient(const ClientConfig& args); +std::unique_ptr CreateAsyncStreamingClient(const ClientConfig& args); } // namespace testing } // namespace grpc diff --git a/test/cpp/qps/client_async.cc b/test/cpp/qps/client_async.cc index f995a63f49f..590d56d8d00 100644 --- a/test/cpp/qps/client_async.cc +++ b/test/cpp/qps/client_async.cc @@ -46,6 +46,7 @@ #include #include #include +#include #include "test/core/util/grpc_profiler.h" #include "test/cpp/util/create_test_channel.h" #include "test/cpp/qps/qpstest.pb.h" @@ -59,13 +60,13 @@ class ClientRpcContext { public: ClientRpcContext() {} virtual ~ClientRpcContext() {} - virtual bool RunNextState() = 0; // do next state, return false if steps done + // next state, return false if done. Collect stats when appropriate + virtual bool RunNextState(bool, Histogram* hist) = 0; virtual void StartNewClone() = 0; static void* tag(ClientRpcContext* c) { return reinterpret_cast(c); } static ClientRpcContext* detag(void* t) { return reinterpret_cast(t); } - virtual void report_stats(Histogram* hist) = 0; }; template @@ -89,9 +90,12 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext { response_reader_( start_req(stub_, &context_, req_, ClientRpcContext::tag(this))) {} ~ClientRpcContextUnaryImpl() GRPC_OVERRIDE {} - bool RunNextState() GRPC_OVERRIDE { return (this->*next_state_)(); } - void report_stats(Histogram* hist) GRPC_OVERRIDE { - hist->Add((Timer::Now() - start_) * 1e9); + bool RunNextState(bool ok, Histogram* hist) GRPC_OVERRIDE { + bool ret = (this->*next_state_)(ok); + if (!ret) { + hist->Add((Timer::Now() - start_) * 1e9); + } + return ret; } void StartNewClone() GRPC_OVERRIDE { @@ -99,16 +103,16 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext { } private: - bool ReqSent() { + bool ReqSent(bool) { next_state_ = &ClientRpcContextUnaryImpl::RespDone; response_reader_->Finish(&response_, &status_, ClientRpcContext::tag(this)); return true; } - bool RespDone() { + bool RespDone(bool) { next_state_ = &ClientRpcContextUnaryImpl::DoCallBack; return false; } - bool DoCallBack() { + bool DoCallBack(bool) { callback_(status_, &response_); return false; } @@ -116,7 +120,7 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext { TestService::Stub* stub_; RequestType req_; ResponseType response_; - bool (ClientRpcContextUnaryImpl::*next_state_)(); + bool (ClientRpcContextUnaryImpl::*next_state_)(bool); std::function callback_; std::function>( TestService::Stub*, grpc::ClientContext*, const RequestType&, void*)> @@ -127,9 +131,9 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext { response_reader_; }; -class AsyncClient GRPC_FINAL : public Client { +class AsyncUnaryClient GRPC_FINAL : public Client { public: - explicit AsyncClient(const ClientConfig& config) : Client(config) { + explicit AsyncUnaryClient(const ClientConfig& config) : Client(config) { for (int i = 0; i < config.async_client_threads(); i++) { cli_cqs_.emplace_back(new CompletionQueue); } @@ -163,7 +167,7 @@ class AsyncClient GRPC_FINAL : public Client { StartThreads(config.async_client_threads()); } - ~AsyncClient() GRPC_OVERRIDE { + ~AsyncUnaryClient() GRPC_OVERRIDE { EndThreads(); for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { @@ -182,10 +186,9 @@ class AsyncClient GRPC_FINAL : public Client { cli_cqs_[thread_idx]->Next(&got_tag, &ok); ClientRpcContext* ctx = ClientRpcContext::detag(got_tag); - if (ctx->RunNextState() == false) { + if (ctx->RunNextState(ok, histogram) == false) { // call the callback and then delete it - ctx->report_stats(histogram); - ctx->RunNextState(); + ctx->RunNextState(ok, histogram); ctx->StartNewClone(); delete ctx; } @@ -194,8 +197,145 @@ class AsyncClient GRPC_FINAL : public Client { std::vector> cli_cqs_; }; -std::unique_ptr CreateAsyncClient(const ClientConfig& args) { - return std::unique_ptr(new AsyncClient(args)); +template +class ClientRpcContextStreamingImpl : public ClientRpcContext { + public: + ClientRpcContextStreamingImpl( + TestService::Stub *stub, const RequestType &req, + std::function< + std::unique_ptr>( + TestService::Stub *, grpc::ClientContext *, void *)> start_req, + std::function on_done) + : context_(), + stub_(stub), + req_(req), + response_(), + next_state_(&ClientRpcContextStreamingImpl::ReqSent), + callback_(on_done), + start_req_(start_req), + start_(Timer::Now()), + stream_(start_req_(stub_, &context_, ClientRpcContext::tag(this))) {} + ~ClientRpcContextStreamingImpl() GRPC_OVERRIDE {} + bool RunNextState(bool ok, Histogram *hist) GRPC_OVERRIDE { + return (this->*next_state_)(ok, hist); + } + void StartNewClone() GRPC_OVERRIDE { + new ClientRpcContextStreamingImpl(stub_, req_, start_req_, callback_); + } + + private: + bool ReqSent(bool ok, Histogram *) { + return StartWrite(ok); + } + bool StartWrite(bool ok) { + if (!ok) { + return(false); + } + start_ = Timer::Now(); + next_state_ = &ClientRpcContextStreamingImpl::WriteDone; + stream_->Write(req_, ClientRpcContext::tag(this)); + return true; + } + bool WriteDone(bool ok, Histogram *) { + if (!ok) { + return(false); + } + next_state_ = &ClientRpcContextStreamingImpl::ReadDone; + stream_->Read(&response_, ClientRpcContext::tag(this)); + return true; + } + bool ReadDone(bool ok, Histogram *hist) { + hist->Add((Timer::Now() - start_) * 1e9); + return StartWrite(ok); + } + grpc::ClientContext context_; + TestService::Stub *stub_; + RequestType req_; + ResponseType response_; + bool (ClientRpcContextStreamingImpl::*next_state_)(bool, Histogram *); + std::function callback_; + std::function>( + TestService::Stub *, grpc::ClientContext *, void *)> start_req_; + grpc::Status status_; + double start_; + std::unique_ptr> + stream_; +}; + +class AsyncStreamingClient GRPC_FINAL : public Client { + public: + explicit AsyncStreamingClient(const ClientConfig &config) : Client(config) { + for (int i = 0; i < config.async_client_threads(); i++) { + cli_cqs_.emplace_back(new CompletionQueue); + } + + auto payload_size = config.payload_size(); + auto check_done = [payload_size](grpc::Status s, SimpleResponse *response) { + GPR_ASSERT(s.IsOk() && (response->payload().type() == + grpc::testing::PayloadType::COMPRESSABLE) && + (response->payload().body().length() == + static_cast(payload_size))); + }; + + int t = 0; + for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) { + for (auto channel = channels_.begin(); channel != channels_.end(); + channel++) { + auto* cq = cli_cqs_[t].get(); + t = (t + 1) % cli_cqs_.size(); + auto start_req = [cq](TestService::Stub *stub, grpc::ClientContext *ctx, + void *tag) { + auto stream = stub->AsyncStreamingCall(ctx, cq, tag); + return stream; + }; + + TestService::Stub *stub = channel->get_stub(); + const SimpleRequest &request = request_; + new ClientRpcContextStreamingImpl( + stub, request, start_req, check_done); + } + } + + StartThreads(config.async_client_threads()); + } + + ~AsyncStreamingClient() GRPC_OVERRIDE { + EndThreads(); + + for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { + (*cq)->Shutdown(); + void *got_tag; + bool ok; + while ((*cq)->Next(&got_tag, &ok)) { + delete ClientRpcContext::detag(got_tag); + } + } + } + + void ThreadFunc(Histogram *histogram, size_t thread_idx) GRPC_OVERRIDE { + void *got_tag; + bool ok; + cli_cqs_[thread_idx]->Next(&got_tag, &ok); + + ClientRpcContext *ctx = ClientRpcContext::detag(got_tag); + if (ctx->RunNextState(ok, histogram) == false) { + // call the callback and then delete it + ctx->RunNextState(ok, histogram); + ctx->StartNewClone(); + delete ctx; + } + } + + std::vector> cli_cqs_; +}; + +std::unique_ptr CreateAsyncUnaryClient(const ClientConfig& args) { + return std::unique_ptr(new AsyncUnaryClient(args)); +} +std::unique_ptr CreateAsyncStreamingClient(const ClientConfig& args) { + return std::unique_ptr(new AsyncStreamingClient(args)); } } // namespace testing diff --git a/test/cpp/qps/client_sync.cc b/test/cpp/qps/client_sync.cc index 7bb7231c6f6..e4ee45a72d0 100644 --- a/test/cpp/qps/client_sync.cc +++ b/test/cpp/qps/client_sync.cc @@ -48,9 +48,11 @@ #include #include #include -#include #include #include +#include +#include +#include #include "test/core/util/grpc_profiler.h" #include "test/cpp/util/create_test_channel.h" #include "test/cpp/qps/client.h" @@ -61,18 +63,28 @@ namespace grpc { namespace testing { -class SynchronousClient GRPC_FINAL : public Client { +class SynchronousClient : public Client { public: SynchronousClient(const ClientConfig& config) : Client(config) { - size_t num_threads = - config.outstanding_rpcs_per_channel() * config.client_channels(); - responses_.resize(num_threads); - StartThreads(num_threads); + num_threads_ = + config.outstanding_rpcs_per_channel() * config.client_channels(); + responses_.resize(num_threads_); } - ~SynchronousClient() { EndThreads(); } + virtual ~SynchronousClient() { EndThreads(); } + + protected: + size_t num_threads_; + std::vector responses_; +}; - void ThreadFunc(Histogram* histogram, size_t thread_idx) { +class SynchronousUnaryClient GRPC_FINAL : public SynchronousClient { + public: + SynchronousUnaryClient(const ClientConfig& config): + SynchronousClient(config) {StartThreads(num_threads_);} + ~SynchronousUnaryClient() {} + + void ThreadFunc(Histogram* histogram, size_t thread_idx) GRPC_OVERRIDE { auto* stub = channels_[thread_idx % channels_.size()].get_stub(); double start = Timer::Now(); grpc::ClientContext context; @@ -80,13 +92,45 @@ class SynchronousClient GRPC_FINAL : public Client { stub->UnaryCall(&context, request_, &responses_[thread_idx]); histogram->Add((Timer::Now() - start) * 1e9); } +}; - private: - std::vector responses_; +class SynchronousStreamingClient GRPC_FINAL : public SynchronousClient { + public: + SynchronousStreamingClient(const ClientConfig& config): + SynchronousClient(config) { + for (size_t thread_idx=0;thread_idxStreamingCall(&context_); + } + StartThreads(num_threads_); + } + ~SynchronousStreamingClient() { + if (stream_) { + SimpleResponse response; + stream_->WritesDone(); + EXPECT_TRUE(stream_->Finish().IsOk()); + } + } + + void ThreadFunc(Histogram* histogram, size_t thread_idx) GRPC_OVERRIDE { + double start = Timer::Now(); + EXPECT_TRUE(stream_->Write(request_)); + EXPECT_TRUE(stream_->Read(&responses_[thread_idx])); + histogram->Add((Timer::Now() - start) * 1e9); + } + private: + grpc::ClientContext context_; + std::unique_ptr> stream_; }; -std::unique_ptr CreateSynchronousClient(const ClientConfig& config) { - return std::unique_ptr(new SynchronousClient(config)); +std::unique_ptr +CreateSynchronousUnaryClient(const ClientConfig& config) { + return std::unique_ptr(new SynchronousUnaryClient(config)); +} +std::unique_ptr +CreateSynchronousStreamingClient(const ClientConfig& config) { + return std::unique_ptr(new SynchronousStreamingClient(config)); } } // namespace testing diff --git a/test/cpp/qps/qps-sweep.sh b/test/cpp/qps/qps-sweep.sh new file mode 100755 index 00000000000..7bc6eade2c1 --- /dev/null +++ b/test/cpp/qps/qps-sweep.sh @@ -0,0 +1,25 @@ +#!/bin/sh + +if [ x"$QPS_WORKERS" == x ]; then + echo Error: Must set QPS_WORKERS variable in form \ + "host:port,host:port,..." 1>&2 + exit 1 +fi + +bins=`find . .. ../.. ../../.. -name bins | head -1` + +for channels in 1 2 4 8 +do + for client in SYNCHRONOUS_CLIENT ASYNC_CLIENT + do + for server in SYNCHRONOUS_SERVER ASYNC_SERVER + do + for rpc in UNARY STREAMING + do + echo "Test $rpc $client $server , $channels channels" + "$bins"/opt/qps_driver --rpc_type=$rpc \ + --client_type=$client --server_type=$server + done + done + done +done diff --git a/test/cpp/qps/qps_driver.cc b/test/cpp/qps/qps_driver.cc index 5e9a577f8e6..f7aa8e2aba6 100644 --- a/test/cpp/qps/qps_driver.cc +++ b/test/cpp/qps/qps_driver.cc @@ -42,6 +42,7 @@ DEFINE_int32(num_servers, 1, "Number of server binaries"); // Common config DEFINE_bool(enable_ssl, false, "Use SSL"); +DEFINE_string(rpc_type, "UNARY", "Type of RPC: UNARY or STREAMING"); // Server config DEFINE_int32(server_threads, 1, "Number of server threads"); @@ -59,6 +60,7 @@ using grpc::testing::ClientConfig; using grpc::testing::ServerConfig; using grpc::testing::ClientType; using grpc::testing::ServerType; +using grpc::testing::RpcType; using grpc::testing::ResourceUsage; using grpc::testing::sum; @@ -73,6 +75,9 @@ int main(int argc, char** argv) { grpc_init(); ParseCommandLineFlags(&argc, &argv, true); + RpcType rpc_type; + GPR_ASSERT(RpcType_Parse(FLAGS_rpc_type, &rpc_type)); + ClientType client_type; ServerType server_type; GPR_ASSERT(ClientType_Parse(FLAGS_client_type, &client_type)); @@ -86,6 +91,7 @@ int main(int argc, char** argv) { client_config.set_client_channels(FLAGS_client_channels); client_config.set_payload_size(FLAGS_payload_size); client_config.set_async_client_threads(FLAGS_async_client_threads); + client_config.set_rpc_type(rpc_type); ServerConfig server_config; server_config.set_server_type(server_type); diff --git a/test/cpp/qps/qpstest.proto b/test/cpp/qps/qpstest.proto index 6a7170bf580..1553ef5f07e 100644 --- a/test/cpp/qps/qpstest.proto +++ b/test/cpp/qps/qpstest.proto @@ -87,15 +87,21 @@ enum ServerType { ASYNC_SERVER = 2; } +enum RpcType { + UNARY = 1; + STREAMING = 2; +} + message ClientConfig { repeated string server_targets = 1; required ClientType client_type = 2; - required bool enable_ssl = 3; + optional bool enable_ssl = 3 [default=false]; required int32 outstanding_rpcs_per_channel = 4; required int32 client_channels = 5; required int32 payload_size = 6; // only for async client: optional int32 async_client_threads = 7; + optional RpcType rpc_type = 8 [default=UNARY]; } // Request current stats @@ -121,8 +127,8 @@ message ClientStatus { message ServerConfig { required ServerType server_type = 1; - required int32 threads = 2; - required bool enable_ssl = 3; + optional int32 threads = 2 [default=1]; + optional bool enable_ssl = 3 [default=false]; } message ServerArgs { @@ -144,7 +150,7 @@ message SimpleRequest { // Desired payload size in the response from the server. // If response_type is COMPRESSABLE, this denotes the size before compression. - optional int32 response_size = 2; + optional int32 response_size = 2 [default=0]; // Optional input payload sent along with the request. optional Payload payload = 3; @@ -154,72 +160,14 @@ message SimpleResponse { optional Payload payload = 1; } -message StreamingInputCallRequest { - // Optional input payload sent along with the request. - optional Payload payload = 1; - - // Not expecting any payload from the response. -} - -message StreamingInputCallResponse { - // Aggregated size of payloads received from the client. - optional int32 aggregated_payload_size = 1; -} - -message ResponseParameters { - // Desired payload sizes in responses from the server. - // If response_type is COMPRESSABLE, this denotes the size before compression. - required int32 size = 1; - - // Desired interval between consecutive responses in the response stream in - // microseconds. - required int32 interval_us = 2; -} - -message StreamingOutputCallRequest { - // Desired payload type in the response from the server. - // If response_type is RANDOM, the payload from each response in the stream - // might be of different types. This is to simulate a mixed type of payload - // stream. - optional PayloadType response_type = 1 [default=COMPRESSABLE]; - - repeated ResponseParameters response_parameters = 2; - - // Optional input payload sent along with the request. - optional Payload payload = 3; -} - -message StreamingOutputCallResponse { - optional Payload payload = 1; -} - service TestService { // One request followed by one response. // The server returns the client payload as-is. rpc UnaryCall(SimpleRequest) returns (SimpleResponse); - // One request followed by a sequence of responses (streamed download). - // The server returns the payload with client desired type and sizes. - rpc StreamingOutputCall(StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); - - // A sequence of requests followed by one response (streamed upload). - // The server returns the aggregated size of client payload as the result. - rpc StreamingInputCall(stream StreamingInputCallRequest) - returns (StreamingInputCallResponse); - - // A sequence of requests with each request served by the server immediately. - // As one request could lead to multiple responses, this interface - // demonstrates the idea of full duplexing. - rpc FullDuplexCall(stream StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); - - // A sequence of requests followed by a sequence of responses. - // The server buffers all the client requests and then serves them in order. A - // stream of responses are returned to the client when the server starts with - // first request. - rpc HalfDuplexCall(stream StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); + // One request followed by one response. + // The server returns the client payload as-is. + rpc StreamingCall(stream SimpleRequest) returns (stream SimpleResponse); } service Worker { diff --git a/test/cpp/qps/server.cc b/test/cpp/qps/server.cc index e1907b069b3..258c5e145b8 100644 --- a/test/cpp/qps/server.cc +++ b/test/cpp/qps/server.cc @@ -115,7 +115,7 @@ class TestServiceImpl GRPC_FINAL : public TestService::Service { } Status UnaryCall(ServerContext* context, const SimpleRequest* request, SimpleResponse* response) { - if (request->has_response_size() && request->response_size() > 0) { + if (request->response_size() > 0) { if (!SetPayload(request->response_type(), request->response_size(), response->mutable_payload())) { return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc index 7b81bd35a2a..83bb08cd498 100644 --- a/test/cpp/qps/server_async.cc +++ b/test/cpp/qps/server_async.cc @@ -33,6 +33,7 @@ #include #include +#include #include #include #include @@ -48,6 +49,7 @@ #include #include #include +#include #include #include "src/cpp/server/thread_pool.h" #include "test/core/util/grpc_profiler.h" @@ -63,7 +65,8 @@ namespace testing { class AsyncQpsServerTest : public Server { public: AsyncQpsServerTest(const ServerConfig& config, int port) - : srv_cq_(), async_service_(&srv_cq_), server_(nullptr) { + : srv_cq_(), async_service_(&srv_cq_), server_(nullptr), + shutdown_(false) { char* server_address = NULL; gpr_join_host_port(&server_address, "::", port); @@ -78,10 +81,16 @@ class AsyncQpsServerTest : public Server { using namespace std::placeholders; request_unary_ = std::bind(&TestService::AsyncService::RequestUnaryCall, &async_service_, _1, _2, _3, &srv_cq_, _4); + request_streaming_ = + std::bind(&TestService::AsyncService::RequestStreamingCall, + &async_service_, _1, _2, &srv_cq_, _3); for (int i = 0; i < 100; i++) { contexts_.push_front( new ServerRpcContextUnaryImpl( - request_unary_, UnaryCall)); + request_unary_, ProcessRPC)); + contexts_.push_front( + new ServerRpcContextStreamingImpl( + request_streaming_, ProcessRPC)); } for (int i = 0; i < config.threads(); i++) { threads_.push_back(std::thread([=]() { @@ -89,14 +98,15 @@ class AsyncQpsServerTest : public Server { bool ok; void* got_tag; while (srv_cq_.Next(&got_tag, &ok)) { - if (ok) { - ServerRpcContext* ctx = detag(got_tag); - // The tag is a pointer to an RPC context to invoke - if (ctx->RunNextState() == false) { - // this RPC context is done, so refresh it + ServerRpcContext* ctx = detag(got_tag); + // The tag is a pointer to an RPC context to invoke + if (ctx->RunNextState(ok) == false) { + // this RPC context is done, so refresh it + std::lock_guard g(shutdown_mutex_); + if (!shutdown_) { ctx->Reset(); } - } + } } return; })); @@ -104,7 +114,11 @@ class AsyncQpsServerTest : public Server { } ~AsyncQpsServerTest() { server_->Shutdown(); - srv_cq_.Shutdown(); + { + std::lock_guard g(shutdown_mutex_); + shutdown_ = true; + srv_cq_.Shutdown(); + } for (auto thr = threads_.begin(); thr != threads_.end(); thr++) { thr->join(); } @@ -119,7 +133,7 @@ class AsyncQpsServerTest : public Server { public: ServerRpcContext() {} virtual ~ServerRpcContext(){}; - virtual bool RunNextState() = 0; // do next state, return false if all done + virtual bool RunNextState(bool) = 0; // next state, return false if done virtual void Reset() = 0; // start this back at a clean state }; static void* tag(ServerRpcContext* func) { @@ -130,7 +144,7 @@ class AsyncQpsServerTest : public Server { } template - class ServerRpcContextUnaryImpl : public ServerRpcContext { + class ServerRpcContextUnaryImpl GRPC_FINAL : public ServerRpcContext { public: ServerRpcContextUnaryImpl( std::function*next_state_)(); } + bool RunNextState(bool ok) GRPC_OVERRIDE {return (this->*next_state_)(ok);} void Reset() GRPC_OVERRIDE { srv_ctx_ = ServerContext(); req_ = RequestType(); @@ -160,8 +174,11 @@ class AsyncQpsServerTest : public Server { } private: - bool finisher() { return false; } - bool invoker() { + bool finisher(bool) { return false; } + bool invoker(bool ok) { + if (!ok) + return false; + ResponseType response; // Call the RPC processing function @@ -174,7 +191,7 @@ class AsyncQpsServerTest : public Server { } ServerContext srv_ctx_; RequestType req_; - bool (ServerRpcContextUnaryImpl::*next_state_)(); + bool (ServerRpcContextUnaryImpl::*next_state_)(bool); std::function*, void*)> request_method_; @@ -183,9 +200,88 @@ class AsyncQpsServerTest : public Server { grpc::ServerAsyncResponseWriter response_writer_; }; - static Status UnaryCall(const SimpleRequest* request, - SimpleResponse* response) { - if (request->has_response_size() && request->response_size() > 0) { + template + class ServerRpcContextStreamingImpl GRPC_FINAL : public ServerRpcContext { + public: + ServerRpcContextStreamingImpl( + std::function *, void *)> request_method, + std::function + invoke_method) + : next_state_(&ServerRpcContextStreamingImpl::request_done), + request_method_(request_method), + invoke_method_(invoke_method), + stream_(&srv_ctx_) { + request_method_(&srv_ctx_, &stream_, AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextStreamingImpl() GRPC_OVERRIDE { + } + bool RunNextState(bool ok) GRPC_OVERRIDE {return (this->*next_state_)(ok);} + void Reset() GRPC_OVERRIDE { + srv_ctx_ = ServerContext(); + req_ = RequestType(); + stream_ = grpc::ServerAsyncReaderWriter(&srv_ctx_); + + // Then request the method + next_state_ = &ServerRpcContextStreamingImpl::request_done; + request_method_(&srv_ctx_, &stream_, AsyncQpsServerTest::tag(this)); + } + + private: + bool request_done(bool ok) { + if (!ok) + return false; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::read_done; + return true; + } + + bool read_done(bool ok) { + if (ok) { + // invoke the method + ResponseType response; + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response); + // initiate the write + stream_.Write(response, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::write_done; + } else { // client has sent writes done + // finish the stream + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::finish_done; + } + return true; + } + bool write_done(bool ok) { + // now go back and get another streaming read! + if (ok) { + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::read_done; + } + else { + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::finish_done; + } + return true; + } + bool finish_done(bool ok) {return false; /* reset the context */ } + + ServerContext srv_ctx_; + RequestType req_; + bool (ServerRpcContextStreamingImpl::*next_state_)(bool); + std::function *, void *)> request_method_; + std::function + invoke_method_; + grpc::ServerAsyncReaderWriter stream_; + }; + + static Status ProcessRPC(const SimpleRequest* request, + SimpleResponse* response) { + if (request->response_size() > 0) { if (!SetPayload(request->response_type(), request->response_size(), response->mutable_payload())) { return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); @@ -200,7 +296,13 @@ class AsyncQpsServerTest : public Server { std::function*, void*)> request_unary_; + std::function*, void*)> + request_streaming_; std::forward_list contexts_; + + std::mutex shutdown_mutex_; + bool shutdown_; }; std::unique_ptr CreateAsyncServer(const ServerConfig& config, diff --git a/test/cpp/qps/server_sync.cc b/test/cpp/qps/server_sync.cc index 3e15fb61c02..6724b8f79af 100644 --- a/test/cpp/qps/server_sync.cc +++ b/test/cpp/qps/server_sync.cc @@ -62,7 +62,7 @@ class TestServiceImpl GRPC_FINAL : public TestService::Service { public: Status UnaryCall(ServerContext* context, const SimpleRequest* request, SimpleResponse* response) GRPC_OVERRIDE { - if (request->has_response_size() && request->response_size() > 0) { + if (request->response_size() > 0) { if (!Server::SetPayload(request->response_type(), request->response_size(), response->mutable_payload())) { @@ -71,6 +71,23 @@ class TestServiceImpl GRPC_FINAL : public TestService::Service { } return Status::OK; } + Status StreamingCall(ServerContext *context, + ServerReaderWriter* + stream) GRPC_OVERRIDE { + SimpleRequest request; + while (stream->Read(&request)) { + SimpleResponse response; + if (request.response_size() > 0) { + if (!Server::SetPayload(request.response_type(), + request.response_size(), + response.mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + stream->Write(response); + } + return Status::OK; + } }; class SynchronousServer GRPC_FINAL : public grpc::testing::Server { diff --git a/test/cpp/qps/worker.cc b/test/cpp/qps/worker.cc index fdcd9d50697..dddc4c98507 100644 --- a/test/cpp/qps/worker.cc +++ b/test/cpp/qps/worker.cc @@ -77,9 +77,12 @@ namespace testing { std::unique_ptr CreateClient(const ClientConfig& config) { switch (config.client_type()) { case ClientType::SYNCHRONOUS_CLIENT: - return CreateSynchronousClient(config); + return (config.rpc_type() == RpcType::UNARY) ? + CreateSynchronousUnaryClient(config) : + CreateSynchronousStreamingClient(config); case ClientType::ASYNC_CLIENT: - return CreateAsyncClient(config); + return (config.rpc_type() == RpcType::UNARY) ? + CreateAsyncUnaryClient(config) : CreateAsyncStreamingClient(config); } abort(); }