Revert "[TokenFetcherCredentials] add backoff and pre-fetching (#37531)"

This reverts commit 158976a963.
pull/37567/head
Mark D. Roth 6 months ago
parent 66bb0ce119
commit 3701b84890
  1. 1
      doc/trace_flags.md
  2. 4
      src/core/BUILD
  3. 2
      src/core/lib/debug/trace_flags.cc
  4. 1
      src/core/lib/debug/trace_flags.h
  5. 3
      src/core/lib/debug/trace_flags.yaml
  6. 2
      src/core/lib/promise/map.h
  7. 6
      src/core/lib/security/credentials/external/external_account_credentials.cc
  8. 5
      src/core/lib/security/credentials/external/external_account_credentials.h
  9. 1
      src/core/lib/security/credentials/external/file_external_account_credentials.cc
  10. 284
      src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc
  11. 83
      src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h
  12. 335
      test/core/security/credentials_test.cc

1
doc/trace_flags.md generated

@ -73,7 +73,6 @@ processing requests via debug logs. Available tracers include:
- tcp - Bytes in and out of a channel.
- timer - Timers (alarms) in the grpc internals.
- timer_check - more detailed trace of timer logic in grpc internals.
- token_fetcher_credentials - Token fetcher call credentials framework, used for (e.g.) oauth2 token fetcher credentials.
- tsi - TSI transport security.
- weighted_round_robin_lb - Weighted round robin load balancing policy.
- weighted_target_lb - Weighted target LB policy.

@ -4335,17 +4335,14 @@ grpc_cc_library(
deps = [
"arena_promise",
"context",
"default_event_engine",
"metadata",
"poll",
"pollset_set",
"ref_counted",
"time",
"useful",
"//:backoff",
"//:gpr",
"//:grpc_security_base",
"//:grpc_trace",
"//:httpcli",
"//:iomgr",
"//:orphanable",
@ -4439,6 +4436,7 @@ grpc_cc_library(
language = "c++",
deps = [
"closure",
"default_event_engine",
"env",
"error",
"error_utils",

@ -114,7 +114,6 @@ TraceFlag subchannel_pool_trace(false, "subchannel_pool");
TraceFlag tcp_trace(false, "tcp");
TraceFlag timer_trace(false, "timer");
TraceFlag timer_check_trace(false, "timer_check");
TraceFlag token_fetcher_credentials_trace(false, "token_fetcher_credentials");
TraceFlag tsi_trace(false, "tsi");
TraceFlag weighted_round_robin_lb_trace(false, "weighted_round_robin_lb");
TraceFlag weighted_target_lb_trace(false, "weighted_target_lb");
@ -207,7 +206,6 @@ const absl::flat_hash_map<std::string, TraceFlag*>& GetAllTraceFlags() {
{"tcp", &tcp_trace},
{"timer", &timer_trace},
{"timer_check", &timer_check_trace},
{"token_fetcher_credentials", &token_fetcher_credentials_trace},
{"tsi", &tsi_trace},
{"weighted_round_robin_lb", &weighted_round_robin_lb_trace},
{"weighted_target_lb", &weighted_target_lb_trace},

@ -112,7 +112,6 @@ extern TraceFlag subchannel_pool_trace;
extern TraceFlag tcp_trace;
extern TraceFlag timer_trace;
extern TraceFlag timer_check_trace;
extern TraceFlag token_fetcher_credentials_trace;
extern TraceFlag tsi_trace;
extern TraceFlag weighted_round_robin_lb_trace;
extern TraceFlag weighted_target_lb_trace;

@ -308,9 +308,6 @@ timer:
timer_check:
default: false
description: more detailed trace of timer logic in grpc internals.
token_fetcher_credentials:
default: false
description: Token fetcher call credentials framework, used for (e.g.) oauth2 token fetcher credentials.
tsi:
default: false
description: TSI transport security.

@ -86,7 +86,7 @@ GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION auto CheckDelayed(Promise promise) {
delayed = true;
return Pending{};
}
return std::make_tuple(std::move(r.value()), delayed);
return std::make_tuple(r.value(), delayed);
};
}

@ -45,6 +45,7 @@
#include <grpc/support/port_platform.h>
#include <grpc/support/string_util.h>
#include "src/core/lib/event_engine/default_event_engine.h"
#include "src/core/lib/gprpp/status_helper.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/external/aws_external_account_credentials.h"
@ -590,7 +591,10 @@ ExternalAccountCredentials::Create(
ExternalAccountCredentials::ExternalAccountCredentials(
Options options, std::vector<std::string> scopes,
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine)
: TokenFetcherCredentials(std::move(event_engine)),
: event_engine_(
event_engine == nullptr
? grpc_event_engine::experimental::GetDefaultEventEngine()
: std::move(event_engine)),
options_(std::move(options)) {
if (scopes.empty()) {
scopes.push_back(GOOGLE_CLOUD_PLATFORM_DEFAULT_SCOPE);

@ -185,6 +185,10 @@ class ExternalAccountCredentials : public TokenFetcherCredentials {
absl::string_view audience() const { return options_.audience; }
grpc_event_engine::experimental::EventEngine& event_engine() const {
return *event_engine_;
}
private:
OrphanablePtr<FetchRequest> FetchToken(
Timestamp deadline,
@ -200,6 +204,7 @@ class ExternalAccountCredentials : public TokenFetcherCredentials {
Timestamp deadline,
absl::AnyInvocable<void(absl::StatusOr<std::string>)> on_done) = 0;
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_;
Options options_;
std::vector<std::string> scopes_;
};

@ -26,6 +26,7 @@
#include <grpc/support/json.h>
#include <grpc/support/port_platform.h>
#include "src/core/lib/event_engine/default_event_engine.h"
#include "src/core/lib/gprpp/load_file.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/slice/slice_internal.h"

@ -18,8 +18,6 @@
#include "src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h"
#include "src/core/lib/debug/trace.h"
#include "src/core/lib/event_engine/default_event_engine.h"
#include "src/core/lib/iomgr/pollset_set.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/poll.h"
@ -30,11 +28,8 @@ namespace grpc_core {
namespace {
// Amount of time before the token's expiration that we consider it
// invalid to account for server processing time and clock skew.
constexpr Duration kTokenExpirationAdjustmentDuration = Duration::Seconds(30);
// Amount of time before the token's expiration that we pre-fetch a new
// token. Also determines the timeout for the fetch request.
// invalid and start a new fetch. Also determines the timeout for the
// fetch request.
constexpr Duration kTokenRefreshDuration = Duration::Seconds(60);
} // namespace
@ -43,193 +38,18 @@ constexpr Duration kTokenRefreshDuration = Duration::Seconds(60);
// TokenFetcherCredentials::Token
//
TokenFetcherCredentials::Token::Token(Slice token, Timestamp expiration)
: token_(std::move(token)),
expiration_(expiration - kTokenExpirationAdjustmentDuration) {}
void TokenFetcherCredentials::Token::AddTokenToClientInitialMetadata(
ClientMetadata& metadata) const {
metadata.Append(GRPC_AUTHORIZATION_METADATA_KEY, token_.Ref(),
[](absl::string_view, const Slice&) { abort(); });
}
//
// TokenFetcherCredentials::FetchState::BackoffTimer
//
TokenFetcherCredentials::FetchState::BackoffTimer::BackoffTimer(
RefCountedPtr<FetchState> fetch_state)
: fetch_state_(std::move(fetch_state)) {
const Timestamp next_attempt_time = fetch_state_->backoff_.NextAttemptTime();
const Duration duration = next_attempt_time - Timestamp::Now();
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": starting backoff timer for " << next_attempt_time << " ("
<< duration << " from now)";
timer_handle_ = fetch_state_->creds_->event_engine().RunAfter(
duration, [self = Ref()]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
self->OnTimer();
self.reset();
});
}
void TokenFetcherCredentials::FetchState::BackoffTimer::Orphan() {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": backoff timer shut down";
if (timer_handle_.has_value()) {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": cancelling timer";
fetch_state_->creds_->event_engine().Cancel(*timer_handle_);
timer_handle_.reset();
fetch_state_->ResumeQueuedCalls(
absl::CancelledError("credentials shutdown"));
}
Unref();
}
void TokenFetcherCredentials::FetchState::BackoffTimer::OnTimer() {
MutexLock lock(&fetch_state_->creds_->mu_);
if (!timer_handle_.has_value()) return;
timer_handle_.reset();
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": backoff timer fired";
if (fetch_state_->queued_calls_.empty()) {
// If there are no pending calls when the timer fires, then orphan
// the FetchState object. Note that this drops the backoff state,
// but that's probably okay, because if we didn't have any pending
// calls during the backoff period, we probably won't see any
// immediately now either.
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": no pending calls, clearing state";
fetch_state_->creds_->fetch_state_.reset();
} else {
// If there are pending calls, then start a new fetch attempt.
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << fetch_state_->creds_.get()
<< "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
<< ": starting new fetch attempt";
fetch_state_->StartFetchAttempt();
}
}
//
// TokenFetcherCredentials::FetchState
//
TokenFetcherCredentials::FetchState::FetchState(
WeakRefCountedPtr<TokenFetcherCredentials> creds)
: creds_(std::move(creds)),
backoff_(BackOff::Options()
.set_initial_backoff(Duration::Seconds(1))
.set_multiplier(1.6)
.set_jitter(creds_->test_only_use_backoff_jitter_ ? 0.2 : 0)
.set_max_backoff(Duration::Seconds(120))) {
StartFetchAttempt();
}
void TokenFetcherCredentials::FetchState::Orphan() {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << creds_.get()
<< "]: fetch_state=" << this << ": shutting down";
// Cancels fetch or backoff timer, if any.
state_ = Shutdown{};
Unref();
}
void TokenFetcherCredentials::FetchState::StartFetchAttempt() {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << creds_.get()
<< "]: fetch_state=" << this << ": starting fetch";
state_ = creds_->FetchToken(
/*deadline=*/Timestamp::Now() + kTokenRefreshDuration,
[self = Ref()](absl::StatusOr<RefCountedPtr<Token>> token) mutable {
self->TokenFetchComplete(std::move(token));
});
}
void TokenFetcherCredentials::FetchState::TokenFetchComplete(
absl::StatusOr<RefCountedPtr<Token>> token) {
MutexLock lock(&creds_->mu_);
// If we were shut down, clean up.
if (absl::holds_alternative<Shutdown>(state_)) {
if (token.ok()) token = absl::CancelledError("credentials shutdown");
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << creds_.get()
<< "]: fetch_state=" << this
<< ": shut down before fetch completed: " << token.status();
ResumeQueuedCalls(std::move(token));
return;
}
// If succeeded, update cache in creds object.
if (token.ok()) {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << creds_.get()
<< "]: fetch_state=" << this << ": token fetch succeeded";
creds_->token_ = *token;
creds_->fetch_state_.reset(); // Orphan ourselves.
} else {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << creds_.get()
<< "]: fetch_state=" << this
<< ": token fetch failed: " << token.status();
// If failed, start backoff timer.
state_ = OrphanablePtr<BackoffTimer>(new BackoffTimer(Ref()));
}
ResumeQueuedCalls(std::move(token));
}
void TokenFetcherCredentials::FetchState::ResumeQueuedCalls(
absl::StatusOr<RefCountedPtr<Token>> token) {
// Invoke callbacks for all pending requests.
for (auto& queued_call : queued_calls_) {
queued_call->result = token;
queued_call->done.store(true, std::memory_order_release);
queued_call->waker.Wakeup();
grpc_polling_entity_del_from_pollset_set(
queued_call->pollent,
grpc_polling_entity_pollset_set(&creds_->pollent_));
}
queued_calls_.clear();
}
RefCountedPtr<TokenFetcherCredentials::QueuedCall>
TokenFetcherCredentials::FetchState::QueueCall(
ClientMetadataHandle initial_metadata) {
// Add call to pending list.
auto queued_call = MakeRefCounted<QueuedCall>();
queued_call->waker = GetContext<Activity>()->MakeNonOwningWaker();
queued_call->pollent = GetContext<grpc_polling_entity>();
grpc_polling_entity_add_to_pollset_set(
queued_call->pollent, grpc_polling_entity_pollset_set(&creds_->pollent_));
queued_call->md = std::move(initial_metadata);
queued_calls_.insert(queued_call);
return queued_call;
}
//
// TokenFetcherCredentials
//
TokenFetcherCredentials::TokenFetcherCredentials(
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine,
bool test_only_use_backoff_jitter)
: event_engine_(
event_engine == nullptr
? grpc_event_engine::experimental::GetDefaultEventEngine()
: std::move(event_engine)),
test_only_use_backoff_jitter_(test_only_use_backoff_jitter),
pollent_(grpc_polling_entity_create_from_pollset_set(
TokenFetcherCredentials::TokenFetcherCredentials()
: pollent_(grpc_polling_entity_create_from_pollset_set(
grpc_pollset_set_create())) {}
TokenFetcherCredentials::~TokenFetcherCredentials() {
@ -238,63 +58,73 @@ TokenFetcherCredentials::~TokenFetcherCredentials() {
void TokenFetcherCredentials::Orphaned() {
MutexLock lock(&mu_);
fetch_state_.reset();
auto* fetch_request = absl::get_if<OrphanablePtr<FetchRequest>>(&token_);
if (fetch_request != nullptr) fetch_request->reset();
}
ArenaPromise<absl::StatusOr<ClientMetadataHandle>>
TokenFetcherCredentials::GetRequestMetadata(
ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs*) {
RefCountedPtr<QueuedCall> queued_call;
RefCountedPtr<PendingCall> pending_call;
{
MutexLock lock(&mu_);
// If we don't have a cached token or the token is within the
// refresh duration, start a new fetch if there isn't a pending one.
if ((token_ == nullptr || (token_->ExpirationTime() - Timestamp::Now()) <=
kTokenRefreshDuration) &&
fetch_state_ == nullptr) {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << this
<< "]: " << GetContext<Activity>()->DebugTag()
<< " triggering new token fetch";
fetch_state_ = OrphanablePtr<FetchState>(
new FetchState(WeakRefAsSubclass<TokenFetcherCredentials>()));
}
// If we have a cached non-expired token, use it.
if (token_ != nullptr &&
(token_->ExpirationTime() - Timestamp::Now()) > Duration::Zero()) {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << this
<< "]: " << GetContext<Activity>()->DebugTag()
<< " using cached token";
token_->AddTokenToClientInitialMetadata(*initial_metadata);
// Check if we can use the cached token.
auto* cached_token = absl::get_if<RefCountedPtr<Token>>(&token_);
if (cached_token != nullptr && *cached_token != nullptr &&
((*cached_token)->ExpirationTime() - Timestamp::Now()) >
kTokenRefreshDuration) {
(*cached_token)->AddTokenToClientInitialMetadata(*initial_metadata);
return Immediate(std::move(initial_metadata));
}
// If we don't have a cached token, this call will need to be queued.
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << this
<< "]: " << GetContext<Activity>()->DebugTag()
<< " no cached token; queuing call";
queued_call = fetch_state_->QueueCall(std::move(initial_metadata));
// Couldn't get the token from the cache.
// Add this call to the pending list.
pending_call = MakeRefCounted<PendingCall>();
pending_call->waker = GetContext<Activity>()->MakeNonOwningWaker();
pending_call->pollent = GetContext<grpc_polling_entity>();
grpc_polling_entity_add_to_pollset_set(
pending_call->pollent, grpc_polling_entity_pollset_set(&pollent_));
pending_call->md = std::move(initial_metadata);
pending_calls_.insert(pending_call);
// Start a new fetch if needed.
if (!absl::holds_alternative<OrphanablePtr<FetchRequest>>(token_)) {
token_ = FetchToken(
/*deadline=*/Timestamp::Now() + kTokenRefreshDuration,
[self = WeakRefAsSubclass<TokenFetcherCredentials>()](
absl::StatusOr<RefCountedPtr<Token>> token) mutable {
self->TokenFetchComplete(std::move(token));
});
}
}
return [this, queued_call = std::move(queued_call)]()
-> Poll<absl::StatusOr<ClientMetadataHandle>> {
if (!queued_call->done.load(std::memory_order_acquire)) {
return [pending_call = std::move(
pending_call)]() -> Poll<absl::StatusOr<ClientMetadataHandle>> {
if (!pending_call->done.load(std::memory_order_acquire)) {
return Pending{};
}
if (!queued_call->result.ok()) {
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << this
<< "]: " << GetContext<Activity>()->DebugTag()
<< " token fetch failed; failing call";
return queued_call->result.status();
if (!pending_call->result.ok()) {
return pending_call->result.status();
}
GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
<< "[TokenFetcherCredentials " << this
<< "]: " << GetContext<Activity>()->DebugTag()
<< " token fetch complete; resuming call";
(*queued_call->result)->AddTokenToClientInitialMetadata(*queued_call->md);
return std::move(queued_call->md);
(*pending_call->result)->AddTokenToClientInitialMetadata(*pending_call->md);
return std::move(pending_call->md);
};
}
void TokenFetcherCredentials::TokenFetchComplete(
absl::StatusOr<RefCountedPtr<Token>> token) {
// Update cache and grab list of pending requests.
absl::flat_hash_set<RefCountedPtr<PendingCall>> pending_calls;
{
MutexLock lock(&mu_);
token_ = token.value_or(nullptr);
pending_calls_.swap(pending_calls);
}
// Invoke callbacks for all pending requests.
for (auto& pending_call : pending_calls) {
pending_call->result = token;
pending_call->done.store(true, std::memory_order_release);
pending_call->waker.Wakeup();
grpc_polling_entity_del_from_pollset_set(
pending_call->pollent, grpc_polling_entity_pollset_set(&pollent_));
}
}
} // namespace grpc_core

@ -26,9 +26,6 @@
#include "absl/status/statusor.h"
#include "absl/types/variant.h"
#include <grpc/event_engine/event_engine.h>
#include "src/core/lib/backoff/backoff.h"
#include "src/core/lib/gprpp/orphanable.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
@ -50,7 +47,8 @@ class TokenFetcherCredentials : public grpc_call_credentials {
// Represents a token.
class Token : public RefCounted<Token> {
public:
Token(Slice token, Timestamp expiration);
Token(Slice token, Timestamp expiration)
: token_(std::move(token)), expiration_(expiration) {}
// Returns the token's expiration time.
Timestamp ExpirationTime() const { return expiration_; }
@ -75,10 +73,7 @@ class TokenFetcherCredentials : public grpc_call_credentials {
// Base class for fetch requests.
class FetchRequest : public InternallyRefCounted<FetchRequest> {};
explicit TokenFetcherCredentials(
std::shared_ptr<grpc_event_engine::experimental::EventEngine>
event_engine = nullptr,
bool test_only_use_backoff_jitter = true);
TokenFetcherCredentials();
// Fetches a token. The on_done callback will be invoked when complete.
virtual OrphanablePtr<FetchRequest> FetchToken(
@ -86,15 +81,11 @@ class TokenFetcherCredentials : public grpc_call_credentials {
absl::AnyInvocable<void(absl::StatusOr<RefCountedPtr<Token>>)>
on_done) = 0;
grpc_event_engine::experimental::EventEngine& event_engine() const {
return *event_engine_;
}
grpc_polling_entity* pollent() { return &pollent_; }
private:
// A call that is waiting for a token fetch request to complete.
struct QueuedCall : public RefCounted<QueuedCall> {
struct PendingCall : public RefCounted<PendingCall> {
std::atomic<bool> done{false};
Waker waker;
grpc_polling_entity* pollent;
@ -102,72 +93,20 @@ class TokenFetcherCredentials : public grpc_call_credentials {
absl::StatusOr<RefCountedPtr<Token>> result;
};
class FetchState : public InternallyRefCounted<FetchState> {
public:
explicit FetchState(WeakRefCountedPtr<TokenFetcherCredentials> creds)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_);
// Disabling thread safety annotations, since Orphan() is called
// by OrpahanablePtr<>, which does not have the right lock
// annotations.
void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS;
RefCountedPtr<QueuedCall> QueueCall(ClientMetadataHandle initial_metadata)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_);
private:
class BackoffTimer : public InternallyRefCounted<BackoffTimer> {
public:
explicit BackoffTimer(RefCountedPtr<FetchState> fetch_state)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_);
// Disabling thread safety annotations, since Orphan() is called
// by OrpahanablePtr<>, which does not have the right lock
// annotations.
void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS;
private:
void OnTimer();
RefCountedPtr<FetchState> fetch_state_;
absl::optional<grpc_event_engine::experimental::EventEngine::TaskHandle>
timer_handle_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_);
};
struct Shutdown {};
void StartFetchAttempt()
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_);
void TokenFetchComplete(absl::StatusOr<RefCountedPtr<Token>> token);
void ResumeQueuedCalls(absl::StatusOr<RefCountedPtr<Token>> token)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_);
WeakRefCountedPtr<TokenFetcherCredentials> creds_;
// Pending token-fetch request or backoff timer, if any.
absl::variant<OrphanablePtr<FetchRequest>, OrphanablePtr<BackoffTimer>,
Shutdown>
state_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_);
// Calls that are queued up waiting for the token.
absl::flat_hash_set<RefCountedPtr<QueuedCall>> queued_calls_
ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_);
// Backoff state.
BackOff backoff_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_);
};
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return QsortCompare(static_cast<const grpc_call_credentials*>(this), other);
}
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_;
const bool test_only_use_backoff_jitter_;
void TokenFetchComplete(absl::StatusOr<RefCountedPtr<Token>> token);
Mutex mu_;
// Cached token, if any.
RefCountedPtr<Token> token_ ABSL_GUARDED_BY(&mu_);
// Fetch state, if any.
OrphanablePtr<FetchState> fetch_state_ ABSL_GUARDED_BY(&mu_);
// Either the cached token or a pending request to fetch the token.
absl::variant<RefCountedPtr<Token>, OrphanablePtr<FetchRequest>> token_
ABSL_GUARDED_BY(&mu_);
// Calls that are queued up waiting for the token.
absl::flat_hash_set<RefCountedPtr<PendingCall>> pending_calls_
ABSL_GUARDED_BY(&mu_);
grpc_polling_entity pollent_ ABSL_GUARDED_BY(&mu_);
};

@ -50,7 +50,6 @@
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/iomgr/timer_manager.h"
#include "src/core/lib/promise/exec_ctx_wakeup_scheduler.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/security/context/security_context.h"
@ -80,7 +79,6 @@
namespace grpc_core {
using grpc_event_engine::experimental::FuzzingEventEngine;
using internal::grpc_flush_cached_google_default_credentials;
using internal::set_gce_tenancy_checker_for_testing;
@ -428,19 +426,16 @@ TEST_F(CredentialsTest,
class RequestMetadataState : public RefCounted<RequestMetadataState> {
public:
static RefCountedPtr<RequestMetadataState> NewInstance(
grpc_error_handle expected_error, std::string expected,
absl::optional<bool> expect_delay = absl::nullopt) {
grpc_error_handle expected_error, std::string expected) {
return MakeRefCounted<RequestMetadataState>(
expected_error, std::move(expected), expect_delay,
expected_error, std::move(expected),
grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create()));
}
RequestMetadataState(grpc_error_handle expected_error, std::string expected,
absl::optional<bool> expect_delay,
grpc_polling_entity pollent)
: expected_error_(expected_error),
expected_(std::move(expected)),
expect_delay_(expect_delay),
pollent_(pollent) {}
~RequestMetadataState() override {
@ -458,18 +453,12 @@ class RequestMetadataState : public RefCounted<RequestMetadataState> {
activity_ = MakeActivity(
[this, creds] {
return Seq(
CheckDelayed(creds->GetRequestMetadata(
creds->GetRequestMetadata(
ClientMetadataHandle(&md_, Arena::PooledDeleter(nullptr)),
&get_request_metadata_args_)),
[this](std::tuple<absl::StatusOr<ClientMetadataHandle>, bool>
metadata_and_delayed) {
auto& metadata = std::get<0>(metadata_and_delayed);
const bool delayed = std::get<1>(metadata_and_delayed);
if (expect_delay_.has_value()) {
EXPECT_EQ(delayed, *expect_delay_);
}
&get_request_metadata_args_),
[this](absl::StatusOr<ClientMetadataHandle> metadata) {
if (metadata.ok()) {
EXPECT_EQ(metadata->get(), &md_);
CHECK(metadata->get() == &md_);
}
return metadata.status();
});
@ -534,7 +523,6 @@ class RequestMetadataState : public RefCounted<RequestMetadataState> {
grpc_error_handle expected_error_;
std::string expected_;
absl::optional<bool> expect_delay_;
RefCountedPtr<Arena> arena_ = SimpleArenaAllocator()->MakeArena();
grpc_metadata_batch md_;
grpc_call_credentials::GetRequestMetadataArgs get_request_metadata_args_;
@ -2381,315 +2369,6 @@ int aws_external_account_creds_httpcli_post_success(
return 1;
}
class TokenFetcherCredentialsTest : public ::testing::Test {
protected:
class TestTokenFetcherCredentials final : public TokenFetcherCredentials {
public:
explicit TestTokenFetcherCredentials(
std::shared_ptr<grpc_event_engine::experimental::EventEngine>
event_engine = nullptr)
: TokenFetcherCredentials(std::move(event_engine),
/*test_only_use_backoff_jitter=*/false) {}
~TestTokenFetcherCredentials() override { CHECK_EQ(queue_.size(), 0); }
void AddResult(absl::StatusOr<RefCountedPtr<Token>> result) {
MutexLock lock(&mu_);
queue_.push_front(std::move(result));
}
size_t num_fetches() const { return num_fetches_; }
private:
class TestFetchRequest final : public FetchRequest {
public:
TestFetchRequest(
grpc_event_engine::experimental::EventEngine& event_engine,
absl::AnyInvocable<void(absl::StatusOr<RefCountedPtr<Token>>)>
on_done,
absl::StatusOr<RefCountedPtr<Token>> result) {
event_engine.Run([on_done = std::move(on_done),
result = std::move(result)]() mutable {
ApplicationCallbackExecCtx application_exec_ctx;
ExecCtx exec_ctx;
std::exchange(on_done, nullptr)(std::move(result));
});
}
void Orphan() override { Unref(); }
};
OrphanablePtr<FetchRequest> FetchToken(
Timestamp deadline,
absl::AnyInvocable<void(absl::StatusOr<RefCountedPtr<Token>>)> on_done)
override {
absl::StatusOr<RefCountedPtr<Token>> result;
{
MutexLock lock(&mu_);
CHECK(!queue_.empty());
result = std::move(queue_.back());
queue_.pop_back();
}
num_fetches_.fetch_add(1);
return MakeOrphanable<TestFetchRequest>(
event_engine(), std::move(on_done), std::move(result));
}
std::string debug_string() override {
return "TestTokenFetcherCredentials";
}
UniqueTypeName type() const override {
static UniqueTypeName::Factory kFactory("TestTokenFetcherCredentials");
return kFactory.Create();
}
Mutex mu_;
std::deque<absl::StatusOr<RefCountedPtr<Token>>> queue_
ABSL_GUARDED_BY(&mu_);
std::atomic<size_t> num_fetches_{0};
};
void SetUp() override {
grpc_timer_manager_set_start_threaded(false);
grpc_init();
}
void TearDown() override {
event_engine_->FuzzingDone();
event_engine_->TickUntilIdle();
event_engine_->UnsetGlobalHooks();
creds_.reset();
grpc_event_engine::experimental::WaitForSingleOwner(
std::move(event_engine_));
grpc_shutdown_blocking();
}
static RefCountedPtr<TokenFetcherCredentials::Token> MakeToken(
absl::string_view token, Timestamp expiration = Timestamp::InfFuture()) {
return MakeRefCounted<TokenFetcherCredentials::Token>(
Slice::FromCopiedString(token), expiration);
}
std::shared_ptr<FuzzingEventEngine> event_engine_ =
std::make_shared<FuzzingEventEngine>(FuzzingEventEngine::Options(),
fuzzing_event_engine::Actions());
RefCountedPtr<TestTokenFetcherCredentials> creds_ =
MakeRefCounted<TestTokenFetcherCredentials>(event_engine_);
};
TEST_F(TokenFetcherCredentialsTest, Basic) {
const auto kExpirationTime = Timestamp::Now() + Duration::Hours(1);
ExecCtx exec_ctx;
creds_->AddResult(MakeToken("foo", kExpirationTime));
// First request will trigger a fetch.
auto state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1);
// Second request while fetch is still outstanding will be delayed but
// will not trigger a new fetch.
state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1);
// Now tick to finish the fetch.
event_engine_->TickUntilIdle();
// Next request will be served from cache with no delay.
state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/false);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1);
// Advance time to expiration minus expiration adjustment and prefetch time.
exec_ctx.TestOnlySetNow(kExpirationTime - Duration::Seconds(90));
// No new fetch yet.
EXPECT_EQ(creds_->num_fetches(), 1);
// Next request will trigger a new fetch but will still use the
// cached token.
creds_->AddResult(MakeToken("bar"));
state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/false);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 2);
event_engine_->TickUntilIdle();
// Next request will use the new data.
state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: bar", /*expect_delay=*/false);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 2);
}
TEST_F(TokenFetcherCredentialsTest, Expires30SecondsEarly) {
const auto kExpirationTime = Timestamp::Now() + Duration::Hours(1);
ExecCtx exec_ctx;
creds_->AddResult(MakeToken("foo", kExpirationTime));
// First request will trigger a fetch.
auto state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1);
event_engine_->TickUntilIdle();
// Advance time to expiration minus 30 seconds.
exec_ctx.TestOnlySetNow(kExpirationTime - Duration::Seconds(30));
// No new fetch yet.
EXPECT_EQ(creds_->num_fetches(), 1);
// Next request will trigger a new fetch and will delay the call until
// the fetch completes.
creds_->AddResult(MakeToken("bar"));
state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: bar", /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 2);
event_engine_->TickUntilIdle();
}
TEST_F(TokenFetcherCredentialsTest, FetchFails) {
const absl::Status kExpectedError = absl::UnavailableError("bummer, dude");
absl::optional<FuzzingEventEngine::Duration> run_after_duration;
event_engine_->SetRunAfterDurationCallback(
[&](FuzzingEventEngine::Duration duration) {
run_after_duration = duration;
});
ExecCtx exec_ctx;
creds_->AddResult(kExpectedError);
// First request will trigger a fetch, which will fail.
auto state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1);
while (!run_after_duration.has_value()) event_engine_->Tick();
// Make sure backoff was set for the right period.
// This is 1 second (initial backoff) minus 1ms for the tick needed above.
EXPECT_EQ(run_after_duration, std::chrono::seconds(1));
run_after_duration.reset();
// Start a new call now, which will be queued and then eventually
// resumed when the next fetch happens.
state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
// Tick until the next fetch starts.
creds_->AddResult(MakeToken("foo"));
event_engine_->TickUntilIdle();
EXPECT_EQ(creds_->num_fetches(), 2);
// A call started now should use the new cached data.
state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/false);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 2);
}
TEST_F(TokenFetcherCredentialsTest, Backoff) {
const absl::Status kExpectedError = absl::UnavailableError("bummer, dude");
absl::optional<FuzzingEventEngine::Duration> run_after_duration;
event_engine_->SetRunAfterDurationCallback(
[&](FuzzingEventEngine::Duration duration) {
run_after_duration = duration;
});
ExecCtx exec_ctx;
creds_->AddResult(kExpectedError);
// First request will trigger a fetch, which will fail.
auto state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1);
while (!run_after_duration.has_value()) event_engine_->Tick();
// Make sure backoff was set for the right period.
EXPECT_EQ(run_after_duration, std::chrono::seconds(1));
run_after_duration.reset();
// Start a new call now, which will be queued and then eventually
// resumed when the next fetch happens.
state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
// Tick until the next fetch fails and the backoff timer starts again.
creds_->AddResult(kExpectedError);
while (!run_after_duration.has_value()) event_engine_->Tick();
EXPECT_EQ(creds_->num_fetches(), 2);
// The backoff time should be longer now. We account for jitter here.
EXPECT_EQ(run_after_duration, std::chrono::milliseconds(1600))
<< "actual: " << run_after_duration->count();
run_after_duration.reset();
// Start another new call to trigger another new fetch once the
// backoff expires.
state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
// Tick until the next fetch starts.
creds_->AddResult(kExpectedError);
while (!run_after_duration.has_value()) event_engine_->Tick();
EXPECT_EQ(creds_->num_fetches(), 3);
// Check backoff time again.
EXPECT_EQ(run_after_duration, std::chrono::milliseconds(2560))
<< "actual: " << run_after_duration->count();
}
TEST_F(TokenFetcherCredentialsTest, FetchNotStartedAfterBackoffWithoutRpc) {
const absl::Status kExpectedError = absl::UnavailableError("bummer, dude");
absl::optional<FuzzingEventEngine::Duration> run_after_duration;
event_engine_->SetRunAfterDurationCallback(
[&](FuzzingEventEngine::Duration duration) {
run_after_duration = duration;
});
ExecCtx exec_ctx;
creds_->AddResult(kExpectedError);
// First request will trigger a fetch, which will fail.
auto state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1);
while (!run_after_duration.has_value()) event_engine_->Tick();
// Make sure backoff was set for the right period.
EXPECT_EQ(run_after_duration, std::chrono::seconds(1));
run_after_duration.reset();
// Tick until the backoff expires. No new fetch should be started.
event_engine_->TickUntilIdle();
EXPECT_EQ(creds_->num_fetches(), 1);
// Now start a new request, which will trigger a new fetch.
creds_->AddResult(MakeToken("foo"));
state = RequestMetadataState::NewInstance(
absl::OkStatus(), "authorization: foo", /*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 2);
}
TEST_F(TokenFetcherCredentialsTest, ShutdownWhileBackoffTimerPending) {
const absl::Status kExpectedError = absl::UnavailableError("bummer, dude");
absl::optional<FuzzingEventEngine::Duration> run_after_duration;
event_engine_->SetRunAfterDurationCallback(
[&](FuzzingEventEngine::Duration duration) {
run_after_duration = duration;
});
ExecCtx exec_ctx;
creds_->AddResult(kExpectedError);
// First request will trigger a fetch, which will fail.
auto state = RequestMetadataState::NewInstance(kExpectedError, "",
/*expect_delay=*/true);
state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority,
kTestPath);
EXPECT_EQ(creds_->num_fetches(), 1);
while (!run_after_duration.has_value()) event_engine_->Tick();
// Make sure backoff was set for the right period.
EXPECT_EQ(run_after_duration, std::chrono::seconds(1));
run_after_duration.reset();
// Do nothing else. Make sure the creds shut down correctly.
}
// The subclass of ExternalAccountCredentials for testing.
// ExternalAccountCredentials is an abstract class so we can't directly test
// against it.
@ -2805,6 +2484,8 @@ TEST_F(CredentialsTest,
grpc_version_string()));
}
using grpc_event_engine::experimental::FuzzingEventEngine;
class ExternalAccountCredentialsTest : public ::testing::Test {
protected:
void SetUp() override {

Loading…
Cancel
Save