diff --git a/src/core/ext/transport/chttp2/server/chttp2_server.cc b/src/core/ext/transport/chttp2/server/chttp2_server.cc index b206eef59fb..58f755572b5 100644 --- a/src/core/ext/transport/chttp2/server/chttp2_server.cc +++ b/src/core/ext/transport/chttp2/server/chttp2_server.cc @@ -136,15 +136,15 @@ class Chttp2ServerListener : public Server::ListenerInterface { grpc_pollset_set* const interested_parties_; }; - ActiveConnection(RefCountedPtr listener, - grpc_pollset* accepting_pollset, + ActiveConnection(grpc_pollset* accepting_pollset, grpc_tcp_server_acceptor* acceptor, grpc_channel_args* args); ~ActiveConnection() override; void Orphan() override; - void Start(grpc_endpoint* endpoint, grpc_channel_args* args); + void Start(RefCountedPtr listener, + grpc_endpoint* endpoint, grpc_channel_args* args); // Needed to be able to grab an external ref in // Chttp2ServerListener::OnAccept() @@ -153,7 +153,7 @@ class Chttp2ServerListener : public Server::ListenerInterface { private: static void OnClose(void* arg, grpc_error* error); - RefCountedPtr const listener_; + RefCountedPtr listener_; Mutex mu_ ACQUIRED_AFTER(&listener_->mu_); // Set by HandshakingState before the handshaking begins and reset when // handshaking is done. @@ -165,6 +165,9 @@ class Chttp2ServerListener : public Server::ListenerInterface { bool shutdown_ ABSL_GUARDED_BY(&mu_) = false; }; + // To allow access to RefCounted<> like interface. + friend class RefCountedPtr; + // Should only be called once so as to start the TCP server. void StartListening(); @@ -177,6 +180,33 @@ class Chttp2ServerListener : public Server::ListenerInterface { static void DestroyListener(Server* /*server*/, void* arg, grpc_closure* destroy_done); + // The interface required by RefCountedPtr<> has been manually implemented + // here to take a ref on tcp_server_ instead. Note that, the handshaker needs + // tcp_server_ to exist for the lifetime of the handshake since it's needed by + // acceptor. Sharing refs between the listener and tcp_server_ is just an + // optimization to avoid taking additional refs on the listener, since + // TcpServerShutdownComplete already holds a ref to the listener. + void IncrementRefCount() { grpc_tcp_server_ref(tcp_server_); } + void IncrementRefCount(const DebugLocation& /* location */, + const char* /* reason */) { + IncrementRefCount(); + } + + RefCountedPtr Ref() GRPC_MUST_USE_RESULT { + IncrementRefCount(); + return RefCountedPtr(this); + } + RefCountedPtr Ref(const DebugLocation& /* location */, + const char* /* reason */) + GRPC_MUST_USE_RESULT { + return Ref(); + } + + void Unref() { grpc_tcp_server_unref(tcp_server_); } + void Unref(const DebugLocation& /* location */, const char* /* reason */) { + Unref(); + } + Server* const server_; grpc_tcp_server* tcp_server_; grpc_resolved_address resolved_address_; @@ -299,13 +329,16 @@ void Chttp2ServerListener::ActiveConnection::HandshakingState::Orphan() { } void Chttp2ServerListener::ActiveConnection::HandshakingState::Start( - grpc_endpoint* endpoint, - grpc_channel_args* args) ABSL_NO_THREAD_SAFETY_ANALYSIS { + grpc_endpoint* endpoint, grpc_channel_args* args) { Ref().release(); // Held by OnHandshakeDone - // Not acquiring a lock for handshake_mgr_ since it is only reset in - // OnHandshakeDone or on destruction. - handshake_mgr_->DoHandshake(endpoint, args, deadline_, acceptor_, - OnHandshakeDone, this); + RefCountedPtr handshake_mgr; + { + MutexLock lock(&connection_->mu_); + if (handshake_mgr_ == nullptr) return; + handshake_mgr = handshake_mgr_; + } + handshake_mgr->DoHandshake(endpoint, args, deadline_, acceptor_, + OnHandshakeDone, this); } void Chttp2ServerListener::ActiveConnection::HandshakingState::OnTimeout( @@ -452,11 +485,9 @@ void Chttp2ServerListener::ActiveConnection::HandshakingState::OnHandshakeDone( // Chttp2ServerListener::ActiveConnection::ActiveConnection( - RefCountedPtr listener, grpc_pollset* accepting_pollset, grpc_tcp_server_acceptor* acceptor, grpc_channel_args* args) - : listener_(std::move(listener)), - handshaking_state_(MakeOrphanable( + : handshaking_state_(MakeOrphanable( Ref(), accepting_pollset, acceptor, args)) { GRPC_CLOSURE_INIT(&on_close_, ActiveConnection::OnClose, this, grpc_schedule_on_exec_ctx); @@ -488,9 +519,11 @@ void Chttp2ServerListener::ActiveConnection::Orphan() { Unref(); } -void Chttp2ServerListener::ActiveConnection::Start(grpc_endpoint* endpoint, - grpc_channel_args* args) { +void Chttp2ServerListener::ActiveConnection::Start( + RefCountedPtr listener, grpc_endpoint* endpoint, + grpc_channel_args* args) { RefCountedPtr handshaking_state_ref; + listener_ = std::move(listener); { MutexLock lock(&mu_); if (shutdown_) return; @@ -655,11 +688,12 @@ void Chttp2ServerListener::OnAccept(void* arg, grpc_endpoint* tcp, MutexLock lock(&self->channel_args_mu_); args = grpc_channel_args_copy(self->args_); } - auto connection = MakeOrphanable( - self->Ref(), accepting_pollset, acceptor, args); + auto connection = + MakeOrphanable(accepting_pollset, acceptor, args); // Hold a ref to connection to allow starting handshake outside the // critical region RefCountedPtr connection_ref = connection->Ref(); + RefCountedPtr listener_ref; { MutexLock lock(&self->mu_); // Shutdown the the connection if listener's stopped serving. @@ -673,6 +707,12 @@ void Chttp2ServerListener::OnAccept(void* arg, grpc_endpoint* tcp, GPR_ERROR, "Memory quota exhausted, rejecting connection, no handshaking."); } else { + // This ref needs to be taken in the critical region after having made + // sure that the listener has not been Orphaned, so as to avoid + // heap-use-after-free issues where `Ref()` is invoked when the ref of + // tcp_server_ has already reached 0. (Ref() implementation of + // Chttp2ServerListener is grpc_tcp_server_ref().) + listener_ref = self->Ref(); self->connections_.emplace(connection.get(), std::move(connection)); } } @@ -682,7 +722,7 @@ void Chttp2ServerListener::OnAccept(void* arg, grpc_endpoint* tcp, grpc_endpoint_destroy(tcp); gpr_free(acceptor); } else { - connection_ref->Start(tcp, args); + connection_ref->Start(std::move(listener_ref), tcp, args); } grpc_channel_args_destroy(args); } @@ -690,17 +730,9 @@ void Chttp2ServerListener::OnAccept(void* arg, grpc_endpoint* tcp, void Chttp2ServerListener::TcpServerShutdownComplete(void* arg, grpc_error* error) { Chttp2ServerListener* self = static_cast(arg); - std::map> connections; - /* ensure all threads have unlocked */ - { - MutexLock lock(&self->mu_); - self->is_serving_ = false; - // Orphan the connections so that they can start cleaning up. - connections = std::move(self->connections_); - self->channelz_listen_socket_.reset(); - } + self->channelz_listen_socket_.reset(); GRPC_ERROR_UNREF(error); - self->Unref(); + delete self; } /* Server callback: destroy the tcp listener (so we don't generate further @@ -711,10 +743,14 @@ void Chttp2ServerListener::Orphan() { if (config_fetcher_watcher_ != nullptr) { server_->config_fetcher()->CancelWatch(config_fetcher_watcher_); } + std::map> connections; grpc_tcp_server* tcp_server; { MutexLock lock(&mu_); shutdown_ = true; + is_serving_ = false; + // Orphan the connections so that they can start cleaning up. + connections = std::move(connections_); // If the listener is currently set to be serving but has not been started // yet, it means that `grpc_tcp_server_start` is in progress. Wait for the // operation to finish to avoid causing races. diff --git a/src/core/lib/surface/server.h b/src/core/lib/surface/server.h index bc2f4a6c627..b0f422cf8df 100644 --- a/src/core/lib/surface/server.h +++ b/src/core/lib/surface/server.h @@ -71,7 +71,7 @@ class Server : public InternallyRefCounted { /// Interface for listeners. /// Implementations must override the Orphan() method, which should stop /// listening and initiate destruction of the listener. - class ListenerInterface : public InternallyRefCounted { + class ListenerInterface : public Orphanable { public: ~ListenerInterface() override = default;