[context] Move security context to arena based context (#36783)

Closes #36783

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/36783 from ctiller:ctx9 2ade190f45
PiperOrigin-RevId: 639096344
pull/36775/head
Craig Tiller 6 months ago committed by Copybara-Service
parent f3220d08d2
commit 21a5bb487c
  1. 4
      src/core/lib/channel/context.h
  2. 17
      src/core/lib/security/context/security_context.cc
  3. 33
      src/core/lib/security/context/security_context.h
  4. 19
      src/core/lib/security/transport/client_auth_filter.cc
  5. 6
      src/core/lib/security/transport/server_auth_filter.cc

@ -32,10 +32,6 @@ typedef enum {
/// grpc_call* associated with this context.
GRPC_CONTEXT_CALL = 0,
/// Value is either a \a grpc_client_security_context or a
/// \a grpc_server_security_context.
GRPC_CONTEXT_SECURITY,
/// Value is a \a census_context.
GRPC_CONTEXT_TRACING,

@ -56,12 +56,12 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call,
LOG(ERROR) << "Method is client-side only.";
return GRPC_CALL_ERROR_NOT_ON_SERVER;
}
ctx = static_cast<grpc_client_security_context*>(
grpc_call_context_get(call, GRPC_CONTEXT_SECURITY));
auto* arena = grpc_call_get_arena(call);
ctx = grpc_core::DownCast<grpc_client_security_context*>(
arena->GetContext<grpc_core::SecurityContext>());
if (ctx == nullptr) {
ctx = grpc_client_security_context_create(grpc_call_get_arena(call), creds);
grpc_call_context_set(call, GRPC_CONTEXT_SECURITY, ctx,
grpc_client_security_context_destroy);
ctx = grpc_client_security_context_create(arena, creds);
arena->SetContext<grpc_core::SecurityContext>(ctx);
} else {
ctx->creds = creds != nullptr ? creds->Ref() : nullptr;
}
@ -70,11 +70,12 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call,
}
grpc_auth_context* grpc_call_auth_context(grpc_call* call) {
void* sec_ctx = grpc_call_context_get(call, GRPC_CONTEXT_SECURITY);
auto* sec_ctx =
grpc_call_get_arena(call)->GetContext<grpc_core::SecurityContext>();
GRPC_API_TRACE("grpc_call_auth_context(call=%p)", 1, (call));
if (sec_ctx == nullptr) return nullptr;
if (grpc_call_is_client(call)) {
auto* sc = static_cast<grpc_client_security_context*>(sec_ctx);
auto* sc = grpc_core::DownCast<grpc_client_security_context*>(sec_ctx);
if (sc->auth_context == nullptr) {
return nullptr;
} else {
@ -83,7 +84,7 @@ grpc_auth_context* grpc_call_auth_context(grpc_call* call) {
.release();
}
} else {
auto* sc = static_cast<grpc_server_security_context*>(sec_ctx);
auto* sc = grpc_core::DownCast<grpc_server_security_context*>(sec_ctx);
if (sc->auth_context == nullptr) {
return nullptr;
} else {

@ -136,15 +136,24 @@ struct grpc_security_context_extension {
void (*destroy)(void*) = nullptr;
};
namespace grpc_core {
class SecurityContext {
public:
virtual ~SecurityContext() = default;
};
} // namespace grpc_core
// --- grpc_client_security_context ---
// Internal client-side security context.
struct grpc_client_security_context {
struct grpc_client_security_context final : public grpc_core::SecurityContext {
explicit grpc_client_security_context(
grpc_core::RefCountedPtr<grpc_call_credentials> creds)
: creds(std::move(creds)) {}
~grpc_client_security_context();
~grpc_client_security_context() override;
grpc_core::RefCountedPtr<grpc_call_credentials> creds;
grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
@ -159,9 +168,9 @@ void grpc_client_security_context_destroy(void* ctx);
// Internal server-side security context.
struct grpc_server_security_context {
struct grpc_server_security_context final : public grpc_core::SecurityContext {
grpc_server_security_context() = default;
~grpc_server_security_context();
~grpc_server_security_context() override;
grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
grpc_security_context_extension extension;
@ -178,4 +187,20 @@ grpc_auth_context* grpc_auth_context_from_arg(const grpc_arg* arg);
grpc_auth_context* grpc_find_auth_context_in_args(
const grpc_channel_args* args);
namespace grpc_core {
template <>
struct ArenaContextType<SecurityContext> {
static void Destroy(SecurityContext* p) { p->~SecurityContext(); }
};
template <>
struct ContextSubclass<grpc_client_security_context> {
using Base = SecurityContext;
};
template <>
struct ContextSubclass<grpc_server_security_context> {
using Base = SecurityContext;
};
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_SECURITY_CONTEXT_SECURITY_CONTEXT_H

@ -110,8 +110,7 @@ ClientAuthFilter::ClientAuthFilter(
ArenaPromise<absl::StatusOr<CallArgs>> ClientAuthFilter::GetCallCredsMetadata(
CallArgs call_args) {
auto* ctx = static_cast<grpc_client_security_context*>(
GetContext<grpc_call_context_element>()[GRPC_CONTEXT_SECURITY].value);
auto* ctx = GetContext<grpc_client_security_context>();
grpc_call_credentials* channel_call_creds =
args_.security_connector->mutable_request_metadata_creds();
const bool call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr);
@ -178,17 +177,13 @@ ArenaPromise<absl::StatusOr<CallArgs>> ClientAuthFilter::GetCallCredsMetadata(
ArenaPromise<ServerMetadataHandle> ClientAuthFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
auto* legacy_ctx = GetContext<grpc_call_context_element>();
if (legacy_ctx[GRPC_CONTEXT_SECURITY].value == nullptr) {
legacy_ctx[GRPC_CONTEXT_SECURITY].value =
grpc_client_security_context_create(GetContext<Arena>(),
/*creds=*/nullptr);
legacy_ctx[GRPC_CONTEXT_SECURITY].destroy =
grpc_client_security_context_destroy;
auto* sec_ctx = MaybeGetContext<grpc_client_security_context>();
if (sec_ctx == nullptr) {
sec_ctx = grpc_client_security_context_create(GetContext<Arena>(),
/*creds=*/nullptr);
SetContext<SecurityContext>(sec_ctx);
}
static_cast<grpc_client_security_context*>(
legacy_ctx[GRPC_CONTEXT_SECURITY].value)
->auth_context = args_.auth_context;
sec_ctx->auth_context = args_.auth_context;
auto* host =
call_args.client_initial_metadata->get_pointer(HttpAuthorityMetadata());

@ -203,11 +203,7 @@ ServerAuthFilter::Call::Call(ServerAuthFilter* filter) {
grpc_server_security_context_create(GetContext<Arena>());
server_ctx->auth_context =
filter->auth_context_->Ref(DEBUG_LOCATION, "server_auth_filter");
grpc_call_context_element& context =
GetContext<grpc_call_context_element>()[GRPC_CONTEXT_SECURITY];
if (context.value != nullptr) context.destroy(context.value);
context.value = server_ctx;
context.destroy = grpc_server_security_context_destroy;
SetContext<SecurityContext>(server_ctx);
}
ServerAuthFilter::ServerAuthFilter(

Loading…
Cancel
Save