Adjust stream cancellation point and fix races in sync client

pull/13576/head
Vijay Pai 7 years ago
parent d1945788c5
commit 6389457ed2
  1. 163
      test/cpp/qps/client_sync.cc

@ -62,11 +62,13 @@ class SynchronousClient
virtual ~SynchronousClient(){}; virtual ~SynchronousClient(){};
virtual void InitThreadFuncImpl(size_t thread_idx) = 0; virtual bool InitThreadFuncImpl(size_t thread_idx) = 0;
virtual bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) = 0; virtual bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) = 0;
void ThreadFunc(size_t thread_idx, Thread* t) override { void ThreadFunc(size_t thread_idx, Thread* t) override {
InitThreadFuncImpl(thread_idx); if (!InitThreadFuncImpl(thread_idx)) {
return;
}
for (;;) { for (;;) {
// run the loop body // run the loop body
HistogramEntry entry; HistogramEntry entry;
@ -109,9 +111,6 @@ class SynchronousClient
size_t num_threads_; size_t num_threads_;
std::vector<SimpleResponse> responses_; std::vector<SimpleResponse> responses_;
private:
void DestroyMultithreading() override final { EndThreads(); }
}; };
class SynchronousUnaryClient final : public SynchronousClient { class SynchronousUnaryClient final : public SynchronousClient {
@ -122,7 +121,7 @@ class SynchronousUnaryClient final : public SynchronousClient {
} }
~SynchronousUnaryClient() {} ~SynchronousUnaryClient() {}
void InitThreadFuncImpl(size_t thread_idx) override {} bool InitThreadFuncImpl(size_t thread_idx) override { return true; }
bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override {
if (!WaitToIssue(thread_idx)) { if (!WaitToIssue(thread_idx)) {
@ -140,6 +139,9 @@ class SynchronousUnaryClient final : public SynchronousClient {
entry->set_status(s.error_code()); entry->set_status(s.error_code());
return true; return true;
} }
private:
void DestroyMultithreading() override final { EndThreads(); }
}; };
template <class StreamType> template <class StreamType>
@ -149,31 +151,30 @@ class SynchronousStreamingClient : public SynchronousClient {
: SynchronousClient(config), : SynchronousClient(config),
context_(num_threads_), context_(num_threads_),
stream_(num_threads_), stream_(num_threads_),
stream_mu_(num_threads_),
shutdown_(num_threads_),
messages_per_stream_(config.messages_per_stream()), messages_per_stream_(config.messages_per_stream()),
messages_issued_(num_threads_) { messages_issued_(num_threads_) {
StartThreads(num_threads_); StartThreads(num_threads_);
} }
virtual ~SynchronousStreamingClient() { virtual ~SynchronousStreamingClient() {
std::vector<std::thread> cleanup_threads; OnAllStreams([](ClientContext* ctx, StreamType* s) -> bool {
for (size_t i = 0; i < num_threads_; i++) { // don't log any kind of error since we might have canceled it
cleanup_threads.emplace_back([this, i]() { s->Finish().IgnoreError();
auto stream = &stream_[i]; return true;
if (*stream) {
// forcibly cancel the streams, then finish
context_[i].TryCancel();
(*stream)->Finish().IgnoreError();
// don't log any error message on !ok since this was canceled
}
}); });
} }
for (auto& th : cleanup_threads) {
th.join();
}
}
protected: protected:
std::vector<grpc::ClientContext> context_; std::vector<grpc::ClientContext> context_;
std::vector<std::unique_ptr<StreamType>> stream_; std::vector<std::unique_ptr<StreamType>> stream_;
// stream_mu_ is only needed when changing an element of stream_ or context_
std::vector<std::mutex> stream_mu_;
struct Bool {
bool val;
Bool() : val(false) {}
};
std::vector<Bool> shutdown_;
const int messages_per_stream_; const int messages_per_stream_;
std::vector<int> messages_issued_; std::vector<int> messages_issued_;
@ -185,9 +186,34 @@ class SynchronousStreamingClient : public SynchronousClient {
gpr_log(GPR_ERROR, "Stream %" PRIuPTR " received an error %s", thread_idx, gpr_log(GPR_ERROR, "Stream %" PRIuPTR " received an error %s", thread_idx,
s.error_message().c_str()); s.error_message().c_str());
} }
// Lock the stream_mu_ now because the client context could change
std::lock_guard<std::mutex> l(stream_mu_[thread_idx]);
context_[thread_idx].~ClientContext(); context_[thread_idx].~ClientContext();
new (&context_[thread_idx]) ClientContext(); new (&context_[thread_idx]) ClientContext();
} }
void OnAllStreams(std::function<bool(ClientContext*, StreamType*)> cleaner) {
std::vector<std::thread> cleanup_threads;
for (size_t i = 0; i < num_threads_; i++) {
cleanup_threads.emplace_back([this, i, cleaner]() {
std::lock_guard<std::mutex> l(stream_mu_[i]);
if (stream_[i]) {
shutdown_[i].val = cleaner(&context_[i], stream_[i].get());
}
});
}
for (auto& th : cleanup_threads) {
th.join();
}
}
private:
void DestroyMultithreading() override final {
OnAllStreams([](ClientContext* ctx, StreamType* s) -> bool {
ctx->TryCancel();
return true;
});
EndThreads();
}
}; };
class SynchronousStreamingPingPongClient final class SynchronousStreamingPingPongClient final
@ -197,24 +223,24 @@ class SynchronousStreamingPingPongClient final
SynchronousStreamingPingPongClient(const ClientConfig& config) SynchronousStreamingPingPongClient(const ClientConfig& config)
: SynchronousStreamingClient(config) {} : SynchronousStreamingClient(config) {}
~SynchronousStreamingPingPongClient() { ~SynchronousStreamingPingPongClient() {
std::vector<std::thread> cleanup_threads; OnAllStreams(
for (size_t i = 0; i < num_threads_; i++) { [](ClientContext* ctx,
cleanup_threads.emplace_back([this, i]() { grpc::ClientReaderWriter<SimpleRequest, SimpleResponse>* s) -> bool {
auto stream = &stream_[i]; s->WritesDone();
if (*stream) { return true;
(*stream)->WritesDone();
}
}); });
} }
for (auto& th : cleanup_threads) {
th.join();
}
}
void InitThreadFuncImpl(size_t thread_idx) override { bool InitThreadFuncImpl(size_t thread_idx) override {
auto* stub = channels_[thread_idx % channels_.size()].get_stub(); auto* stub = channels_[thread_idx % channels_.size()].get_stub();
std::lock_guard<std::mutex> l(stream_mu_[thread_idx]);
if (!shutdown_[thread_idx].val) {
stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]);
} else {
return false;
}
messages_issued_[thread_idx] = 0; messages_issued_[thread_idx] = 0;
return true;
} }
bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override {
@ -239,7 +265,13 @@ class SynchronousStreamingPingPongClient final
stream_[thread_idx]->WritesDone(); stream_[thread_idx]->WritesDone();
FinishStream(entry, thread_idx); FinishStream(entry, thread_idx);
auto* stub = channels_[thread_idx % channels_.size()].get_stub(); auto* stub = channels_[thread_idx % channels_.size()].get_stub();
std::lock_guard<std::mutex> l(stream_mu_[thread_idx]);
if (!shutdown_[thread_idx].val) {
stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]);
} else {
stream_[thread_idx].reset();
return false;
}
messages_issued_[thread_idx] = 0; messages_issued_[thread_idx] = 0;
return true; return true;
} }
@ -251,25 +283,24 @@ class SynchronousStreamingFromClientClient final
SynchronousStreamingFromClientClient(const ClientConfig& config) SynchronousStreamingFromClientClient(const ClientConfig& config)
: SynchronousStreamingClient(config), last_issue_(num_threads_) {} : SynchronousStreamingClient(config), last_issue_(num_threads_) {}
~SynchronousStreamingFromClientClient() { ~SynchronousStreamingFromClientClient() {
std::vector<std::thread> cleanup_threads; OnAllStreams(
for (size_t i = 0; i < num_threads_; i++) { [](ClientContext* ctx, grpc::ClientWriter<SimpleRequest>* s) -> bool {
cleanup_threads.emplace_back([this, i]() { s->WritesDone();
auto stream = &stream_[i]; return true;
if (*stream) {
(*stream)->WritesDone();
}
}); });
} }
for (auto& th : cleanup_threads) {
th.join();
}
}
void InitThreadFuncImpl(size_t thread_idx) override { bool InitThreadFuncImpl(size_t thread_idx) override {
auto* stub = channels_[thread_idx % channels_.size()].get_stub(); auto* stub = channels_[thread_idx % channels_.size()].get_stub();
std::lock_guard<std::mutex> l(stream_mu_[thread_idx]);
if (!shutdown_[thread_idx].val) {
stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx],
&responses_[thread_idx]); &responses_[thread_idx]);
} else {
return false;
}
last_issue_[thread_idx] = UsageTimer::Now(); last_issue_[thread_idx] = UsageTimer::Now();
return true;
} }
bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override {
@ -287,8 +318,14 @@ class SynchronousStreamingFromClientClient final
stream_[thread_idx]->WritesDone(); stream_[thread_idx]->WritesDone();
FinishStream(entry, thread_idx); FinishStream(entry, thread_idx);
auto* stub = channels_[thread_idx % channels_.size()].get_stub(); auto* stub = channels_[thread_idx % channels_.size()].get_stub();
std::lock_guard<std::mutex> l(stream_mu_[thread_idx]);
if (!shutdown_[thread_idx].val) {
stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx],
&responses_[thread_idx]); &responses_[thread_idx]);
} else {
stream_[thread_idx].reset();
return false;
}
return true; return true;
} }
@ -301,11 +338,17 @@ class SynchronousStreamingFromServerClient final
public: public:
SynchronousStreamingFromServerClient(const ClientConfig& config) SynchronousStreamingFromServerClient(const ClientConfig& config)
: SynchronousStreamingClient(config), last_recv_(num_threads_) {} : SynchronousStreamingClient(config), last_recv_(num_threads_) {}
void InitThreadFuncImpl(size_t thread_idx) override { bool InitThreadFuncImpl(size_t thread_idx) override {
auto* stub = channels_[thread_idx % channels_.size()].get_stub(); auto* stub = channels_[thread_idx % channels_.size()].get_stub();
std::lock_guard<std::mutex> l(stream_mu_[thread_idx]);
if (!shutdown_[thread_idx].val) {
stream_[thread_idx] = stream_[thread_idx] =
stub->StreamingFromServer(&context_[thread_idx], request_); stub->StreamingFromServer(&context_[thread_idx], request_);
} else {
return false;
}
last_recv_[thread_idx] = UsageTimer::Now(); last_recv_[thread_idx] = UsageTimer::Now();
return true;
} }
bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override {
GPR_TIMER_SCOPE("SynchronousStreamingFromServerClient::ThreadFunc", 0); GPR_TIMER_SCOPE("SynchronousStreamingFromServerClient::ThreadFunc", 0);
@ -317,8 +360,14 @@ class SynchronousStreamingFromServerClient final
} }
FinishStream(entry, thread_idx); FinishStream(entry, thread_idx);
auto* stub = channels_[thread_idx % channels_.size()].get_stub(); auto* stub = channels_[thread_idx % channels_.size()].get_stub();
std::lock_guard<std::mutex> l(stream_mu_[thread_idx]);
if (!shutdown_[thread_idx].val) {
stream_[thread_idx] = stream_[thread_idx] =
stub->StreamingFromServer(&context_[thread_idx], request_); stub->StreamingFromServer(&context_[thread_idx], request_);
} else {
stream_[thread_idx].reset();
return false;
}
return true; return true;
} }
@ -333,23 +382,23 @@ class SynchronousStreamingBothWaysClient final
SynchronousStreamingBothWaysClient(const ClientConfig& config) SynchronousStreamingBothWaysClient(const ClientConfig& config)
: SynchronousStreamingClient(config) {} : SynchronousStreamingClient(config) {}
~SynchronousStreamingBothWaysClient() { ~SynchronousStreamingBothWaysClient() {
std::vector<std::thread> cleanup_threads; OnAllStreams(
for (size_t i = 0; i < num_threads_; i++) { [](ClientContext* ctx,
cleanup_threads.emplace_back([this, i]() { grpc::ClientReaderWriter<SimpleRequest, SimpleResponse>* s) -> bool {
auto stream = &stream_[i]; s->WritesDone();
if (*stream) { return true;
(*stream)->WritesDone();
}
}); });
} }
for (auto& th : cleanup_threads) {
th.join();
}
}
void InitThreadFuncImpl(size_t thread_idx) override { bool InitThreadFuncImpl(size_t thread_idx) override {
auto* stub = channels_[thread_idx % channels_.size()].get_stub(); auto* stub = channels_[thread_idx % channels_.size()].get_stub();
std::lock_guard<std::mutex> l(stream_mu_[thread_idx]);
if (!shutdown_[thread_idx].val) {
stream_[thread_idx] = stub->StreamingBothWays(&context_[thread_idx]); stream_[thread_idx] = stub->StreamingBothWays(&context_[thread_idx]);
} else {
return false;
}
return true;
} }
bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override {
// TODO (vjpai): Do this // TODO (vjpai): Do this

Loading…
Cancel
Save