[TokenFetcherCredentials] fix backoff behavior (#38004)

As per discussion in https://github.com/grpc/proposal/pull/438#discussion_r1809068625.

Closes #38004

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/38004 from markdroth:token_fetcher_creds_backoff_fix 8de2fecb8f
PiperOrigin-RevId: 690678927
pull/37986/head
Mark D. Roth 4 months ago committed by Copybara-Service
parent 5c31076cc9
commit 88b5c9e3ab
  1. 51
      src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc
  2. 8
      src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h
  3. 102
      test/core/security/credentials_test.cc

@ -58,8 +58,8 @@ void TokenFetcherCredentials::Token::AddTokenToClientInitialMetadata(
// //
TokenFetcherCredentials::FetchState::BackoffTimer::BackoffTimer( TokenFetcherCredentials::FetchState::BackoffTimer::BackoffTimer(
RefCountedPtr<FetchState> fetch_state) RefCountedPtr<FetchState> fetch_state, absl::Status status)
: fetch_state_(std::move(fetch_state)) { : fetch_state_(std::move(fetch_state)), status_(status) {
const Duration delay = fetch_state_->backoff_.NextAttemptDelay(); const Duration delay = fetch_state_->backoff_.NextAttemptDelay();
GRPC_TRACE_LOG(token_fetcher_credentials, INFO) GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get() << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
@ -100,24 +100,13 @@ void TokenFetcherCredentials::FetchState::BackoffTimer::OnTimer() {
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get() << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": backoff timer fired"; << ": backoff timer fired";
if (fetch_state_->queued_calls_.empty()) { auto* self_ptr =
// If there are no pending calls when the timer fires, then orphan absl::get_if<OrphanablePtr<BackoffTimer>>(&fetch_state_->state_);
// the FetchState object. Note that this drops the backoff state, // This condition should always be true, but check to be defensive.
// but that's probably okay, because if we didn't have any pending if (self_ptr != nullptr && self_ptr->get() == this) {
// calls during the backoff period, we probably won't see any // Reset pointer in fetch_state_, so that subsequent RPCs know that
// immediately now either. // we're no longer in backoff and they can trigger a new fetch.
GRPC_TRACE_LOG(token_fetcher_credentials, INFO) self_ptr->reset();
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": no pending calls, clearing state";
fetch_state_->creds_->fetch_state_.reset();
} else {
// If there are pending calls, then start a new fetch attempt.
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": starting new fetch attempt";
fetch_state_->StartFetchAttempt();
} }
} }
@ -145,6 +134,14 @@ void TokenFetcherCredentials::FetchState::Orphan() {
Unref(); Unref();
} }
absl::Status TokenFetcherCredentials::FetchState::status() const {
auto* backoff_ptr = absl::get_if<OrphanablePtr<BackoffTimer>>(&state_);
if (backoff_ptr == nullptr || *backoff_ptr == nullptr) {
return absl::OkStatus();
}
return (*backoff_ptr)->status();
}
void TokenFetcherCredentials::FetchState::StartFetchAttempt() { void TokenFetcherCredentials::FetchState::StartFetchAttempt() {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO) GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << creds_.get() << "[TokenFetcherCredentials " << creds_.get()
@ -182,7 +179,8 @@ void TokenFetcherCredentials::FetchState::TokenFetchComplete(
<< "]: fetch_state=" << this << "]: fetch_state=" << this
<< ": token fetch failed: " << token.status(); << ": token fetch failed: " << token.status();
// If failed, start backoff timer. // If failed, start backoff timer.
state_ = OrphanablePtr<BackoffTimer>(new BackoffTimer(Ref())); state_ =
OrphanablePtr<BackoffTimer>(new BackoffTimer(Ref(), token.status()));
} }
ResumeQueuedCalls(std::move(token)); ResumeQueuedCalls(std::move(token));
} }
@ -204,7 +202,6 @@ void TokenFetcherCredentials::FetchState::ResumeQueuedCalls(
RefCountedPtr<TokenFetcherCredentials::QueuedCall> RefCountedPtr<TokenFetcherCredentials::QueuedCall>
TokenFetcherCredentials::FetchState::QueueCall( TokenFetcherCredentials::FetchState::QueueCall(
ClientMetadataHandle initial_metadata) { ClientMetadataHandle initial_metadata) {
// Add call to pending list.
auto queued_call = MakeRefCounted<QueuedCall>(); auto queued_call = MakeRefCounted<QueuedCall>();
queued_call->waker = GetContext<Activity>()->MakeNonOwningWaker(); queued_call->waker = GetContext<Activity>()->MakeNonOwningWaker();
queued_call->pollent = GetContext<grpc_polling_entity>(); queued_call->pollent = GetContext<grpc_polling_entity>();
@ -212,6 +209,11 @@ TokenFetcherCredentials::FetchState::QueueCall(
queued_call->pollent, grpc_polling_entity_pollset_set(&creds_->pollent_)); queued_call->pollent, grpc_polling_entity_pollset_set(&creds_->pollent_));
queued_call->md = std::move(initial_metadata); queued_call->md = std::move(initial_metadata);
queued_calls_.insert(queued_call); queued_calls_.insert(queued_call);
// If backoff has expired since the last attempt, trigger a new one.
auto* backoff_ptr = absl::get_if<OrphanablePtr<BackoffTimer>>(&state_);
if (backoff_ptr != nullptr && backoff_ptr->get() == nullptr) {
StartFetchAttempt();
}
return queued_call; return queued_call;
} }
@ -267,6 +269,11 @@ TokenFetcherCredentials::GetRequestMetadata(
token_->AddTokenToClientInitialMetadata(*initial_metadata); token_->AddTokenToClientInitialMetadata(*initial_metadata);
return Immediate(std::move(initial_metadata)); return Immediate(std::move(initial_metadata));
} }
// If we're in backoff, fail the call.
if (fetch_state_ != nullptr) {
absl::Status status = fetch_state_->status();
if (!status.ok()) return Immediate(std::move(status));
}
// If we don't have a cached token, this call will need to be queued. // If we don't have a cached token, this call will need to be queued.
GRPC_TRACE_LOG(token_fetcher_credentials, INFO) GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << this << "[TokenFetcherCredentials " << this

@ -111,13 +111,16 @@ class TokenFetcherCredentials : public grpc_call_credentials {
// annotations. // annotations.
void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS;
// Returns non-OK when we're in backoff.
absl::Status status() const;
RefCountedPtr<QueuedCall> QueueCall(ClientMetadataHandle initial_metadata) RefCountedPtr<QueuedCall> QueueCall(ClientMetadataHandle initial_metadata)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_);
private: private:
class BackoffTimer : public InternallyRefCounted<BackoffTimer> { class BackoffTimer : public InternallyRefCounted<BackoffTimer> {
public: public:
explicit BackoffTimer(RefCountedPtr<FetchState> fetch_state) BackoffTimer(RefCountedPtr<FetchState> fetch_state, absl::Status status)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_);
// Disabling thread safety annotations, since Orphan() is called // Disabling thread safety annotations, since Orphan() is called
@ -125,10 +128,13 @@ class TokenFetcherCredentials : public grpc_call_credentials {
// annotations. // annotations.
void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS;
absl::Status status() const { return status_; }
private: private:
void OnTimer(); void OnTimer();
RefCountedPtr<FetchState> fetch_state_; RefCountedPtr<FetchState> fetch_state_;
const absl::Status status_;
absl::optional<grpc_event_engine::experimental::EventEngine::TaskHandle> absl::optional<grpc_event_engine::experimental::EventEngine::TaskHandle>
timer_handle_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); timer_handle_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_);
}; };

@ -2569,8 +2569,9 @@ TEST_F(TokenFetcherCredentialsTest, FetchFails) {
run_after_duration = duration; run_after_duration = duration;
}); });
ExecCtx exec_ctx; ExecCtx exec_ctx;
creds_->AddResult(kExpectedError);
// First request will trigger a fetch, which will fail. // First request will trigger a fetch, which will fail.
LOG(INFO) << "Sending first RPC.";
creds_->AddResult(kExpectedError);
auto state = RequestMetadataState::NewInstance(kExpectedError, "", auto state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true); /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
@ -2578,22 +2579,25 @@ TEST_F(TokenFetcherCredentialsTest, FetchFails) {
EXPECT_EQ(creds_->num_fetches(), 1); EXPECT_EQ(creds_->num_fetches(), 1);
while (!run_after_duration.has_value()) event_engine_->Tick(); while (!run_after_duration.has_value()) event_engine_->Tick();
// Make sure backoff was set for the right period. // Make sure backoff was set for the right period.
// This is 1 second (initial backoff) minus 1ms for the tick needed above.
EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); EXPECT_EQ(run_after_duration, std::chrono::seconds(1));
run_after_duration.reset(); run_after_duration.reset();
// Start a new call now, which will be queued and then eventually // Start a new call now, which will fail because we're in backoff.
// resumed when the next fetch happens. LOG(INFO) << "Sending second RPC.";
state = RequestMetadataState::NewInstance( state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); kExpectedError, "authorization: foo", /*expect_delay=*/false);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath); kTestPath);
// Tick until the next fetch starts. EXPECT_EQ(creds_->num_fetches(), 1);
creds_->AddResult(MakeToken("foo")); // Tick until backoff expires.
LOG(INFO) << "Waiting for backoff.";
event_engine_->TickUntilIdle(); event_engine_->TickUntilIdle();
EXPECT_EQ(creds_->num_fetches(), 2); EXPECT_EQ(creds_->num_fetches(), 1);
// A call started now should use the new cached data. // Starting another call should trigger a new fetch, which will
// succeed this time.
LOG(INFO) << "Sending third RPC.";
creds_->AddResult(MakeToken("foo"));
state = RequestMetadataState::NewInstance( state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/false); absl::OkStatus(), "authorization: foo", /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath); kTestPath);
EXPECT_EQ(creds_->num_fetches(), 2); EXPECT_EQ(creds_->num_fetches(), 2);
@ -2607,8 +2611,9 @@ TEST_F(TokenFetcherCredentialsTest, Backoff) {
run_after_duration = duration; run_after_duration = duration;
}); });
ExecCtx exec_ctx; ExecCtx exec_ctx;
creds_->AddResult(kExpectedError);
// First request will trigger a fetch, which will fail. // First request will trigger a fetch, which will fail.
LOG(INFO) << "Sending first RPC.";
creds_->AddResult(kExpectedError);
auto state = RequestMetadataState::NewInstance(kExpectedError, "", auto state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true); /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
@ -2618,64 +2623,53 @@ TEST_F(TokenFetcherCredentialsTest, Backoff) {
// Make sure backoff was set for the right period. // Make sure backoff was set for the right period.
EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); EXPECT_EQ(run_after_duration, std::chrono::seconds(1));
run_after_duration.reset(); run_after_duration.reset();
// Start a new call now, which will be queued and then eventually // Start a new call now, which will fail because we're in backoff.
// resumed when the next fetch happens. LOG(INFO) << "Sending second RPC.";
state = RequestMetadataState::NewInstance(kExpectedError, "", state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true); /*expect_delay=*/false);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath); kTestPath);
// Tick until the next fetch fails and the backoff timer starts again. EXPECT_EQ(creds_->num_fetches(), 1);
// Tick until backoff expires.
LOG(INFO) << "Waiting for backoff.";
event_engine_->TickUntilIdle();
EXPECT_EQ(creds_->num_fetches(), 1);
// Starting another call should trigger a new fetch, which will again fail.
LOG(INFO) << "Sending third RPC.";
creds_->AddResult(kExpectedError); creds_->AddResult(kExpectedError);
while (!run_after_duration.has_value()) event_engine_->Tick();
EXPECT_EQ(creds_->num_fetches(), 2);
// The backoff time should be longer now. We account for jitter here.
EXPECT_EQ(run_after_duration, std::chrono::milliseconds(1600))
<< "actual: " << run_after_duration->count();
run_after_duration.reset();
// Start another new call to trigger another new fetch once the
// backoff expires.
state = RequestMetadataState::NewInstance(kExpectedError, "", state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true); /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath); kTestPath);
// Tick until the next fetch starts. EXPECT_EQ(creds_->num_fetches(), 2);
creds_->AddResult(kExpectedError);
while (!run_after_duration.has_value()) event_engine_->Tick(); while (!run_after_duration.has_value()) event_engine_->Tick();
EXPECT_EQ(creds_->num_fetches(), 3); // The backoff time should be longer now.
// Check backoff time again. EXPECT_EQ(run_after_duration, std::chrono::milliseconds(1600))
EXPECT_EQ(run_after_duration, std::chrono::milliseconds(2560))
<< "actual: " << run_after_duration->count(); << "actual: " << run_after_duration->count();
} run_after_duration.reset();
// Start a new call now, which will fail because we're in backoff.
TEST_F(TokenFetcherCredentialsTest, FetchNotStartedAfterBackoffWithoutRpc) { LOG(INFO) << "Sending fourth RPC.";
const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); state = RequestMetadataState::NewInstance(kExpectedError, "",
absl::optional<FuzzingEventEngine::Duration> run_after_duration; /*expect_delay=*/false);
event_engine_->SetRunAfterDurationCallback(
[&](FuzzingEventEngine::Duration duration) {
run_after_duration = duration;
});
ExecCtx exec_ctx;
creds_->AddResult(kExpectedError);
// First request will trigger a fetch, which will fail.
auto state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath); kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1); EXPECT_EQ(creds_->num_fetches(), 2);
while (!run_after_duration.has_value()) event_engine_->Tick(); // Tick until backoff expires.
// Make sure backoff was set for the right period. LOG(INFO) << "Waiting for backoff.";
EXPECT_EQ(run_after_duration, std::chrono::seconds(1));
run_after_duration.reset();
// Tick until the backoff expires. No new fetch should be started.
event_engine_->TickUntilIdle(); event_engine_->TickUntilIdle();
EXPECT_EQ(creds_->num_fetches(), 1); EXPECT_EQ(creds_->num_fetches(), 2);
// Now start a new request, which will trigger a new fetch. // Starting another call should trigger a new fetch, which will again fail.
creds_->AddResult(MakeToken("foo")); LOG(INFO) << "Sending fifth RPC.";
state = RequestMetadataState::NewInstance( creds_->AddResult(kExpectedError);
absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath); kTestPath);
EXPECT_EQ(creds_->num_fetches(), 2); EXPECT_EQ(creds_->num_fetches(), 3);
while (!run_after_duration.has_value()) event_engine_->Tick();
// The backoff time should be longer now.
EXPECT_EQ(run_after_duration, std::chrono::milliseconds(2560))
<< "actual: " << run_after_duration->count();
} }
TEST_F(TokenFetcherCredentialsTest, ShutdownWhileBackoffTimerPending) { TEST_F(TokenFetcherCredentialsTest, ShutdownWhileBackoffTimerPending) {

Loading…
Cancel
Save