diff --git a/src/core/BUILD b/src/core/BUILD index dc761e5a8e9..dd9ded49632 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -4292,6 +4292,40 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "token_fetcher_credentials", + srcs = [ + "lib/security/credentials/token_fetcher/token_fetcher_credentials.cc", + ], + hdrs = [ + "lib/security/credentials/token_fetcher/token_fetcher_credentials.h", + ], + external_deps = [ + "absl/container:flat_hash_set", + "absl/functional:any_invocable", + "absl/status:statusor", + "absl/types:variant", + ], + language = "c++", + deps = [ + "arena_promise", + "context", + "metadata", + "poll", + "pollset_set", + "ref_counted", + "time", + "useful", + "//:gpr", + "//:grpc_security_base", + "//:httpcli", + "//:iomgr", + "//:orphanable", + "//:promise", + "//:ref_counted_ptr", + ], +) + grpc_cc_library( name = "grpc_oauth2_credentials", srcs = [ @@ -4329,6 +4363,7 @@ grpc_cc_library( "slice_refcount", "status_helper", "time", + "token_fetcher_credentials", "unique_type_name", "useful", "//:api_trace", diff --git a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc new file mode 100644 index 00000000000..a911b2cc4b3 --- /dev/null +++ b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc @@ -0,0 +1,113 @@ +// +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h" + +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/promise.h" + +namespace grpc_core { + +namespace { + +// Amount of time before the token's expiration that we consider it +// invalid and start a new fetch. Also determines the timeout for the +// fetch request. +constexpr Duration kTokenRefreshDuration = Duration::Seconds(60); + +} // namespace + +TokenFetcherCredentials::TokenFetcherCredentials( + std::unique_ptr token_fetcher) + : token_fetcher_(std::move(token_fetcher)), + pollent_(grpc_polling_entity_create_from_pollset_set( + grpc_pollset_set_create())) {} + +TokenFetcherCredentials::~TokenFetcherCredentials() { + grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_)); +} + +ArenaPromise> +TokenFetcherCredentials::GetRequestMetadata( + ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs*) { + RefCountedPtr pending_call; + { + MutexLock lock(&mu_); + // Check if we can use the cached token. + auto* cached_token = absl::get_if>(&token_); + if (cached_token != nullptr && + (*cached_token)->ExpirationTime() < + (Timestamp::Now() - kTokenRefreshDuration)) { + (*cached_token)->AddTokenToClientInitialMetadata(*initial_metadata); + return Immediate(std::move(initial_metadata)); + } + // Couldn't get the token from the cache. + // Add this call to the pending list. + pending_call = MakeRefCounted(); + pending_call->waker = GetContext()->MakeNonOwningWaker(); + pending_call->pollent = GetContext(); + grpc_polling_entity_add_to_pollset_set( + pending_call->pollent, grpc_polling_entity_pollset_set(&pollent_)); + pending_call->md = std::move(initial_metadata); + pending_calls_.insert(pending_call); + // Start a new fetch if needed. + if (!absl::holds_alternative>(token_)) { + token_ = token_fetcher_->FetchToken( + &pollent_, /*deadline=*/Timestamp::Now() + kTokenRefreshDuration, + [self = RefAsSubclass()]( + absl::StatusOr> token) mutable { + self->TokenFetchComplete(std::move(token)); + }); + } + } + return [pending_call = std::move(pending_call)]() + -> Poll> { + if (!pending_call->done.load(std::memory_order_acquire)) { + return Pending{}; + } + if (!pending_call->result.ok()) { + return pending_call->result.status(); + } + (*pending_call->result)->AddTokenToClientInitialMetadata( + *pending_call->md); + return std::move(pending_call->md); + }; +} + +void TokenFetcherCredentials::TokenFetchComplete( + absl::StatusOr> token) { + // Update cache and grab list of pending requests. + absl::flat_hash_set> pending_calls; + { + MutexLock lock(&mu_); + token_ = token.value_or(nullptr); + pending_calls_.swap(pending_calls); + } + // Invoke callbacks for all pending requests. + for (auto& pending_call : pending_calls) { + pending_call->result = token; + pending_call->done.store(true, std::memory_order_release); + pending_call->waker.Wakeup(); + grpc_polling_entity_del_from_pollset_set( + pending_call->pollent, grpc_polling_entity_pollset_set(&pollent_)); + } +} + +} // namespace grpc_core diff --git a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h new file mode 100644 index 00000000000..ef2da06a2c0 --- /dev/null +++ b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h @@ -0,0 +1,109 @@ +// +// Copyright 2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef GRPC_SRC_CORE_LIB_SECURITY_CREDENTIALS_TOKEN_FETCHER_TOKEN_FETCHER_CREDENTIALS_H +#define GRPC_SRC_CORE_LIB_SECURITY_CREDENTIALS_TOKEN_FETCHER_TOKEN_FETCHER_CREDENTIALS_H + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/types/variant.h" + +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/time.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/promise/arena_promise.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/util/http_client/httpcli.h" +#include "src/core/util/useful.h" + +namespace grpc_core { + +// A base class for credentials that fetch tokens via an HTTP request. +class TokenFetcherCredentials : public grpc_call_credentials { + public: + // Represents a token. + class Token : public RefCounted { + public: + virtual ~Token() = default; + + // Returns the token's expiration time. + virtual Timestamp ExpirationTime() = 0; + + // Adds the token to the call's client initial metadata. + virtual void AddTokenToClientInitialMetadata(ClientMetadata& metadata) = 0; + }; + + // A token fetcher interface. + class TokenFetcher { + public: + virtual ~TokenFetcher() = default; + + // Fetches a token. The on_done callback will be invoked when complete. + // The fetch may be cancelled by orphaning the returned HttpRequest. + virtual OrphanablePtr FetchToken( + grpc_polling_entity* pollent, Timestamp deadline, + absl::AnyInvocable>)> + on_done) = 0; + }; + + ~TokenFetcherCredentials() override; + + ArenaPromise> + GetRequestMetadata(ClientMetadataHandle initial_metadata, + const GetRequestMetadataArgs* args) override; + + protected: + explicit TokenFetcherCredentials(std::unique_ptr token_fetcher); + + private: + // A call that is waiting for a token fetch request to complete. + struct PendingCall : public RefCounted { + std::atomic done{false}; + Waker waker; + grpc_polling_entity* pollent; + ClientMetadataHandle md; + absl::StatusOr> result; + }; + + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return QsortCompare(static_cast(this), other); + } + + void TokenFetchComplete(absl::StatusOr> token); + + const std::unique_ptr token_fetcher_; + + Mutex mu_; + absl::variant, OrphanablePtr> token_ + ABSL_GUARDED_BY(&mu_); + absl::flat_hash_set> pending_calls_ + ABSL_GUARDED_BY(&mu_); + grpc_polling_entity pollent_ ABSL_GUARDED_BY(&mu_); +}; + +} // namespace grpc_core + +#endif // GRPC_SRC_CORE_LIB_SECURITY_CREDENTIALS_TOKEN_FETCHER_TOKEN_FETCHER_CREDENTIALS_H