Merge pull request #5258 from vjpai/poll_override2

Avoid a race when overriding default poll function
pull/5347/head
Craig Tiller 9 years ago
commit fc6fec2a2d
  1. 4
      src/core/iomgr/pollset_posix.h
  2. 40
      test/cpp/end2end/async_end2end_test.cc

@ -142,6 +142,10 @@ int grpc_pollset_has_workers(grpc_pollset *pollset);
void grpc_remove_fd_from_all_epoll_sets(int fd);
/* override to allow tests to hook poll() usage */
/* NOTE: Any changes to grpc_poll_function must take place when the gRPC
is certainly not doing any polling anywhere.
Otherwise, there might be a race between changing the variable and actually
doing a polling operation */
typedef int (*grpc_poll_function_type)(struct pollfd *, nfds_t, int);
extern grpc_poll_function_type grpc_poll_function;
extern grpc_wakeup_fd grpc_global_wakeup_fd;

@ -43,6 +43,7 @@
#include <grpc/grpc.h>
#include <grpc/support/thd.h>
#include <grpc/support/time.h>
#include <grpc/support/tls.h>
#include <gtest/gtest.h>
#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
@ -59,6 +60,8 @@ using grpc::testing::EchoRequest;
using grpc::testing::EchoResponse;
using std::chrono::system_clock;
GPR_TLS_DECL(g_is_async_end2end_test);
namespace grpc {
namespace testing {
@ -67,9 +70,11 @@ namespace {
void* tag(int i) { return (void*)(intptr_t)i; }
#ifdef GPR_POSIX_SOCKET
static int assert_non_blocking_poll(struct pollfd* pfds, nfds_t nfds,
int timeout) {
GPR_ASSERT(timeout == 0);
static int maybe_assert_non_blocking_poll(struct pollfd* pfds, nfds_t nfds,
int timeout) {
if (gpr_tls_get(&g_is_async_end2end_test)) {
GPR_ASSERT(timeout == 0);
}
return poll(pfds, nfds, timeout);
}
@ -86,21 +91,21 @@ class PollOverride {
grpc_poll_function_type prev_;
};
class PollingCheckRegion : public PollOverride {
class PollingOverrider : public PollOverride {
public:
explicit PollingCheckRegion(bool allow_blocking)
: PollOverride(allow_blocking ? poll : assert_non_blocking_poll) {}
explicit PollingOverrider(bool allow_blocking)
: PollOverride(allow_blocking ? poll : maybe_assert_non_blocking_poll) {}
};
#else
class PollingCheckRegion {
class PollingOverrider {
public:
explicit PollingCheckRegion(bool allow_blocking) {}
explicit PollingOverrider(bool allow_blocking) {}
};
#endif
class Verifier : public PollingCheckRegion {
class Verifier {
public:
explicit Verifier(bool spin) : PollingCheckRegion(!spin), spin_(spin) {}
explicit Verifier(bool spin) : spin_(spin) {}
Verifier& Expect(int i, bool expect_ok) {
expectations_[tag(i)] = expect_ok;
return *this;
@ -183,6 +188,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<bool> {
AsyncEnd2endTest() {}
void SetUp() GRPC_OVERRIDE {
poll_overrider_.reset(new PollingOverrider(!GetParam()));
int port = grpc_pick_unused_port_or_die();
server_address_ << "localhost:" << port;
@ -193,6 +200,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<bool> {
builder.RegisterService(&service_);
cq_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
gpr_tls_set(&g_is_async_end2end_test, 1);
}
void TearDown() GRPC_OVERRIDE {
@ -202,6 +211,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<bool> {
cq_->Shutdown();
while (cq_->Next(&ignored_tag, &ignored_ok))
;
poll_overrider_.reset();
gpr_tls_set(&g_is_async_end2end_test, 0);
}
void ResetStub() {
@ -249,6 +260,8 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<bool> {
std::unique_ptr<Server> server_;
grpc::testing::EchoTestService::AsyncService service_;
std::ostringstream server_address_;
std::unique_ptr<PollingOverrider> poll_overrider_;
};
TEST_P(AsyncEnd2endTest, SimpleRpc) {
@ -1087,7 +1100,7 @@ class AsyncEnd2endServerTryCancelTest : public AsyncEnd2endTest {
Verifier(GetParam()).Expect(7, true).Verify(cq_.get());
// This is expected to fail in all cases i.e for all values of
// server_try_cancel. This is becasue at this point, either there are no
// server_try_cancel. This is because at this point, either there are no
// more msgs from the client (because client called WritesDone) or the RPC
// is cancelled on the server
srv_stream.Read(&recv_request, tag(8));
@ -1164,6 +1177,9 @@ INSTANTIATE_TEST_CASE_P(AsyncEnd2endServerTryCancel,
int main(int argc, char** argv) {
grpc_test_init(argc, argv);
gpr_tls_init(&g_is_async_end2end_test);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
int ret = RUN_ALL_TESTS();
gpr_tls_destroy(&g_is_async_end2end_test);
return ret;
}

Loading…
Cancel
Save