diff --git a/src/core/lib/iomgr/iomgr_posix.cc b/src/core/lib/iomgr/iomgr_posix.cc index 437d0327507..0a862ccad64 100644 --- a/src/core/lib/iomgr/iomgr_posix.cc +++ b/src/core/lib/iomgr/iomgr_posix.cc @@ -80,6 +80,7 @@ void grpc_set_default_iomgr_platform() { grpc_set_pollset_vtable(&grpc_posix_pollset_vtable); grpc_set_pollset_set_vtable(&grpc_posix_pollset_set_vtable); grpc_core::SetDNSResolver(grpc_core::NativeDNSResolver::GetOrCreate()); + grpc_tcp_client_global_init(); grpc_set_iomgr_platform_vtable(&vtable); } diff --git a/src/core/lib/iomgr/iomgr_posix_cfstream.cc b/src/core/lib/iomgr/iomgr_posix_cfstream.cc index 74cd19cc25a..90677e516a1 100644 --- a/src/core/lib/iomgr/iomgr_posix_cfstream.cc +++ b/src/core/lib/iomgr/iomgr_posix_cfstream.cc @@ -177,6 +177,7 @@ void grpc_set_default_iomgr_platform() { grpc_set_pollset_set_vtable(&grpc_apple_pollset_set_vtable); grpc_set_iomgr_platform_vtable(&apple_vtable); } + grpc_tcp_client_global_init(); grpc_set_timer_impl(&grpc_generic_timer_vtable); grpc_core::SetDNSResolver(grpc_core::NativeDNSResolver::GetOrCreate()); } diff --git a/src/core/lib/iomgr/tcp_client.h b/src/core/lib/iomgr/tcp_client.h index 59a363c89d7..24235cde277 100644 --- a/src/core/lib/iomgr/tcp_client.h +++ b/src/core/lib/iomgr/tcp_client.h @@ -53,10 +53,13 @@ int64_t grpc_tcp_client_connect(grpc_closure* on_connect, grpc_core::Timestamp deadline); // Returns true if a connect attempt corresponding to the provided handle -// is successfully cancelled. Otherwise it returns false. +// is successfully cancelled. Otherwise it returns false. If the connect +// attempt is successfully cancelled, then the on_connect closure passed to +// grpc_tcp_client_connect will not be executed. Its upto the caller to free +// up any resources that may have been allocated to create the closure. bool grpc_tcp_client_cancel_connect(int64_t connection_handle); -void grpc_tcp_client_global_init(); +extern void grpc_tcp_client_global_init(); void grpc_set_tcp_client_impl(grpc_tcp_client_vtable* impl); diff --git a/src/core/lib/iomgr/tcp_client_posix.cc b/src/core/lib/iomgr/tcp_client_posix.cc index 389c2734aeb..849a1d903ea 100644 --- a/src/core/lib/iomgr/tcp_client_posix.cc +++ b/src/core/lib/iomgr/tcp_client_posix.cc @@ -27,6 +27,7 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include @@ -62,8 +63,33 @@ struct async_connect { grpc_endpoint** ep; grpc_closure* closure; grpc_channel_args* channel_args; + int64_t connection_handle; + bool connect_cancelled; }; +struct ConnectionShard { + grpc_core::Mutex mu; + absl::flat_hash_map pending_connections + ABSL_GUARDED_BY(&mu); +}; + +namespace { + +gpr_once g_tcp_client_posix_init = GPR_ONCE_INIT; +std::vector* g_connection_shards = nullptr; +std::atomic g_connection_id{1}; + +void do_tcp_client_global_init(void) { + size_t num_shards = std::max(2 * gpr_cpu_num_cores(), 1u); + g_connection_shards = new std::vector(num_shards); +} + +} // namespace + +void grpc_tcp_client_global_init() { + gpr_once_init(&g_tcp_client_posix_init, do_tcp_client_global_init); +} + static grpc_error_handle prepare_socket(const grpc_resolved_address* addr, int fd, const grpc_channel_args* channel_args) { @@ -150,6 +176,7 @@ static void on_writable(void* acp, grpc_error_handle error) { GPR_ASSERT(ac->fd); fd = ac->fd; ac->fd = nullptr; + bool connect_cancelled = ac->connect_cancelled; gpr_mu_unlock(&ac->mu); grpc_timer_cancel(&ac->alarm); @@ -161,6 +188,12 @@ static void on_writable(void* acp, grpc_error_handle error) { goto finish; } + if (connect_cancelled) { + // The callback should not get scheduled in this case. + error = GRPC_ERROR_NONE; + goto finish; + } + do { so_error_size = sizeof(so_error); err = getsockopt(grpc_fd_wrapped_fd(fd), SOL_SOCKET, SO_ERROR, &so_error, @@ -208,6 +241,14 @@ static void on_writable(void* acp, grpc_error_handle error) { } finish: + if (!connect_cancelled) { + int shard_number = ac->connection_handle % (*g_connection_shards).size(); + struct ConnectionShard* shard = &(*g_connection_shards)[shard_number]; + { + grpc_core::MutexLock lock(&shard->mu); + shard->pending_connections.erase(ac->connection_handle); + } + } if (fd != nullptr) { grpc_pollset_set_del_fd(ac->interested_parties, fd); grpc_fd_orphan(fd, nullptr, nullptr, "tcp_client_orphan"); @@ -234,7 +275,12 @@ finish: // Push async connect closure to the executor since this may actually be // called during the shutdown process, in which case a deadlock could form // between the core shutdown mu and the connector mu (b/188239051) - grpc_core::Executor::Run(closure, error); + if (!connect_cancelled) { + grpc_core::Executor::Run(closure, error); + } else if (!GRPC_ERROR_IS_NONE(error)) { + // Unref the error here because it is not used. + (void)GRPC_ERROR_UNREF(error); + } } grpc_error_handle grpc_tcp_client_prepare_fd( @@ -267,7 +313,7 @@ grpc_error_handle grpc_tcp_client_prepare_fd( return GRPC_ERROR_NONE; } -void grpc_tcp_client_create_from_prepared_fd( +int64_t grpc_tcp_client_create_from_prepared_fd( grpc_pollset_set* interested_parties, grpc_closure* closure, const int fd, const grpc_channel_args* channel_args, const grpc_resolved_address* addr, grpc_core::Timestamp deadline, grpc_endpoint** ep) { @@ -282,24 +328,33 @@ void grpc_tcp_client_create_from_prepared_fd( grpc_error_handle error = GRPC_ERROR_CREATE_FROM_CPP_STRING(addr_uri.status().ToString()); grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); - return; + return 0; } std::string name = absl::StrCat("tcp-client:", addr_uri.value()); grpc_fd* fdobj = grpc_fd_create(fd, name.c_str(), true); + int64_t connection_id = 0; + if (errno == EWOULDBLOCK || errno == EINPROGRESS) { + // Connection is still in progress. + connection_id = g_connection_id.fetch_add(1, std::memory_order_acq_rel); + } if (err >= 0) { + // Connection already succeded. Return 0 to discourage any cancellation + // attempts. *ep = grpc_tcp_client_create_from_fd(fdobj, channel_args, addr_uri.value()); grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); - return; + return 0; } if (errno != EWOULDBLOCK && errno != EINPROGRESS) { + // Connection already failed. Return 0 to discourage any cancellation + // attempts. grpc_error_handle error = GRPC_OS_ERROR(errno, "connect"); error = grpc_error_set_str(error, GRPC_ERROR_STR_TARGET_ADDRESS, addr_uri.value()); grpc_fd_orphan(fdobj, nullptr, nullptr, "tcp_client_connect_error"); grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); - return; + return 0; } grpc_pollset_set_add_fd(interested_parties, fdobj); @@ -310,6 +365,8 @@ void grpc_tcp_client_create_from_prepared_fd( ac->fd = fdobj; ac->interested_parties = interested_parties; ac->addr_str = addr_uri.value(); + ac->connection_handle = connection_id; + ac->connect_cancelled = false; gpr_mu_init(&ac->mu); ac->refs = 2; GRPC_CLOSURE_INIT(&ac->write_closure, on_writable, ac, @@ -321,11 +378,19 @@ void grpc_tcp_client_create_from_prepared_fd( ac->addr_str.c_str(), fdobj); } + int shard_number = connection_id % (*g_connection_shards).size(); + struct ConnectionShard* shard = &(*g_connection_shards)[shard_number]; + { + grpc_core::MutexLock lock(&shard->mu); + shard->pending_connections.insert_or_assign(connection_id, ac); + } + gpr_mu_lock(&ac->mu); GRPC_CLOSURE_INIT(&ac->on_alarm, tc_on_alarm, ac, grpc_schedule_on_exec_ctx); grpc_timer_init(&ac->alarm, deadline, &ac->on_alarm); grpc_fd_notify_on_write(ac->fd, &ac->write_closure); gpr_mu_unlock(&ac->mu); + return connection_id; } static int64_t tcp_connect(grpc_closure* closure, grpc_endpoint** ep, @@ -342,13 +407,62 @@ static int64_t tcp_connect(grpc_closure* closure, grpc_endpoint** ep, grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); return 0; } - grpc_tcp_client_create_from_prepared_fd(interested_parties, closure, fd, - channel_args, &mapped_addr, deadline, - ep); - return 0; + return grpc_tcp_client_create_from_prepared_fd(interested_parties, closure, + fd, channel_args, &mapped_addr, + deadline, ep); } -static bool tcp_cancel_connect(int64_t /*connection_handle*/) { return false; } +static bool tcp_cancel_connect(int64_t connection_handle) { + if (connection_handle <= 0) { + return false; + } + int shard_number = connection_handle % (*g_connection_shards).size(); + struct ConnectionShard* shard = &(*g_connection_shards)[shard_number]; + async_connect* ac = nullptr; + { + grpc_core::MutexLock lock(&shard->mu); + auto it = shard->pending_connections.find(connection_handle); + if (it != shard->pending_connections.end()) { + ac = it->second; + GPR_ASSERT(ac != nullptr); + // Trying to acquire ac->mu here would could cause a deadlock because + // the on_writable method tries to acquire the two mutexes used + // here in the reverse order. But we dont need to acquire ac->mu before + // incrementing ac->refs here. This is because the on_writable + // method decrements ac->refs only after deleting the connection handle + // from the corresponding hashmap. If the code enters here, it means that + // deletion hasn't happened yet. The deletion can only happen after the + // corresponding g_shard_mu is unlocked. + ++ac->refs; + // Remove connection from list of active connections. + shard->pending_connections.erase(it); + } + } + if (ac == nullptr) { + return false; + } + gpr_mu_lock(&ac->mu); + bool connection_cancel_success = (ac->fd != nullptr); + if (connection_cancel_success) { + // Connection is still pending. The on_writable callback hasn't executed + // yet because ac->fd != nullptr. + ac->connect_cancelled = true; + // Shutdown the fd. This would cause on_writable to run as soon as possible. + // We dont need to pass a custom error here because it wont be used since + // the on_connect_closure is not run if connect cancellation is successfull. + grpc_fd_shutdown(ac->fd, GRPC_ERROR_NONE); + } + bool done = (--ac->refs == 0); + gpr_mu_unlock(&ac->mu); + if (done) { + // This is safe even outside the lock, because "done", the sentinel, is + // populated *inside* the lock. + gpr_mu_destroy(&ac->mu); + grpc_channel_args_destroy(ac->channel_args); + delete ac; + } + return connection_cancel_success; +} grpc_tcp_client_vtable grpc_posix_tcp_client_vtable = {tcp_connect, tcp_cancel_connect}; diff --git a/src/core/lib/iomgr/tcp_client_posix.h b/src/core/lib/iomgr/tcp_client_posix.h index eb87b9fb79e..ddcb6e0691d 100644 --- a/src/core/lib/iomgr/tcp_client_posix.h +++ b/src/core/lib/iomgr/tcp_client_posix.h @@ -61,7 +61,7 @@ grpc_error_handle grpc_tcp_client_prepare_fd( deadline: connection deadline ep: out parameter. Set before closure is called if successful */ -void grpc_tcp_client_create_from_prepared_fd( +int64_t grpc_tcp_client_create_from_prepared_fd( grpc_pollset_set* interested_parties, grpc_closure* closure, const int fd, const grpc_channel_args* channel_args, const grpc_resolved_address* addr, grpc_core::Timestamp deadline, grpc_endpoint** ep); diff --git a/test/core/iomgr/tcp_client_posix_test.cc b/test/core/iomgr/tcp_client_posix_test.cc index 71acff6ea8f..ae7e26a101c 100644 --- a/test/core/iomgr/tcp_client_posix_test.cc +++ b/test/core/iomgr/tcp_client_posix_test.cc @@ -112,8 +112,9 @@ void test_succeeds(void) { .channel_args_preconditioning() .PreconditionChannelArgs(nullptr) .ToC(); - grpc_tcp_client_connect(&done, &g_connecting, g_pollset_set, args, - &resolved_addr, grpc_core::Timestamp::InfFuture()); + int64_t connection_handle = grpc_tcp_client_connect( + &done, &g_connecting, g_pollset_set, args, &resolved_addr, + grpc_core::Timestamp::InfFuture()); grpc_channel_args_destroy(args); /* await the connection */ do { @@ -139,6 +140,10 @@ void test_succeeds(void) { } gpr_mu_unlock(g_mu); + + // A cancellation attempt should fail because connect already succeeded. + GPR_ASSERT(grpc_tcp_client_cancel_connect(connection_handle) == false); + gpr_log(GPR_ERROR, "---- finished test_succeeds() ----"); } @@ -161,8 +166,9 @@ void test_fails(void) { /* connect to a broken address */ GRPC_CLOSURE_INIT(&done, must_fail, nullptr, grpc_schedule_on_exec_ctx); - grpc_tcp_client_connect(&done, &g_connecting, g_pollset_set, nullptr, - &resolved_addr, grpc_core::Timestamp::InfFuture()); + int64_t connection_handle = grpc_tcp_client_connect( + &done, &g_connecting, g_pollset_set, nullptr, &resolved_addr, + grpc_core::Timestamp::InfFuture()); gpr_mu_lock(g_mu); /* wait for the connection callback to finish */ @@ -187,9 +193,52 @@ void test_fails(void) { } gpr_mu_unlock(g_mu); + + // A cancellation attempt should fail because connect already failed. + GPR_ASSERT(grpc_tcp_client_cancel_connect(connection_handle) == false); + gpr_log(GPR_ERROR, "---- finished test_fails() ----"); } +void test_connect_cancellation_succeeds(void) { + gpr_log(GPR_ERROR, "---- starting test_connect_cancellation_succeeds() ----"); + grpc_resolved_address resolved_addr; + struct sockaddr_in* addr = + reinterpret_cast(resolved_addr.addr); + int svr_fd; + grpc_closure done; + grpc_core::ExecCtx exec_ctx; + + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = static_cast(sizeof(struct sockaddr_in)); + addr->sin_family = AF_INET; + + /* create a phony server */ + svr_fd = socket(AF_INET, SOCK_STREAM, 0); + GPR_ASSERT(svr_fd >= 0); + GPR_ASSERT( + 0 == bind(svr_fd, (struct sockaddr*)addr, (socklen_t)resolved_addr.len)); + GPR_ASSERT(0 == listen(svr_fd, 1)); + + // connect to it. accept() is not called on the bind socket. So the connection + // should appear to be stuck giving ample time to try to cancel it. + GPR_ASSERT(getsockname(svr_fd, (struct sockaddr*)addr, + (socklen_t*)&resolved_addr.len) == 0); + GRPC_CLOSURE_INIT(&done, must_succeed, nullptr, grpc_schedule_on_exec_ctx); + const grpc_channel_args* args = grpc_core::CoreConfiguration::Get() + .channel_args_preconditioning() + .PreconditionChannelArgs(nullptr) + .ToC(); + int64_t connection_handle = grpc_tcp_client_connect( + &done, &g_connecting, g_pollset_set, args, &resolved_addr, + grpc_core::Timestamp::InfFuture()); + grpc_channel_args_destroy(args); + GPR_ASSERT(connection_handle > 0); + GPR_ASSERT(grpc_tcp_client_cancel_connect(connection_handle) == true); + close(svr_fd); + gpr_log(GPR_ERROR, "---- finished test_connect_cancellation_succeeds() ----"); +} + void test_fails_bad_addr_no_leak(void) { gpr_log(GPR_ERROR, "---- starting test_fails_bad_addr_no_leak() ----"); grpc_resolved_address resolved_addr; @@ -250,6 +299,7 @@ int main(int argc, char** argv) { grpc_pollset_set_add_pollset(g_pollset_set, g_pollset); test_succeeds(); + test_connect_cancellation_succeeds(); test_fails(); test_fails_bad_addr_no_leak(); grpc_pollset_set_destroy(g_pollset_set);