diff --git a/doc/trace_flags.md b/doc/trace_flags.md index 4ea7129a28f..56ad7191781 100644 --- a/doc/trace_flags.md +++ b/doc/trace_flags.md @@ -73,7 +73,6 @@ processing requests via debug logs. Available tracers include: - tcp - Bytes in and out of a channel. - timer - Timers (alarms) in the grpc internals. - timer_check - more detailed trace of timer logic in grpc internals. - - token_fetcher_credentials - Token fetcher call credentials framework, used for (e.g.) oauth2 token fetcher credentials. - tsi - TSI transport security. - weighted_round_robin_lb - Weighted round robin load balancing policy. - weighted_target_lb - Weighted target LB policy. diff --git a/src/core/BUILD b/src/core/BUILD index 641ba363ab8..50ce0f5c9cc 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -4335,17 +4335,14 @@ grpc_cc_library( deps = [ "arena_promise", "context", - "default_event_engine", "metadata", "poll", "pollset_set", "ref_counted", "time", "useful", - "//:backoff", "//:gpr", "//:grpc_security_base", - "//:grpc_trace", "//:httpcli", "//:iomgr", "//:orphanable", @@ -4439,6 +4436,7 @@ grpc_cc_library( language = "c++", deps = [ "closure", + "default_event_engine", "env", "error", "error_utils", diff --git a/src/core/lib/debug/trace_flags.cc b/src/core/lib/debug/trace_flags.cc index 8f7183782c8..f3fa6c63314 100644 --- a/src/core/lib/debug/trace_flags.cc +++ b/src/core/lib/debug/trace_flags.cc @@ -114,7 +114,6 @@ TraceFlag subchannel_pool_trace(false, "subchannel_pool"); TraceFlag tcp_trace(false, "tcp"); TraceFlag timer_trace(false, "timer"); TraceFlag timer_check_trace(false, "timer_check"); -TraceFlag token_fetcher_credentials_trace(false, "token_fetcher_credentials"); TraceFlag tsi_trace(false, "tsi"); TraceFlag weighted_round_robin_lb_trace(false, "weighted_round_robin_lb"); TraceFlag weighted_target_lb_trace(false, "weighted_target_lb"); @@ -207,7 +206,6 @@ const absl::flat_hash_map& GetAllTraceFlags() { {"tcp", &tcp_trace}, {"timer", &timer_trace}, {"timer_check", &timer_check_trace}, - {"token_fetcher_credentials", &token_fetcher_credentials_trace}, {"tsi", &tsi_trace}, {"weighted_round_robin_lb", &weighted_round_robin_lb_trace}, {"weighted_target_lb", &weighted_target_lb_trace}, diff --git a/src/core/lib/debug/trace_flags.h b/src/core/lib/debug/trace_flags.h index 9aa4df691f0..4aaf5e01111 100644 --- a/src/core/lib/debug/trace_flags.h +++ b/src/core/lib/debug/trace_flags.h @@ -112,7 +112,6 @@ extern TraceFlag subchannel_pool_trace; extern TraceFlag tcp_trace; extern TraceFlag timer_trace; extern TraceFlag timer_check_trace; -extern TraceFlag token_fetcher_credentials_trace; extern TraceFlag tsi_trace; extern TraceFlag weighted_round_robin_lb_trace; extern TraceFlag weighted_target_lb_trace; diff --git a/src/core/lib/debug/trace_flags.yaml b/src/core/lib/debug/trace_flags.yaml index 64c665af799..247d830e81f 100644 --- a/src/core/lib/debug/trace_flags.yaml +++ b/src/core/lib/debug/trace_flags.yaml @@ -308,9 +308,6 @@ timer: timer_check: default: false description: more detailed trace of timer logic in grpc internals. -token_fetcher_credentials: - default: false - description: Token fetcher call credentials framework, used for (e.g.) oauth2 token fetcher credentials. tsi: default: false description: TSI transport security. diff --git a/src/core/lib/promise/map.h b/src/core/lib/promise/map.h index 3ba8c19c2f6..a2a2a773eea 100644 --- a/src/core/lib/promise/map.h +++ b/src/core/lib/promise/map.h @@ -86,7 +86,7 @@ GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION auto CheckDelayed(Promise promise) { delayed = true; return Pending{}; } - return std::make_tuple(std::move(r.value()), delayed); + return std::make_tuple(r.value(), delayed); }; } diff --git a/src/core/lib/security/credentials/external/external_account_credentials.cc b/src/core/lib/security/credentials/external/external_account_credentials.cc index c83351e67e2..100b4ddd8e4 100644 --- a/src/core/lib/security/credentials/external/external_account_credentials.cc +++ b/src/core/lib/security/credentials/external/external_account_credentials.cc @@ -45,6 +45,7 @@ #include #include +#include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/external/aws_external_account_credentials.h" @@ -590,7 +591,10 @@ ExternalAccountCredentials::Create( ExternalAccountCredentials::ExternalAccountCredentials( Options options, std::vector scopes, std::shared_ptr event_engine) - : TokenFetcherCredentials(std::move(event_engine)), + : event_engine_( + event_engine == nullptr + ? grpc_event_engine::experimental::GetDefaultEventEngine() + : std::move(event_engine)), options_(std::move(options)) { if (scopes.empty()) { scopes.push_back(GOOGLE_CLOUD_PLATFORM_DEFAULT_SCOPE); diff --git a/src/core/lib/security/credentials/external/external_account_credentials.h b/src/core/lib/security/credentials/external/external_account_credentials.h index 5dfbb24a36e..0617c3e6f07 100644 --- a/src/core/lib/security/credentials/external/external_account_credentials.h +++ b/src/core/lib/security/credentials/external/external_account_credentials.h @@ -185,6 +185,10 @@ class ExternalAccountCredentials : public TokenFetcherCredentials { absl::string_view audience() const { return options_.audience; } + grpc_event_engine::experimental::EventEngine& event_engine() const { + return *event_engine_; + } + private: OrphanablePtr FetchToken( Timestamp deadline, @@ -200,6 +204,7 @@ class ExternalAccountCredentials : public TokenFetcherCredentials { Timestamp deadline, absl::AnyInvocable)> on_done) = 0; + std::shared_ptr event_engine_; Options options_; std::vector scopes_; }; diff --git a/src/core/lib/security/credentials/external/file_external_account_credentials.cc b/src/core/lib/security/credentials/external/file_external_account_credentials.cc index 086af8c79c7..cad9b7f7ee4 100644 --- a/src/core/lib/security/credentials/external/file_external_account_credentials.cc +++ b/src/core/lib/security/credentials/external/file_external_account_credentials.cc @@ -26,6 +26,7 @@ #include #include +#include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/gprpp/load_file.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_internal.h" diff --git a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc index 596641f0fff..eb5bee6478a 100644 --- a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc +++ b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc @@ -18,8 +18,6 @@ #include "src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h" -#include "src/core/lib/debug/trace.h" -#include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/iomgr/pollset_set.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/poll.h" @@ -30,11 +28,8 @@ namespace grpc_core { namespace { // Amount of time before the token's expiration that we consider it -// invalid to account for server processing time and clock skew. -constexpr Duration kTokenExpirationAdjustmentDuration = Duration::Seconds(30); - -// Amount of time before the token's expiration that we pre-fetch a new -// token. Also determines the timeout for the fetch request. +// invalid and start a new fetch. Also determines the timeout for the +// fetch request. constexpr Duration kTokenRefreshDuration = Duration::Seconds(60); } // namespace @@ -43,193 +38,18 @@ constexpr Duration kTokenRefreshDuration = Duration::Seconds(60); // TokenFetcherCredentials::Token // -TokenFetcherCredentials::Token::Token(Slice token, Timestamp expiration) - : token_(std::move(token)), - expiration_(expiration - kTokenExpirationAdjustmentDuration) {} - void TokenFetcherCredentials::Token::AddTokenToClientInitialMetadata( ClientMetadata& metadata) const { metadata.Append(GRPC_AUTHORIZATION_METADATA_KEY, token_.Ref(), [](absl::string_view, const Slice&) { abort(); }); } -// -// TokenFetcherCredentials::FetchState::BackoffTimer -// - -TokenFetcherCredentials::FetchState::BackoffTimer::BackoffTimer( - RefCountedPtr fetch_state) - : fetch_state_(std::move(fetch_state)) { - const Timestamp next_attempt_time = fetch_state_->backoff_.NextAttemptTime(); - const Duration duration = next_attempt_time - Timestamp::Now(); - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << fetch_state_->creds_.get() - << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this - << ": starting backoff timer for " << next_attempt_time << " (" - << duration << " from now)"; - timer_handle_ = fetch_state_->creds_->event_engine().RunAfter( - duration, [self = Ref()]() mutable { - ApplicationCallbackExecCtx callback_exec_ctx; - ExecCtx exec_ctx; - self->OnTimer(); - self.reset(); - }); -} - -void TokenFetcherCredentials::FetchState::BackoffTimer::Orphan() { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << fetch_state_->creds_.get() - << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this - << ": backoff timer shut down"; - if (timer_handle_.has_value()) { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << fetch_state_->creds_.get() - << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this - << ": cancelling timer"; - fetch_state_->creds_->event_engine().Cancel(*timer_handle_); - timer_handle_.reset(); - fetch_state_->ResumeQueuedCalls( - absl::CancelledError("credentials shutdown")); - } - Unref(); -} - -void TokenFetcherCredentials::FetchState::BackoffTimer::OnTimer() { - MutexLock lock(&fetch_state_->creds_->mu_); - if (!timer_handle_.has_value()) return; - timer_handle_.reset(); - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << fetch_state_->creds_.get() - << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this - << ": backoff timer fired"; - if (fetch_state_->queued_calls_.empty()) { - // If there are no pending calls when the timer fires, then orphan - // the FetchState object. Note that this drops the backoff state, - // but that's probably okay, because if we didn't have any pending - // calls during the backoff period, we probably won't see any - // immediately now either. - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << fetch_state_->creds_.get() - << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this - << ": no pending calls, clearing state"; - fetch_state_->creds_->fetch_state_.reset(); - } else { - // If there are pending calls, then start a new fetch attempt. - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << fetch_state_->creds_.get() - << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this - << ": starting new fetch attempt"; - fetch_state_->StartFetchAttempt(); - } -} - -// -// TokenFetcherCredentials::FetchState -// - -TokenFetcherCredentials::FetchState::FetchState( - WeakRefCountedPtr creds) - : creds_(std::move(creds)), - backoff_(BackOff::Options() - .set_initial_backoff(Duration::Seconds(1)) - .set_multiplier(1.6) - .set_jitter(creds_->test_only_use_backoff_jitter_ ? 0.2 : 0) - .set_max_backoff(Duration::Seconds(120))) { - StartFetchAttempt(); -} - -void TokenFetcherCredentials::FetchState::Orphan() { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << creds_.get() - << "]: fetch_state=" << this << ": shutting down"; - // Cancels fetch or backoff timer, if any. - state_ = Shutdown{}; - Unref(); -} - -void TokenFetcherCredentials::FetchState::StartFetchAttempt() { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << creds_.get() - << "]: fetch_state=" << this << ": starting fetch"; - state_ = creds_->FetchToken( - /*deadline=*/Timestamp::Now() + kTokenRefreshDuration, - [self = Ref()](absl::StatusOr> token) mutable { - self->TokenFetchComplete(std::move(token)); - }); -} - -void TokenFetcherCredentials::FetchState::TokenFetchComplete( - absl::StatusOr> token) { - MutexLock lock(&creds_->mu_); - // If we were shut down, clean up. - if (absl::holds_alternative(state_)) { - if (token.ok()) token = absl::CancelledError("credentials shutdown"); - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << creds_.get() - << "]: fetch_state=" << this - << ": shut down before fetch completed: " << token.status(); - ResumeQueuedCalls(std::move(token)); - return; - } - // If succeeded, update cache in creds object. - if (token.ok()) { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << creds_.get() - << "]: fetch_state=" << this << ": token fetch succeeded"; - creds_->token_ = *token; - creds_->fetch_state_.reset(); // Orphan ourselves. - } else { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << creds_.get() - << "]: fetch_state=" << this - << ": token fetch failed: " << token.status(); - // If failed, start backoff timer. - state_ = OrphanablePtr(new BackoffTimer(Ref())); - } - ResumeQueuedCalls(std::move(token)); -} - -void TokenFetcherCredentials::FetchState::ResumeQueuedCalls( - absl::StatusOr> token) { - // Invoke callbacks for all pending requests. - for (auto& queued_call : queued_calls_) { - queued_call->result = token; - queued_call->done.store(true, std::memory_order_release); - queued_call->waker.Wakeup(); - grpc_polling_entity_del_from_pollset_set( - queued_call->pollent, - grpc_polling_entity_pollset_set(&creds_->pollent_)); - } - queued_calls_.clear(); -} - -RefCountedPtr -TokenFetcherCredentials::FetchState::QueueCall( - ClientMetadataHandle initial_metadata) { - // Add call to pending list. - auto queued_call = MakeRefCounted(); - queued_call->waker = GetContext()->MakeNonOwningWaker(); - queued_call->pollent = GetContext(); - grpc_polling_entity_add_to_pollset_set( - queued_call->pollent, grpc_polling_entity_pollset_set(&creds_->pollent_)); - queued_call->md = std::move(initial_metadata); - queued_calls_.insert(queued_call); - return queued_call; -} - // // TokenFetcherCredentials // -TokenFetcherCredentials::TokenFetcherCredentials( - std::shared_ptr event_engine, - bool test_only_use_backoff_jitter) - : event_engine_( - event_engine == nullptr - ? grpc_event_engine::experimental::GetDefaultEventEngine() - : std::move(event_engine)), - test_only_use_backoff_jitter_(test_only_use_backoff_jitter), - pollent_(grpc_polling_entity_create_from_pollset_set( +TokenFetcherCredentials::TokenFetcherCredentials() + : pollent_(grpc_polling_entity_create_from_pollset_set( grpc_pollset_set_create())) {} TokenFetcherCredentials::~TokenFetcherCredentials() { @@ -238,63 +58,73 @@ TokenFetcherCredentials::~TokenFetcherCredentials() { void TokenFetcherCredentials::Orphaned() { MutexLock lock(&mu_); - fetch_state_.reset(); + auto* fetch_request = absl::get_if>(&token_); + if (fetch_request != nullptr) fetch_request->reset(); } ArenaPromise> TokenFetcherCredentials::GetRequestMetadata( ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs*) { - RefCountedPtr queued_call; + RefCountedPtr pending_call; { MutexLock lock(&mu_); - // If we don't have a cached token or the token is within the - // refresh duration, start a new fetch if there isn't a pending one. - if ((token_ == nullptr || (token_->ExpirationTime() - Timestamp::Now()) <= - kTokenRefreshDuration) && - fetch_state_ == nullptr) { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << this - << "]: " << GetContext()->DebugTag() - << " triggering new token fetch"; - fetch_state_ = OrphanablePtr( - new FetchState(WeakRefAsSubclass())); - } - // If we have a cached non-expired token, use it. - if (token_ != nullptr && - (token_->ExpirationTime() - Timestamp::Now()) > Duration::Zero()) { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << this - << "]: " << GetContext()->DebugTag() - << " using cached token"; - token_->AddTokenToClientInitialMetadata(*initial_metadata); + // Check if we can use the cached token. + auto* cached_token = absl::get_if>(&token_); + if (cached_token != nullptr && *cached_token != nullptr && + ((*cached_token)->ExpirationTime() - Timestamp::Now()) > + kTokenRefreshDuration) { + (*cached_token)->AddTokenToClientInitialMetadata(*initial_metadata); return Immediate(std::move(initial_metadata)); } - // If we don't have a cached token, this call will need to be queued. - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << this - << "]: " << GetContext()->DebugTag() - << " no cached token; queuing call"; - queued_call = fetch_state_->QueueCall(std::move(initial_metadata)); + // Couldn't get the token from the cache. + // Add this call to the pending list. + pending_call = MakeRefCounted(); + pending_call->waker = GetContext()->MakeNonOwningWaker(); + pending_call->pollent = GetContext(); + grpc_polling_entity_add_to_pollset_set( + pending_call->pollent, grpc_polling_entity_pollset_set(&pollent_)); + pending_call->md = std::move(initial_metadata); + pending_calls_.insert(pending_call); + // Start a new fetch if needed. + if (!absl::holds_alternative>(token_)) { + token_ = FetchToken( + /*deadline=*/Timestamp::Now() + kTokenRefreshDuration, + [self = WeakRefAsSubclass()]( + absl::StatusOr> token) mutable { + self->TokenFetchComplete(std::move(token)); + }); + } } - return [this, queued_call = std::move(queued_call)]() - -> Poll> { - if (!queued_call->done.load(std::memory_order_acquire)) { + return [pending_call = std::move( + pending_call)]() -> Poll> { + if (!pending_call->done.load(std::memory_order_acquire)) { return Pending{}; } - if (!queued_call->result.ok()) { - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << this - << "]: " << GetContext()->DebugTag() - << " token fetch failed; failing call"; - return queued_call->result.status(); + if (!pending_call->result.ok()) { + return pending_call->result.status(); } - GRPC_TRACE_LOG(token_fetcher_credentials, INFO) - << "[TokenFetcherCredentials " << this - << "]: " << GetContext()->DebugTag() - << " token fetch complete; resuming call"; - (*queued_call->result)->AddTokenToClientInitialMetadata(*queued_call->md); - return std::move(queued_call->md); + (*pending_call->result)->AddTokenToClientInitialMetadata(*pending_call->md); + return std::move(pending_call->md); }; } +void TokenFetcherCredentials::TokenFetchComplete( + absl::StatusOr> token) { + // Update cache and grab list of pending requests. + absl::flat_hash_set> pending_calls; + { + MutexLock lock(&mu_); + token_ = token.value_or(nullptr); + pending_calls_.swap(pending_calls); + } + // Invoke callbacks for all pending requests. + for (auto& pending_call : pending_calls) { + pending_call->result = token; + pending_call->done.store(true, std::memory_order_release); + pending_call->waker.Wakeup(); + grpc_polling_entity_del_from_pollset_set( + pending_call->pollent, grpc_polling_entity_pollset_set(&pollent_)); + } +} + } // namespace grpc_core diff --git a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h index c73d469234b..72793afcdee 100644 --- a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h +++ b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h @@ -26,9 +26,6 @@ #include "absl/status/statusor.h" #include "absl/types/variant.h" -#include - -#include "src/core/lib/backoff/backoff.h" #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" @@ -50,7 +47,8 @@ class TokenFetcherCredentials : public grpc_call_credentials { // Represents a token. class Token : public RefCounted { public: - Token(Slice token, Timestamp expiration); + Token(Slice token, Timestamp expiration) + : token_(std::move(token)), expiration_(expiration) {} // Returns the token's expiration time. Timestamp ExpirationTime() const { return expiration_; } @@ -75,10 +73,7 @@ class TokenFetcherCredentials : public grpc_call_credentials { // Base class for fetch requests. class FetchRequest : public InternallyRefCounted {}; - explicit TokenFetcherCredentials( - std::shared_ptr - event_engine = nullptr, - bool test_only_use_backoff_jitter = true); + TokenFetcherCredentials(); // Fetches a token. The on_done callback will be invoked when complete. virtual OrphanablePtr FetchToken( @@ -86,15 +81,11 @@ class TokenFetcherCredentials : public grpc_call_credentials { absl::AnyInvocable>)> on_done) = 0; - grpc_event_engine::experimental::EventEngine& event_engine() const { - return *event_engine_; - } - grpc_polling_entity* pollent() { return &pollent_; } private: // A call that is waiting for a token fetch request to complete. - struct QueuedCall : public RefCounted { + struct PendingCall : public RefCounted { std::atomic done{false}; Waker waker; grpc_polling_entity* pollent; @@ -102,72 +93,20 @@ class TokenFetcherCredentials : public grpc_call_credentials { absl::StatusOr> result; }; - class FetchState : public InternallyRefCounted { - public: - explicit FetchState(WeakRefCountedPtr creds) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); - - // Disabling thread safety annotations, since Orphan() is called - // by OrpahanablePtr<>, which does not have the right lock - // annotations. - void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; - - RefCountedPtr QueueCall(ClientMetadataHandle initial_metadata) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); - - private: - class BackoffTimer : public InternallyRefCounted { - public: - explicit BackoffTimer(RefCountedPtr fetch_state) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); - - // Disabling thread safety annotations, since Orphan() is called - // by OrpahanablePtr<>, which does not have the right lock - // annotations. - void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; - - private: - void OnTimer(); - - RefCountedPtr fetch_state_; - absl::optional - timer_handle_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); - }; - - struct Shutdown {}; - - void StartFetchAttempt() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); - void TokenFetchComplete(absl::StatusOr> token); - void ResumeQueuedCalls(absl::StatusOr> token) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); - - WeakRefCountedPtr creds_; - // Pending token-fetch request or backoff timer, if any. - absl::variant, OrphanablePtr, - Shutdown> - state_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); - // Calls that are queued up waiting for the token. - absl::flat_hash_set> queued_calls_ - ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); - // Backoff state. - BackOff backoff_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); - }; - int cmp_impl(const grpc_call_credentials* other) const override { // TODO(yashykt): Check if we can do something better here return QsortCompare(static_cast(this), other); } - std::shared_ptr event_engine_; - const bool test_only_use_backoff_jitter_; + void TokenFetchComplete(absl::StatusOr> token); Mutex mu_; - // Cached token, if any. - RefCountedPtr token_ ABSL_GUARDED_BY(&mu_); - // Fetch state, if any. - OrphanablePtr fetch_state_ ABSL_GUARDED_BY(&mu_); - + // Either the cached token or a pending request to fetch the token. + absl::variant, OrphanablePtr> token_ + ABSL_GUARDED_BY(&mu_); + // Calls that are queued up waiting for the token. + absl::flat_hash_set> pending_calls_ + ABSL_GUARDED_BY(&mu_); grpc_polling_entity pollent_ ABSL_GUARDED_BY(&mu_); }; diff --git a/test/core/security/credentials_test.cc b/test/core/security/credentials_test.cc index 03467770b4d..fb93d522a7c 100644 --- a/test/core/security/credentials_test.cc +++ b/test/core/security/credentials_test.cc @@ -50,7 +50,6 @@ #include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/timer_manager.h" #include "src/core/lib/promise/exec_ctx_wakeup_scheduler.h" -#include "src/core/lib/promise/map.h" #include "src/core/lib/promise/promise.h" #include "src/core/lib/promise/seq.h" #include "src/core/lib/security/context/security_context.h" @@ -80,7 +79,6 @@ namespace grpc_core { -using grpc_event_engine::experimental::FuzzingEventEngine; using internal::grpc_flush_cached_google_default_credentials; using internal::set_gce_tenancy_checker_for_testing; @@ -428,19 +426,16 @@ TEST_F(CredentialsTest, class RequestMetadataState : public RefCounted { public: static RefCountedPtr NewInstance( - grpc_error_handle expected_error, std::string expected, - absl::optional expect_delay = absl::nullopt) { + grpc_error_handle expected_error, std::string expected) { return MakeRefCounted( - expected_error, std::move(expected), expect_delay, + expected_error, std::move(expected), grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create())); } RequestMetadataState(grpc_error_handle expected_error, std::string expected, - absl::optional expect_delay, grpc_polling_entity pollent) : expected_error_(expected_error), expected_(std::move(expected)), - expect_delay_(expect_delay), pollent_(pollent) {} ~RequestMetadataState() override { @@ -458,18 +453,12 @@ class RequestMetadataState : public RefCounted { activity_ = MakeActivity( [this, creds] { return Seq( - CheckDelayed(creds->GetRequestMetadata( + creds->GetRequestMetadata( ClientMetadataHandle(&md_, Arena::PooledDeleter(nullptr)), - &get_request_metadata_args_)), - [this](std::tuple, bool> - metadata_and_delayed) { - auto& metadata = std::get<0>(metadata_and_delayed); - const bool delayed = std::get<1>(metadata_and_delayed); - if (expect_delay_.has_value()) { - EXPECT_EQ(delayed, *expect_delay_); - } + &get_request_metadata_args_), + [this](absl::StatusOr metadata) { if (metadata.ok()) { - EXPECT_EQ(metadata->get(), &md_); + CHECK(metadata->get() == &md_); } return metadata.status(); }); @@ -534,7 +523,6 @@ class RequestMetadataState : public RefCounted { grpc_error_handle expected_error_; std::string expected_; - absl::optional expect_delay_; RefCountedPtr arena_ = SimpleArenaAllocator()->MakeArena(); grpc_metadata_batch md_; grpc_call_credentials::GetRequestMetadataArgs get_request_metadata_args_; @@ -2381,315 +2369,6 @@ int aws_external_account_creds_httpcli_post_success( return 1; } -class TokenFetcherCredentialsTest : public ::testing::Test { - protected: - class TestTokenFetcherCredentials final : public TokenFetcherCredentials { - public: - explicit TestTokenFetcherCredentials( - std::shared_ptr - event_engine = nullptr) - : TokenFetcherCredentials(std::move(event_engine), - /*test_only_use_backoff_jitter=*/false) {} - - ~TestTokenFetcherCredentials() override { CHECK_EQ(queue_.size(), 0); } - - void AddResult(absl::StatusOr> result) { - MutexLock lock(&mu_); - queue_.push_front(std::move(result)); - } - - size_t num_fetches() const { return num_fetches_; } - - private: - class TestFetchRequest final : public FetchRequest { - public: - TestFetchRequest( - grpc_event_engine::experimental::EventEngine& event_engine, - absl::AnyInvocable>)> - on_done, - absl::StatusOr> result) { - event_engine.Run([on_done = std::move(on_done), - result = std::move(result)]() mutable { - ApplicationCallbackExecCtx application_exec_ctx; - ExecCtx exec_ctx; - std::exchange(on_done, nullptr)(std::move(result)); - }); - } - - void Orphan() override { Unref(); } - }; - - OrphanablePtr FetchToken( - Timestamp deadline, - absl::AnyInvocable>)> on_done) - override { - absl::StatusOr> result; - { - MutexLock lock(&mu_); - CHECK(!queue_.empty()); - result = std::move(queue_.back()); - queue_.pop_back(); - } - num_fetches_.fetch_add(1); - return MakeOrphanable( - event_engine(), std::move(on_done), std::move(result)); - } - - std::string debug_string() override { - return "TestTokenFetcherCredentials"; - } - - UniqueTypeName type() const override { - static UniqueTypeName::Factory kFactory("TestTokenFetcherCredentials"); - return kFactory.Create(); - } - - Mutex mu_; - std::deque>> queue_ - ABSL_GUARDED_BY(&mu_); - - std::atomic num_fetches_{0}; - }; - - void SetUp() override { - grpc_timer_manager_set_start_threaded(false); - grpc_init(); - } - - void TearDown() override { - event_engine_->FuzzingDone(); - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); - creds_.reset(); - grpc_event_engine::experimental::WaitForSingleOwner( - std::move(event_engine_)); - grpc_shutdown_blocking(); - } - - static RefCountedPtr MakeToken( - absl::string_view token, Timestamp expiration = Timestamp::InfFuture()) { - return MakeRefCounted( - Slice::FromCopiedString(token), expiration); - } - - std::shared_ptr event_engine_ = - std::make_shared(FuzzingEventEngine::Options(), - fuzzing_event_engine::Actions()); - RefCountedPtr creds_ = - MakeRefCounted(event_engine_); -}; - -TEST_F(TokenFetcherCredentialsTest, Basic) { - const auto kExpirationTime = Timestamp::Now() + Duration::Hours(1); - ExecCtx exec_ctx; - creds_->AddResult(MakeToken("foo", kExpirationTime)); - // First request will trigger a fetch. - auto state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 1); - // Second request while fetch is still outstanding will be delayed but - // will not trigger a new fetch. - state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 1); - // Now tick to finish the fetch. - event_engine_->TickUntilIdle(); - // Next request will be served from cache with no delay. - state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: foo", /*expect_delay=*/false); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 1); - // Advance time to expiration minus expiration adjustment and prefetch time. - exec_ctx.TestOnlySetNow(kExpirationTime - Duration::Seconds(90)); - // No new fetch yet. - EXPECT_EQ(creds_->num_fetches(), 1); - // Next request will trigger a new fetch but will still use the - // cached token. - creds_->AddResult(MakeToken("bar")); - state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: foo", /*expect_delay=*/false); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 2); - event_engine_->TickUntilIdle(); - // Next request will use the new data. - state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: bar", /*expect_delay=*/false); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 2); -} - -TEST_F(TokenFetcherCredentialsTest, Expires30SecondsEarly) { - const auto kExpirationTime = Timestamp::Now() + Duration::Hours(1); - ExecCtx exec_ctx; - creds_->AddResult(MakeToken("foo", kExpirationTime)); - // First request will trigger a fetch. - auto state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 1); - event_engine_->TickUntilIdle(); - // Advance time to expiration minus 30 seconds. - exec_ctx.TestOnlySetNow(kExpirationTime - Duration::Seconds(30)); - // No new fetch yet. - EXPECT_EQ(creds_->num_fetches(), 1); - // Next request will trigger a new fetch and will delay the call until - // the fetch completes. - creds_->AddResult(MakeToken("bar")); - state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: bar", /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 2); - event_engine_->TickUntilIdle(); -} - -TEST_F(TokenFetcherCredentialsTest, FetchFails) { - const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); - absl::optional run_after_duration; - event_engine_->SetRunAfterDurationCallback( - [&](FuzzingEventEngine::Duration duration) { - run_after_duration = duration; - }); - ExecCtx exec_ctx; - creds_->AddResult(kExpectedError); - // First request will trigger a fetch, which will fail. - auto state = RequestMetadataState::NewInstance(kExpectedError, "", - /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 1); - while (!run_after_duration.has_value()) event_engine_->Tick(); - // Make sure backoff was set for the right period. - // This is 1 second (initial backoff) minus 1ms for the tick needed above. - EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); - run_after_duration.reset(); - // Start a new call now, which will be queued and then eventually - // resumed when the next fetch happens. - state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - // Tick until the next fetch starts. - creds_->AddResult(MakeToken("foo")); - event_engine_->TickUntilIdle(); - EXPECT_EQ(creds_->num_fetches(), 2); - // A call started now should use the new cached data. - state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: foo", /*expect_delay=*/false); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 2); -} - -TEST_F(TokenFetcherCredentialsTest, Backoff) { - const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); - absl::optional run_after_duration; - event_engine_->SetRunAfterDurationCallback( - [&](FuzzingEventEngine::Duration duration) { - run_after_duration = duration; - }); - ExecCtx exec_ctx; - creds_->AddResult(kExpectedError); - // First request will trigger a fetch, which will fail. - auto state = RequestMetadataState::NewInstance(kExpectedError, "", - /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 1); - while (!run_after_duration.has_value()) event_engine_->Tick(); - // Make sure backoff was set for the right period. - EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); - run_after_duration.reset(); - // Start a new call now, which will be queued and then eventually - // resumed when the next fetch happens. - state = RequestMetadataState::NewInstance(kExpectedError, "", - /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - // Tick until the next fetch fails and the backoff timer starts again. - creds_->AddResult(kExpectedError); - while (!run_after_duration.has_value()) event_engine_->Tick(); - EXPECT_EQ(creds_->num_fetches(), 2); - // The backoff time should be longer now. We account for jitter here. - EXPECT_EQ(run_after_duration, std::chrono::milliseconds(1600)) - << "actual: " << run_after_duration->count(); - run_after_duration.reset(); - // Start another new call to trigger another new fetch once the - // backoff expires. - state = RequestMetadataState::NewInstance(kExpectedError, "", - /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - // Tick until the next fetch starts. - creds_->AddResult(kExpectedError); - while (!run_after_duration.has_value()) event_engine_->Tick(); - EXPECT_EQ(creds_->num_fetches(), 3); - // Check backoff time again. - EXPECT_EQ(run_after_duration, std::chrono::milliseconds(2560)) - << "actual: " << run_after_duration->count(); -} - -TEST_F(TokenFetcherCredentialsTest, FetchNotStartedAfterBackoffWithoutRpc) { - const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); - absl::optional run_after_duration; - event_engine_->SetRunAfterDurationCallback( - [&](FuzzingEventEngine::Duration duration) { - run_after_duration = duration; - }); - ExecCtx exec_ctx; - creds_->AddResult(kExpectedError); - // First request will trigger a fetch, which will fail. - auto state = RequestMetadataState::NewInstance(kExpectedError, "", - /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 1); - while (!run_after_duration.has_value()) event_engine_->Tick(); - // Make sure backoff was set for the right period. - EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); - run_after_duration.reset(); - // Tick until the backoff expires. No new fetch should be started. - event_engine_->TickUntilIdle(); - EXPECT_EQ(creds_->num_fetches(), 1); - // Now start a new request, which will trigger a new fetch. - creds_->AddResult(MakeToken("foo")); - state = RequestMetadataState::NewInstance( - absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 2); -} - -TEST_F(TokenFetcherCredentialsTest, ShutdownWhileBackoffTimerPending) { - const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); - absl::optional run_after_duration; - event_engine_->SetRunAfterDurationCallback( - [&](FuzzingEventEngine::Duration duration) { - run_after_duration = duration; - }); - ExecCtx exec_ctx; - creds_->AddResult(kExpectedError); - // First request will trigger a fetch, which will fail. - auto state = RequestMetadataState::NewInstance(kExpectedError, "", - /*expect_delay=*/true); - state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, - kTestPath); - EXPECT_EQ(creds_->num_fetches(), 1); - while (!run_after_duration.has_value()) event_engine_->Tick(); - // Make sure backoff was set for the right period. - EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); - run_after_duration.reset(); - // Do nothing else. Make sure the creds shut down correctly. -} - // The subclass of ExternalAccountCredentials for testing. // ExternalAccountCredentials is an abstract class so we can't directly test // against it. @@ -2805,6 +2484,8 @@ TEST_F(CredentialsTest, grpc_version_string())); } +using grpc_event_engine::experimental::FuzzingEventEngine; + class ExternalAccountCredentialsTest : public ::testing::Test { protected: void SetUp() override {