diff --git a/src/core/lib/channel/context.h b/src/core/lib/channel/context.h index cd3dadf30e2..4e3cdae4bfd 100644 --- a/src/core/lib/channel/context.h +++ b/src/core/lib/channel/context.h @@ -32,10 +32,6 @@ typedef enum { /// grpc_call* associated with this context. GRPC_CONTEXT_CALL = 0, - /// Value is either a \a grpc_client_security_context or a - /// \a grpc_server_security_context. - GRPC_CONTEXT_SECURITY, - /// Value is a \a census_context. GRPC_CONTEXT_TRACING, diff --git a/src/core/lib/security/context/security_context.cc b/src/core/lib/security/context/security_context.cc index 0fc65da3f42..13ccb6f48b7 100644 --- a/src/core/lib/security/context/security_context.cc +++ b/src/core/lib/security/context/security_context.cc @@ -56,12 +56,12 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call, LOG(ERROR) << "Method is client-side only."; return GRPC_CALL_ERROR_NOT_ON_SERVER; } - ctx = static_cast( - grpc_call_context_get(call, GRPC_CONTEXT_SECURITY)); + auto* arena = grpc_call_get_arena(call); + ctx = grpc_core::DownCast( + arena->GetContext()); if (ctx == nullptr) { - ctx = grpc_client_security_context_create(grpc_call_get_arena(call), creds); - grpc_call_context_set(call, GRPC_CONTEXT_SECURITY, ctx, - grpc_client_security_context_destroy); + ctx = grpc_client_security_context_create(arena, creds); + arena->SetContext(ctx); } else { ctx->creds = creds != nullptr ? creds->Ref() : nullptr; } @@ -70,11 +70,12 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call, } grpc_auth_context* grpc_call_auth_context(grpc_call* call) { - void* sec_ctx = grpc_call_context_get(call, GRPC_CONTEXT_SECURITY); + auto* sec_ctx = + grpc_call_get_arena(call)->GetContext(); GRPC_API_TRACE("grpc_call_auth_context(call=%p)", 1, (call)); if (sec_ctx == nullptr) return nullptr; if (grpc_call_is_client(call)) { - auto* sc = static_cast(sec_ctx); + auto* sc = grpc_core::DownCast(sec_ctx); if (sc->auth_context == nullptr) { return nullptr; } else { @@ -83,7 +84,7 @@ grpc_auth_context* grpc_call_auth_context(grpc_call* call) { .release(); } } else { - auto* sc = static_cast(sec_ctx); + auto* sc = grpc_core::DownCast(sec_ctx); if (sc->auth_context == nullptr) { return nullptr; } else { diff --git a/src/core/lib/security/context/security_context.h b/src/core/lib/security/context/security_context.h index f07b5b7cc19..b583c4b36ab 100644 --- a/src/core/lib/security/context/security_context.h +++ b/src/core/lib/security/context/security_context.h @@ -136,15 +136,24 @@ struct grpc_security_context_extension { void (*destroy)(void*) = nullptr; }; +namespace grpc_core { + +class SecurityContext { + public: + virtual ~SecurityContext() = default; +}; + +} // namespace grpc_core + // --- grpc_client_security_context --- // Internal client-side security context. -struct grpc_client_security_context { +struct grpc_client_security_context final : public grpc_core::SecurityContext { explicit grpc_client_security_context( grpc_core::RefCountedPtr creds) : creds(std::move(creds)) {} - ~grpc_client_security_context(); + ~grpc_client_security_context() override; grpc_core::RefCountedPtr creds; grpc_core::RefCountedPtr auth_context; @@ -159,9 +168,9 @@ void grpc_client_security_context_destroy(void* ctx); // Internal server-side security context. -struct grpc_server_security_context { +struct grpc_server_security_context final : public grpc_core::SecurityContext { grpc_server_security_context() = default; - ~grpc_server_security_context(); + ~grpc_server_security_context() override; grpc_core::RefCountedPtr auth_context; grpc_security_context_extension extension; @@ -178,4 +187,20 @@ grpc_auth_context* grpc_auth_context_from_arg(const grpc_arg* arg); grpc_auth_context* grpc_find_auth_context_in_args( const grpc_channel_args* args); +namespace grpc_core { +template <> +struct ArenaContextType { + static void Destroy(SecurityContext* p) { p->~SecurityContext(); } +}; + +template <> +struct ContextSubclass { + using Base = SecurityContext; +}; +template <> +struct ContextSubclass { + using Base = SecurityContext; +}; +} // namespace grpc_core + #endif // GRPC_SRC_CORE_LIB_SECURITY_CONTEXT_SECURITY_CONTEXT_H diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc index 1314b7d523b..be6fcdae3d8 100644 --- a/src/core/lib/security/transport/client_auth_filter.cc +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -110,8 +110,7 @@ ClientAuthFilter::ClientAuthFilter( ArenaPromise> ClientAuthFilter::GetCallCredsMetadata( CallArgs call_args) { - auto* ctx = static_cast( - GetContext()[GRPC_CONTEXT_SECURITY].value); + auto* ctx = GetContext(); grpc_call_credentials* channel_call_creds = args_.security_connector->mutable_request_metadata_creds(); const bool call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr); @@ -178,17 +177,13 @@ ArenaPromise> ClientAuthFilter::GetCallCredsMetadata( ArenaPromise ClientAuthFilter::MakeCallPromise( CallArgs call_args, NextPromiseFactory next_promise_factory) { - auto* legacy_ctx = GetContext(); - if (legacy_ctx[GRPC_CONTEXT_SECURITY].value == nullptr) { - legacy_ctx[GRPC_CONTEXT_SECURITY].value = - grpc_client_security_context_create(GetContext(), - /*creds=*/nullptr); - legacy_ctx[GRPC_CONTEXT_SECURITY].destroy = - grpc_client_security_context_destroy; + auto* sec_ctx = MaybeGetContext(); + if (sec_ctx == nullptr) { + sec_ctx = grpc_client_security_context_create(GetContext(), + /*creds=*/nullptr); + SetContext(sec_ctx); } - static_cast( - legacy_ctx[GRPC_CONTEXT_SECURITY].value) - ->auth_context = args_.auth_context; + sec_ctx->auth_context = args_.auth_context; auto* host = call_args.client_initial_metadata->get_pointer(HttpAuthorityMetadata()); diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index bfbfeb2d8ae..bf7aa2560c2 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -203,11 +203,7 @@ ServerAuthFilter::Call::Call(ServerAuthFilter* filter) { grpc_server_security_context_create(GetContext()); server_ctx->auth_context = filter->auth_context_->Ref(DEBUG_LOCATION, "server_auth_filter"); - grpc_call_context_element& context = - GetContext()[GRPC_CONTEXT_SECURITY]; - if (context.value != nullptr) context.destroy(context.value); - context.value = server_ctx; - context.destroy = grpc_server_security_context_destroy; + SetContext(server_ctx); } ServerAuthFilter::ServerAuthFilter(