[security-handshaker] Simplify refcounting (#37345)

Make the refcounting in the class a little less manual

Closes #37345

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/37345 from ctiller:things-that-make-you-go-hmmm 67f23b6666
PiperOrigin-RevId: 660022927
pull/37358/head
Craig Tiller 7 months ago committed by Copybara-Service
parent 3de09c544d
commit 255bc5cb8b
  1. 131
      src/core/handshaker/security/security_handshaker.cc

@ -88,27 +88,27 @@ class SecurityHandshaker : public Handshaker {
private:
grpc_error_handle DoHandshakerNextLocked(const unsigned char* bytes_received,
size_t bytes_received_size);
size_t bytes_received_size)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
grpc_error_handle OnHandshakeNextDoneLocked(
tsi_result result, const unsigned char* bytes_to_send,
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
void HandshakeFailedLocked(absl::Status error);
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
void HandshakeFailedLocked(absl::Status error)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
void Finish(absl::Status status);
void OnHandshakeDataReceivedFromPeerFn(absl::Status error);
void OnHandshakeDataSentToPeerFn(absl::Status error);
static void OnHandshakeDataReceivedFromPeerFnScheduler(
void* arg, grpc_error_handle error);
static void OnHandshakeDataSentToPeerFnScheduler(void* arg,
grpc_error_handle error);
void OnHandshakeDataReceivedFromPeerFnScheduler(grpc_error_handle error);
void OnHandshakeDataSentToPeerFnScheduler(grpc_error_handle error);
static void OnHandshakeNextDoneGrpcWrapper(
tsi_result result, void* user_data, const unsigned char* bytes_to_send,
size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
static void OnPeerCheckedFn(void* arg, grpc_error_handle error);
void OnPeerCheckedInner(grpc_error_handle error);
void OnPeerCheckedFn(grpc_error_handle error);
size_t MoveReadBufferIntoHandshakeBuffer();
grpc_error_handle CheckPeerLocked();
grpc_error_handle CheckPeerLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// State set at creation time.
tsi_handshaker* handshaker_;
@ -125,13 +125,11 @@ class SecurityHandshaker : public Handshaker {
size_t handshake_buffer_size_;
unsigned char* handshake_buffer_;
SliceBuffer outgoing_;
grpc_closure on_handshake_data_sent_to_peer_;
grpc_closure on_handshake_data_received_from_peer_;
grpc_closure on_peer_checked_;
RefCountedPtr<grpc_auth_context> auth_context_;
tsi_handshaker_result* handshaker_result_ = nullptr;
size_t max_frame_size_ = 0;
std::string tsi_handshake_error_;
grpc_closure* on_peer_checked_ ABSL_GUARDED_BY(mu_) = nullptr;
};
SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
@ -143,10 +141,7 @@ SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
handshake_buffer_(
static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))),
max_frame_size_(
std::max(0, args.GetInt(GRPC_ARG_TSI_MAX_FRAME_SIZE).value_or(0))) {
GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn,
this, grpc_schedule_on_exec_ctx);
}
std::max(0, args.GetInt(GRPC_ARG_TSI_MAX_FRAME_SIZE).value_or(0))) {}
SecurityHandshaker::~SecurityHandshaker() {
tsi_handshaker_destroy(handshaker_);
@ -220,8 +215,9 @@ MakeChannelzSecurityFromAuthContext(grpc_auth_context* auth_context) {
} // namespace
void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) {
void SecurityHandshaker::OnPeerCheckedFn(grpc_error_handle error) {
MutexLock lock(&mu_);
on_peer_checked_ = nullptr;
if (!error.ok() || is_shutdown_) {
HandshakeFailedLocked(error);
return;
@ -317,11 +313,6 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) {
Finish(absl::OkStatus());
}
void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error_handle error) {
RefCountedPtr<SecurityHandshaker>(static_cast<SecurityHandshaker*>(arg))
->OnPeerCheckedInner(error);
}
grpc_error_handle SecurityHandshaker::CheckPeerLocked() {
tsi_peer peer;
tsi_result result =
@ -330,8 +321,12 @@ grpc_error_handle SecurityHandshaker::CheckPeerLocked() {
return GRPC_ERROR_CREATE(absl::StrCat("Peer extraction failed (",
tsi_result_to_string(result), ")"));
}
on_peer_checked_ = NewClosure(
[self = RefAsSubclass<SecurityHandshaker>()](absl::Status status) {
self->OnPeerCheckedFn(std::move(status));
});
connector_->check_peer(peer, args_->endpoint.get(), args_->args,
&auth_context_, &on_peer_checked_);
&auth_context_, on_peer_checked_);
grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name(
auth_context_.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME);
const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
@ -356,10 +351,10 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked(
CHECK_EQ(bytes_to_send_size, 0u);
grpc_endpoint_read(
args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
GRPC_CLOSURE_INIT(
&on_handshake_data_received_from_peer_,
&SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
this, grpc_schedule_on_exec_ctx),
NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
absl::Status status) {
self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
}),
/*urgent=*/true, /*min_progress_size=*/1);
return error;
}
@ -387,19 +382,19 @@ grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked(
reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size));
grpc_endpoint_write(
args_->endpoint.get(), outgoing_.c_slice_buffer(),
GRPC_CLOSURE_INIT(
&on_handshake_data_sent_to_peer_,
&SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler, this,
grpc_schedule_on_exec_ctx),
NewClosure(
[self = RefAsSubclass<SecurityHandshaker>()](absl::Status status) {
self->OnHandshakeDataSentToPeerFnScheduler(std::move(status));
}),
nullptr, /*max_frame_size=*/INT_MAX);
} else if (handshaker_result == nullptr) {
// There is nothing to send, but need to read from peer.
grpc_endpoint_read(
args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
GRPC_CLOSURE_INIT(
&on_handshake_data_received_from_peer_,
&SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
this, grpc_schedule_on_exec_ctx),
NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
absl::Status status) {
self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
}),
/*urgent=*/true, /*min_progress_size=*/1);
} else {
// Handshake has finished, check peer and so on.
@ -418,8 +413,6 @@ void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper(
result, bytes_to_send, bytes_to_send_size, handshaker_result);
if (!error.ok()) {
h->HandshakeFailedLocked(std::move(error));
} else {
h.release(); // Avoid unref
}
}
@ -429,13 +422,15 @@ grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked(
const unsigned char* bytes_to_send = nullptr;
size_t bytes_to_send_size = 0;
tsi_handshaker_result* hs_result = nullptr;
auto self = RefAsSubclass<SecurityHandshaker>();
tsi_result result = tsi_handshaker_next(
handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
&bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this,
&tsi_handshake_error_);
&bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper,
self.get(), &tsi_handshake_error_);
if (result == TSI_ASYNC) {
// Handshaker operating asynchronously. Nothing else to do here;
// callback will be invoked in a TSI thread.
// Handshaker operating asynchronously. Callback will be invoked in a TSI
// thread. We no longer own the ref held in self.
self.release();
return absl::OkStatus();
}
// Handshaker returned synchronously. Invoke callback directly in
@ -449,18 +444,18 @@ grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked(
// TODO(roth): This will no longer be necessary once we migrate to the
// EventEngine endpoint API.
void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler(
void* arg, grpc_error_handle error) {
SecurityHandshaker* handshaker = static_cast<SecurityHandshaker*>(arg);
handshaker->args_->event_engine->Run(
[handshaker, error = std::move(error)]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
handshaker->OnHandshakeDataReceivedFromPeerFn(std::move(error));
});
grpc_error_handle error) {
args_->event_engine->Run([self = RefAsSubclass<SecurityHandshaker>(),
error = std::move(error)]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
self->OnHandshakeDataReceivedFromPeerFn(std::move(error));
// Avoid destruction outside of an ExecCtx (since this is non-cancelable).
self.reset();
});
}
void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) {
RefCountedPtr<SecurityHandshaker> handshaker(this);
MutexLock lock(&mu_);
if (!error.ok() || is_shutdown_) {
HandshakeFailedLocked(
@ -473,8 +468,6 @@ void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) {
error = DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
if (!error.ok()) {
HandshakeFailedLocked(std::move(error));
} else {
handshaker.release(); // Avoid unref
}
}
@ -483,18 +476,18 @@ void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) {
// TODO(roth): This will no longer be necessary once we migrate to the
// EventEngine endpoint API.
void SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler(
void* arg, grpc_error_handle error) {
SecurityHandshaker* handshaker = static_cast<SecurityHandshaker*>(arg);
handshaker->args_->event_engine->Run(
[handshaker, error = std::move(error)]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
handshaker->OnHandshakeDataSentToPeerFn(std::move(error));
});
grpc_error_handle error) {
args_->event_engine->Run([self = RefAsSubclass<SecurityHandshaker>(),
error = std::move(error)]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
self->OnHandshakeDataSentToPeerFn(std::move(error));
// Avoid destruction outside of an ExecCtx (since this is non-cancelable).
self.reset();
});
}
void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) {
RefCountedPtr<SecurityHandshaker> handshaker(this);
MutexLock lock(&mu_);
if (!error.ok() || is_shutdown_) {
HandshakeFailedLocked(
@ -505,10 +498,10 @@ void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) {
if (handshaker_result_ == nullptr) {
grpc_endpoint_read(
args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
GRPC_CLOSURE_INIT(
&on_handshake_data_received_from_peer_,
&SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
this, grpc_schedule_on_exec_ctx),
NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
absl::Status status) {
self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
}),
/*urgent=*/true, /*min_progress_size=*/1);
} else {
error = CheckPeerLocked();
@ -517,7 +510,6 @@ void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) {
return;
}
}
handshaker.release(); // Avoid unref
}
//
@ -528,7 +520,7 @@ void SecurityHandshaker::Shutdown(grpc_error_handle error) {
MutexLock lock(&mu_);
if (!is_shutdown_) {
is_shutdown_ = true;
connector_->cancel_check_peer(&on_peer_checked_, std::move(error));
connector_->cancel_check_peer(on_peer_checked_, std::move(error));
tsi_handshaker_shutdown(handshaker_);
args_->endpoint.reset();
}
@ -537,7 +529,6 @@ void SecurityHandshaker::Shutdown(grpc_error_handle error) {
void SecurityHandshaker::DoHandshake(
HandshakerArgs* args,
absl::AnyInvocable<void(absl::Status)> on_handshake_done) {
auto ref = Ref();
MutexLock lock(&mu_);
args_ = args;
on_handshake_done_ = std::move(on_handshake_done);
@ -546,8 +537,6 @@ void SecurityHandshaker::DoHandshake(
DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
if (!error.ok()) {
HandshakeFailedLocked(error);
} else {
ref.release(); // Avoid unref
}
}

Loading…
Cancel
Save