/* * * Copyright 2015 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "src/core/lib/gprpp/host_port.h" #include "src/core/lib/surface/completion_queue.h" #include "src/proto/grpc/testing/benchmark_service.grpc.pb.h" #include "test/core/util/test_config.h" #include "test/cpp/qps/qps_server_builder.h" #include "test/cpp/qps/server.h" namespace grpc { namespace testing { template class AsyncQpsServerTest final : public grpc::testing::Server { public: AsyncQpsServerTest( const ServerConfig& config, std::function register_service, std::function*, CompletionQueue*, ServerCompletionQueue*, void*)> request_unary_function, std::function*, CompletionQueue*, ServerCompletionQueue*, void*)> request_streaming_function, std::function*, CompletionQueue*, ServerCompletionQueue*, void*)> request_streaming_from_client_function, std::function*, CompletionQueue*, ServerCompletionQueue*, void*)> request_streaming_from_server_function, std::function*, CompletionQueue*, ServerCompletionQueue*, void*)> request_streaming_both_ways_function, std::function process_rpc) : Server(config) { std::unique_ptr builder = CreateQpsServerBuilder(); auto port_num = port(); // Negative port number means inproc server, so no listen port needed if (port_num >= 0) { std::unique_ptr server_address; grpc_core::JoinHostPort(&server_address, "::", port_num); builder->AddListeningPort(server_address.get(), Server::CreateServerCredentials(config)); } register_service(builder.get(), &async_service_); int num_threads = config.async_server_threads(); if (num_threads <= 0) { // dynamic sizing num_threads = cores(); gpr_log(GPR_INFO, "Sizing async server to %d threads", num_threads); } int tpc = std::max(1, config.threads_per_cq()); // 1 if unspecified int num_cqs = (num_threads + tpc - 1) / tpc; // ceiling operator for (int i = 0; i < num_cqs; i++) { srv_cqs_.emplace_back(builder->AddCompletionQueue()); } for (int i = 0; i < num_threads; i++) { cq_.emplace_back(i % srv_cqs_.size()); } ApplyConfigToBuilder(config, builder.get()); server_ = builder->BuildAndStart(); auto process_rpc_bound = std::bind(process_rpc, config.payload_config(), std::placeholders::_1, std::placeholders::_2); for (int i = 0; i < 5000; i++) { for (int j = 0; j < num_cqs; j++) { if (request_unary_function) { auto request_unary = std::bind( request_unary_function, &async_service_, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, srv_cqs_[j].get(), srv_cqs_[j].get(), std::placeholders::_4); contexts_.emplace_back( new ServerRpcContextUnaryImpl(request_unary, process_rpc_bound)); } if (request_streaming_function) { auto request_streaming = std::bind( request_streaming_function, &async_service_, std::placeholders::_1, std::placeholders::_2, srv_cqs_[j].get(), srv_cqs_[j].get(), std::placeholders::_3); contexts_.emplace_back(new ServerRpcContextStreamingImpl( request_streaming, process_rpc_bound)); } if (request_streaming_from_client_function) { auto request_streaming_from_client = std::bind( request_streaming_from_client_function, &async_service_, std::placeholders::_1, std::placeholders::_2, srv_cqs_[j].get(), srv_cqs_[j].get(), std::placeholders::_3); contexts_.emplace_back(new ServerRpcContextStreamingFromClientImpl( request_streaming_from_client, process_rpc_bound)); } if (request_streaming_from_server_function) { auto request_streaming_from_server = std::bind(request_streaming_from_server_function, &async_service_, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, srv_cqs_[j].get(), srv_cqs_[j].get(), std::placeholders::_4); contexts_.emplace_back(new ServerRpcContextStreamingFromServerImpl( request_streaming_from_server, process_rpc_bound)); } if (request_streaming_both_ways_function) { // TODO(vjpai): Add this code } } } for (int i = 0; i < num_threads; i++) { shutdown_state_.emplace_back(new PerThreadShutdownState()); threads_.emplace_back(&AsyncQpsServerTest::ThreadFunc, this, i); } } ~AsyncQpsServerTest() { for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) { std::lock_guard lock((*ss)->mutex); (*ss)->shutdown = true; } std::thread shutdown_thread(&AsyncQpsServerTest::ShutdownThreadFunc, this); for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) { (*cq)->Shutdown(); } for (auto thr = threads_.begin(); thr != threads_.end(); thr++) { thr->join(); } for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) { bool ok; void* got_tag; while ((*cq)->Next(&got_tag, &ok)) ; } shutdown_thread.join(); } int GetPollCount() override { int count = 0; for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); cq++) { count += grpc_get_cq_poll_num((*cq)->cq()); } return count; } std::shared_ptr InProcessChannel( const ChannelArguments& args) override { return server_->InProcessChannel(args); } private: void ShutdownThreadFunc() { // TODO (vpai): Remove this deadline and allow Shutdown to finish properly auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(3); server_->Shutdown(deadline); } void ThreadFunc(int thread_idx) { // Wait until work is available or we are shutting down bool ok; void* got_tag; if (!srv_cqs_[cq_[thread_idx]]->Next(&got_tag, &ok)) { return; } ServerRpcContext* ctx; std::mutex* mu_ptr = &shutdown_state_[thread_idx]->mutex; do { ctx = detag(got_tag); // The tag is a pointer to an RPC context to invoke // Proceed while holding a lock to make sure that // this thread isn't supposed to shut down mu_ptr->lock(); if (shutdown_state_[thread_idx]->shutdown) { mu_ptr->unlock(); return; } } while (srv_cqs_[cq_[thread_idx]]->DoThenAsyncNext( [&, ctx, ok, mu_ptr]() { ctx->lock(); if (!ctx->RunNextState(ok)) { ctx->Reset(); } ctx->unlock(); mu_ptr->unlock(); }, &got_tag, &ok, gpr_inf_future(GPR_CLOCK_REALTIME))); } class ServerRpcContext { public: ServerRpcContext() {} void lock() { mu_.lock(); } void unlock() { mu_.unlock(); } virtual ~ServerRpcContext(){}; virtual bool RunNextState(bool) = 0; // next state, return false if done virtual void Reset() = 0; // start this back at a clean state private: std::mutex mu_; }; static void* tag(ServerRpcContext* func) { return static_cast(func); } static ServerRpcContext* detag(void* tag) { return static_cast(tag); } class ServerRpcContextUnaryImpl final : public ServerRpcContext { public: ServerRpcContextUnaryImpl( std::function*, void*)> request_method, std::function invoke_method) : srv_ctx_(new ServerContextType), next_state_(&ServerRpcContextUnaryImpl::invoker), request_method_(request_method), invoke_method_(invoke_method), response_writer_(srv_ctx_.get()) { request_method_(srv_ctx_.get(), &req_, &response_writer_, AsyncQpsServerTest::tag(this)); } ~ServerRpcContextUnaryImpl() override {} bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } void Reset() override { srv_ctx_.reset(new ServerContextType); req_ = RequestType(); response_writer_ = grpc::ServerAsyncResponseWriter(srv_ctx_.get()); // Then request the method next_state_ = &ServerRpcContextUnaryImpl::invoker; request_method_(srv_ctx_.get(), &req_, &response_writer_, AsyncQpsServerTest::tag(this)); } private: bool finisher(bool) { return false; } bool invoker(bool ok) { if (!ok) { return false; } // Call the RPC processing function grpc::Status status = invoke_method_(&req_, &response_); // Have the response writer work and invoke on_finish when done next_state_ = &ServerRpcContextUnaryImpl::finisher; response_writer_.Finish(response_, status, AsyncQpsServerTest::tag(this)); return true; } std::unique_ptr srv_ctx_; RequestType req_; ResponseType response_; bool (ServerRpcContextUnaryImpl::*next_state_)(bool); std::function*, void*)> request_method_; std::function invoke_method_; grpc::ServerAsyncResponseWriter response_writer_; }; class ServerRpcContextStreamingImpl final : public ServerRpcContext { public: ServerRpcContextStreamingImpl( std::function*, void*)> request_method, std::function invoke_method) : srv_ctx_(new ServerContextType), next_state_(&ServerRpcContextStreamingImpl::request_done), request_method_(request_method), invoke_method_(invoke_method), stream_(srv_ctx_.get()) { request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); } ~ServerRpcContextStreamingImpl() override {} bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } void Reset() override { srv_ctx_.reset(new ServerContextType); req_ = RequestType(); stream_ = grpc::ServerAsyncReaderWriter( srv_ctx_.get()); // Then request the method next_state_ = &ServerRpcContextStreamingImpl::request_done; request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); } private: bool request_done(bool ok) { if (!ok) { return false; } next_state_ = &ServerRpcContextStreamingImpl::read_done; stream_.Read(&req_, AsyncQpsServerTest::tag(this)); return true; } bool read_done(bool ok) { if (ok) { // invoke the method // Call the RPC processing function grpc::Status status = invoke_method_(&req_, &response_); // initiate the write next_state_ = &ServerRpcContextStreamingImpl::write_done; stream_.Write(response_, AsyncQpsServerTest::tag(this)); } else { // client has sent writes done // finish the stream next_state_ = &ServerRpcContextStreamingImpl::finish_done; stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); } return true; } bool write_done(bool ok) { // now go back and get another streaming read! if (ok) { next_state_ = &ServerRpcContextStreamingImpl::read_done; stream_.Read(&req_, AsyncQpsServerTest::tag(this)); } else { next_state_ = &ServerRpcContextStreamingImpl::finish_done; stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); } return true; } bool finish_done(bool /*ok*/) { return false; /*reset the context*/ } std::unique_ptr srv_ctx_; RequestType req_; ResponseType response_; bool (ServerRpcContextStreamingImpl::*next_state_)(bool); std::function*, void*)> request_method_; std::function invoke_method_; grpc::ServerAsyncReaderWriter stream_; }; class ServerRpcContextStreamingFromClientImpl final : public ServerRpcContext { public: ServerRpcContextStreamingFromClientImpl( std::function*, void*)> request_method, std::function invoke_method) : srv_ctx_(new ServerContextType), next_state_(&ServerRpcContextStreamingFromClientImpl::request_done), request_method_(request_method), invoke_method_(invoke_method), stream_(srv_ctx_.get()) { request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); } ~ServerRpcContextStreamingFromClientImpl() override {} bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } void Reset() override { srv_ctx_.reset(new ServerContextType); req_ = RequestType(); stream_ = grpc::ServerAsyncReader(srv_ctx_.get()); // Then request the method next_state_ = &ServerRpcContextStreamingFromClientImpl::request_done; request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); } private: bool request_done(bool ok) { if (!ok) { return false; } next_state_ = &ServerRpcContextStreamingFromClientImpl::read_done; stream_.Read(&req_, AsyncQpsServerTest::tag(this)); return true; } bool read_done(bool ok) { if (ok) { // In this case, just do another read // next_state_ is unchanged stream_.Read(&req_, AsyncQpsServerTest::tag(this)); return true; } else { // client has sent writes done // invoke the method // Call the RPC processing function grpc::Status status = invoke_method_(&req_, &response_); // finish the stream next_state_ = &ServerRpcContextStreamingFromClientImpl::finish_done; stream_.Finish(response_, Status::OK, AsyncQpsServerTest::tag(this)); } return true; } bool finish_done(bool /*ok*/) { return false; /*reset the context*/ } std::unique_ptr srv_ctx_; RequestType req_; ResponseType response_; bool (ServerRpcContextStreamingFromClientImpl::*next_state_)(bool); std::function*, void*)> request_method_; std::function invoke_method_; grpc::ServerAsyncReader stream_; }; class ServerRpcContextStreamingFromServerImpl final : public ServerRpcContext { public: ServerRpcContextStreamingFromServerImpl( std::function*, void*)> request_method, std::function invoke_method) : srv_ctx_(new ServerContextType), next_state_(&ServerRpcContextStreamingFromServerImpl::request_done), request_method_(request_method), invoke_method_(invoke_method), stream_(srv_ctx_.get()) { request_method_(srv_ctx_.get(), &req_, &stream_, AsyncQpsServerTest::tag(this)); } ~ServerRpcContextStreamingFromServerImpl() override {} bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } void Reset() override { srv_ctx_.reset(new ServerContextType); req_ = RequestType(); stream_ = grpc::ServerAsyncWriter(srv_ctx_.get()); // Then request the method next_state_ = &ServerRpcContextStreamingFromServerImpl::request_done; request_method_(srv_ctx_.get(), &req_, &stream_, AsyncQpsServerTest::tag(this)); } private: bool request_done(bool ok) { if (!ok) { return false; } // invoke the method // Call the RPC processing function grpc::Status status = invoke_method_(&req_, &response_); next_state_ = &ServerRpcContextStreamingFromServerImpl::write_done; stream_.Write(response_, AsyncQpsServerTest::tag(this)); return true; } bool write_done(bool ok) { if (ok) { // Do another write! // next_state_ is unchanged stream_.Write(response_, AsyncQpsServerTest::tag(this)); } else { // must be done so let's finish next_state_ = &ServerRpcContextStreamingFromServerImpl::finish_done; stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); } return true; } bool finish_done(bool /*ok*/) { return false; /*reset the context*/ } std::unique_ptr srv_ctx_; RequestType req_; ResponseType response_; bool (ServerRpcContextStreamingFromServerImpl::*next_state_)(bool); std::function*, void*)> request_method_; std::function invoke_method_; grpc::ServerAsyncWriter stream_; }; std::vector threads_; std::unique_ptr server_; std::vector> srv_cqs_; std::vector cq_; ServiceType async_service_; std::vector> contexts_; struct PerThreadShutdownState { mutable std::mutex mutex; bool shutdown; PerThreadShutdownState() : shutdown(false) {} }; std::vector> shutdown_state_; }; static void RegisterBenchmarkService(ServerBuilder* builder, BenchmarkService::AsyncService* service) { builder->RegisterService(service); } static void RegisterGenericService(ServerBuilder* builder, grpc::AsyncGenericService* service) { builder->RegisterAsyncGenericService(service); } static Status ProcessSimpleRPC(const PayloadConfig&, SimpleRequest* request, SimpleResponse* response) { if (request->response_size() > 0) { if (!Server::SetPayload(request->response_type(), request->response_size(), response->mutable_payload())) { return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); } } // We are done using the request. Clear it to reduce working memory. // This proves to reduce cache misses in large message size cases. request->Clear(); return Status::OK; } static Status ProcessGenericRPC(const PayloadConfig& payload_config, ByteBuffer* request, ByteBuffer* response) { // We are done using the request. Clear it to reduce working memory. // This proves to reduce cache misses in large message size cases. request->Clear(); int resp_size = payload_config.bytebuf_params().resp_size(); std::unique_ptr buf(new char[resp_size]); memset(buf.get(), 0, static_cast(resp_size)); Slice slice(buf.get(), resp_size); *response = ByteBuffer(&slice, 1); return Status::OK; } std::unique_ptr CreateAsyncServer(const ServerConfig& config) { return std::unique_ptr( new AsyncQpsServerTest( config, RegisterBenchmarkService, &BenchmarkService::AsyncService::RequestUnaryCall, &BenchmarkService::AsyncService::RequestStreamingCall, &BenchmarkService::AsyncService::RequestStreamingFromClient, &BenchmarkService::AsyncService::RequestStreamingFromServer, &BenchmarkService::AsyncService::RequestStreamingBothWays, ProcessSimpleRPC)); } std::unique_ptr CreateAsyncGenericServer(const ServerConfig& config) { return std::unique_ptr( new AsyncQpsServerTest( config, RegisterGenericService, nullptr, &grpc::AsyncGenericService::RequestCall, nullptr, nullptr, nullptr, ProcessGenericRPC)); } } // namespace testing } // namespace grpc