[promises] Convert server auth filter (#32094)

* [promises] Convert server auth filter

* Automated change: Fix sanity tests

* fix compile

* fix

* fix

* fix

* Automated change: Fix sanity tests

* implement-wakeup

* fix-error

* Automated change: Fix sanity tests

* fix-forwarding

---------

Co-authored-by: ctiller <ctiller@users.noreply.github.com>
pull/32250/head
Craig Tiller 2 years ago committed by GitHub
parent 125141c9d2
commit d769da7229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 34
      src/core/lib/channel/promise_based_filter.cc
  2. 26
      src/core/lib/security/transport/auth_filters.h
  3. 371
      src/core/lib/security/transport/server_auth_filter.cc
  4. 2
      src/core/lib/surface/init.cc

@ -20,6 +20,7 @@
#include <initializer_list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/base/attributes.h"
@ -49,14 +50,21 @@ namespace promise_filter_detail {
namespace {
class FakeActivity final : public Activity {
public:
explicit FakeActivity(Activity* wake_activity)
: wake_activity_(wake_activity) {}
void Orphan() override {}
void ForceImmediateRepoll() override {}
Waker MakeOwningWaker() override { abort(); }
Waker MakeNonOwningWaker() override { abort(); }
Waker MakeOwningWaker() override { return wake_activity_->MakeOwningWaker(); }
Waker MakeNonOwningWaker() override {
return wake_activity_->MakeNonOwningWaker();
}
void Run(absl::FunctionRef<void()> f) {
ScopedActivity activity(this);
f();
}
private:
Activity* const wake_activity_;
};
absl::Status StatusFromMetadata(const ServerMetadata& md) {
@ -103,7 +111,7 @@ BaseCallData::BaseCallData(
}
BaseCallData::~BaseCallData() {
FakeActivity().Run([this] {
FakeActivity(this).Run([this] {
if (send_message_ != nullptr) {
send_message_->~SendMessage();
}
@ -2279,7 +2287,7 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) {
ScopedContext context(this);
// Construct the promise.
ChannelFilter* filter = static_cast<ChannelFilter*>(elem()->channel_data);
FakeActivity().Run([this, filter] {
FakeActivity(this).Run([this, filter] {
promise_ = filter->MakeCallPromise(
CallArgs{WrapMetadata(recv_initial_metadata_),
server_initial_metadata_pipe() == nullptr
@ -2297,11 +2305,6 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) {
});
// Poll once.
WakeInsideCombiner(&flusher);
if (auto* closure =
std::exchange(original_recv_initial_metadata_ready_, nullptr)) {
flusher.AddClosure(closure, absl::OkStatus(),
"original_recv_initial_metadata");
}
}
std::string ServerCallData::DebugString() const {
@ -2481,9 +2484,20 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) {
}
}
}
if (std::exchange(forward_recv_initial_metadata_callback_, false)) {
if (auto* closure =
std::exchange(original_recv_initial_metadata_ready_, nullptr)) {
flusher->AddClosure(closure, absl::OkStatus(),
"original_recv_initial_metadata");
}
}
}
void ServerCallData::OnWakeup() { abort(); } // not implemented
void ServerCallData::OnWakeup() {
Flusher flusher(this);
ScopedContext context(this);
WakeInsideCombiner(&flusher);
}
} // namespace promise_filter_detail
} // namespace grpc_core

@ -36,8 +36,6 @@
#include "src/core/lib/security/security_connector/security_connector.h"
#include "src/core/lib/transport/transport.h"
extern const grpc_channel_filter grpc_server_auth_filter;
namespace grpc_core {
// Handles calling out to credentials to fill in metadata per call.
@ -64,6 +62,30 @@ class ClientAuthFilter final : public ChannelFilter {
grpc_call_credentials::GetRequestMetadataArgs args_;
};
class ServerAuthFilter final : public ChannelFilter {
public:
static const grpc_channel_filter kFilter;
static absl::StatusOr<ServerAuthFilter> Create(const ChannelArgs& args,
ChannelFilter::Args);
// Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private:
ServerAuthFilter(RefCountedPtr<grpc_server_credentials> server_credentials,
RefCountedPtr<grpc_auth_context> auth_context);
class RunApplicationCode;
ArenaPromise<absl::StatusOr<CallArgs>> GetCallCredsMetadata(
CallArgs call_args);
RefCountedPtr<grpc_server_credentials> server_credentials_;
RefCountedPtr<grpc_auth_context> auth_context_;
};
} // namespace grpc_core
// Exposed for testing purposes only.

@ -21,28 +21,37 @@
#include <string.h>
#include <algorithm>
#include <new>
#include <atomic>
#include <functional>
#include <memory>
#include <utility>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include <grpc/grpc.h>
#include <grpc/grpc_security.h>
#include <grpc/slice.h>
#include <grpc/status.h>
#include <grpc/support/alloc.h>
#include <grpc/support/atm.h>
#include <grpc/support/log.h>
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/channel_fwd.h"
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/channel/context.h"
#include "src/core/lib/channel/promise_based_filter.h"
#include "src/core/lib/gprpp/debug_location.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/status_helper.h"
#include "src/core/lib/iomgr/call_combiner.h"
#include "src/core/lib/iomgr/closure.h"
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/transport/auth_filters.h" // IWYU pragma: keep
@ -51,88 +60,33 @@
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h"
static void recv_initial_metadata_ready(void* arg, grpc_error_handle error);
static void recv_trailing_metadata_ready(void* user_data,
grpc_error_handle error);
namespace grpc_core {
namespace {
enum async_state {
STATE_INIT = 0,
STATE_DONE,
STATE_CANCELLED,
};
struct channel_data {
channel_data(grpc_auth_context* auth_context, grpc_server_credentials* creds)
: auth_context(auth_context->Ref()), creds(creds->Ref()) {}
~channel_data() { auth_context.reset(DEBUG_LOCATION, "server_auth_filter"); }
grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
grpc_core::RefCountedPtr<grpc_server_credentials> creds;
};
const grpc_channel_filter ServerAuthFilter::kFilter =
MakePromiseBasedFilter<ServerAuthFilter, FilterEndpoint::kServer>(
"server-auth");
struct call_data {
call_data(grpc_call_element* elem, const grpc_call_element_args& args)
: call_combiner(args.call_combiner), owning_call(args.call_stack) {
GRPC_CLOSURE_INIT(&recv_initial_metadata_ready,
::recv_initial_metadata_ready, elem,
grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready,
::recv_trailing_metadata_ready, elem,
grpc_schedule_on_exec_ctx);
// Create server security context. Set its auth context from channel
// data and save it in the call context.
grpc_server_security_context* server_ctx =
grpc_server_security_context_create(args.arena);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
server_ctx->auth_context =
chand->auth_context->Ref(DEBUG_LOCATION, "server_auth_filter");
if (args.context[GRPC_CONTEXT_SECURITY].value != nullptr) {
args.context[GRPC_CONTEXT_SECURITY].destroy(
args.context[GRPC_CONTEXT_SECURITY].value);
}
args.context[GRPC_CONTEXT_SECURITY].value = server_ctx;
args.context[GRPC_CONTEXT_SECURITY].destroy =
grpc_server_security_context_destroy;
}
~call_data() {}
grpc_core::CallCombiner* call_combiner;
grpc_call_stack* owning_call;
grpc_transport_stream_op_batch* recv_initial_metadata_batch;
grpc_closure* original_recv_initial_metadata_ready;
grpc_closure recv_initial_metadata_ready;
grpc_error_handle recv_initial_metadata_error;
grpc_closure recv_trailing_metadata_ready;
grpc_closure* original_recv_trailing_metadata_ready;
grpc_error_handle recv_trailing_metadata_error;
bool seen_recv_trailing_metadata_ready = false;
grpc_metadata_array md;
grpc_closure cancel_closure;
gpr_atm state = STATE_INIT; // async_state
};
namespace {
class ArrayEncoder {
public:
explicit ArrayEncoder(grpc_metadata_array* result) : result_(result) {}
void Encode(const grpc_core::Slice& key, const grpc_core::Slice& value) {
void Encode(const Slice& key, const Slice& value) {
Append(key.Ref(), value.Ref());
}
template <typename Which>
void Encode(Which, const typename Which::ValueType& value) {
Append(grpc_core::Slice(
grpc_core::StaticSlice::FromStaticString(Which::key())),
grpc_core::Slice(Which::Encode(value)));
Append(Slice(StaticSlice::FromStaticString(Which::key())),
Slice(Which::Encode(value)));
}
void Encode(grpc_core::HttpMethodMetadata,
const typename grpc_core::HttpMethodMetadata::ValueType&) {}
void Encode(HttpMethodMetadata,
const typename HttpMethodMetadata::ValueType&) {}
private:
void Append(grpc_core::Slice key, grpc_core::Slice value) {
void Append(Slice key, Slice value) {
if (result_->count == result_->capacity) {
result_->capacity =
std::max(result_->capacity + 8, result_->capacity * 2);
@ -147,9 +101,9 @@ class ArrayEncoder {
grpc_metadata_array* result_;
};
} // namespace
static grpc_metadata_array metadata_batch_to_md_array(
// TODO(ctiller): seek out all users of this functionality and change API so
// that this unilateral format conversion IS NOT REQUIRED.
grpc_metadata_array MetadataBatchToMetadataArray(
const grpc_metadata_batch* batch) {
grpc_metadata_array result;
grpc_metadata_array_init(&result);
@ -158,202 +112,117 @@ static grpc_metadata_array metadata_batch_to_md_array(
return result;
}
static void on_md_processing_done_inner(grpc_call_element* elem,
const grpc_metadata* consumed_md,
size_t num_consumed_md,
const grpc_metadata* response_md,
size_t num_response_md,
grpc_error_handle error) {
call_data* calld = static_cast<call_data*>(elem->call_data);
grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch;
// TODO(ZhenLian): Implement support for response_md.
if (response_md != nullptr && num_response_md > 0) {
gpr_log(GPR_ERROR,
"response_md in auth metadata processing not supported for now. "
"Ignoring...");
} // namespace
class ServerAuthFilter::RunApplicationCode {
public:
// TODO(ctiller): Allocate state_ into a pool on the arena to reuse this
// memory later
RunApplicationCode(ServerAuthFilter* filter, CallArgs call_args)
: state_(GetContext<Arena>()->ManagedNew<State>(std::move(call_args))) {
filter->server_credentials_->auth_metadata_processor().process(
filter->server_credentials_->auth_metadata_processor().state,
filter->auth_context_.get(), state_->md.metadata, state_->md.count,
OnMdProcessingDone, state_);
}
if (error.ok()) {
for (size_t i = 0; i < num_consumed_md; i++) {
batch->payload->recv_initial_metadata.recv_initial_metadata->Remove(
grpc_core::StringViewFromSlice(consumed_md[i].key));
Poll<absl::StatusOr<CallArgs>> operator()() {
if (state_->done.load(std::memory_order_acquire)) {
return Poll<absl::StatusOr<CallArgs>>(std::move(state_->call_args));
}
return Pending{};
}
calld->recv_initial_metadata_error = error;
grpc_closure* closure = calld->original_recv_initial_metadata_ready;
calld->original_recv_initial_metadata_ready = nullptr;
if (calld->seen_recv_trailing_metadata_ready) {
GRPC_CALL_COMBINER_START(calld->call_combiner,
&calld->recv_trailing_metadata_ready,
calld->recv_trailing_metadata_error,
"continue recv_trailing_metadata_ready");
}
grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error);
}
// Called from application code.
static void on_md_processing_done(
void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md,
const grpc_metadata* response_md, size_t num_response_md,
grpc_status_code status, const char* error_details) {
grpc_call_element* elem = static_cast<grpc_call_element*>(user_data);
call_data* calld = static_cast<call_data*>(elem->call_data);
grpc_core::ApplicationCallbackExecCtx callback_exec_ctx;
grpc_core::ExecCtx exec_ctx;
// If the call was not cancelled while we were in flight, process the result.
if (gpr_atm_full_cas(&calld->state, static_cast<gpr_atm>(STATE_INIT),
static_cast<gpr_atm>(STATE_DONE))) {
grpc_error_handle error;
if (status != GRPC_STATUS_OK) {
private:
struct State {
explicit State(CallArgs call_args) : call_args(std::move(call_args)) {}
Waker waker{Activity::current()->MakeOwningWaker()};
absl::StatusOr<CallArgs> call_args;
grpc_metadata_array md =
MetadataBatchToMetadataArray(call_args->client_initial_metadata.get());
std::atomic<bool> done{false};
};
// Called from application code.
static void OnMdProcessingDone(
void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md,
const grpc_metadata* response_md, size_t num_response_md,
grpc_status_code status, const char* error_details) {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto* state = static_cast<State*>(user_data);
// TODO(ZhenLian): Implement support for response_md.
if (response_md != nullptr && num_response_md > 0) {
gpr_log(GPR_ERROR,
"response_md in auth metadata processing not supported for now. "
"Ignoring...");
}
if (status == GRPC_STATUS_OK) {
ClientMetadataHandle& md = state->call_args->client_initial_metadata;
for (size_t i = 0; i < num_consumed_md; i++) {
md->Remove(StringViewFromSlice(consumed_md[i].key));
}
} else {
if (error_details == nullptr) {
error_details = "Authentication metadata processing failed.";
}
error =
grpc_error_set_int(GRPC_ERROR_CREATE(error_details),
grpc_core::StatusIntProperty::kRpcStatus, status);
state->call_args = grpc_error_set_int(
absl::Status(static_cast<absl::StatusCode>(status), error_details),
StatusIntProperty::kRpcStatus, status);
}
on_md_processing_done_inner(elem, consumed_md, num_consumed_md, response_md,
num_response_md, error);
}
// Clean up.
for (size_t i = 0; i < calld->md.count; i++) {
grpc_core::CSliceUnref(calld->md.metadata[i].key);
grpc_core::CSliceUnref(calld->md.metadata[i].value);
}
grpc_metadata_array_destroy(&calld->md);
GRPC_CALL_STACK_UNREF(calld->owning_call, "server_auth_metadata");
}
static void cancel_call(void* arg, grpc_error_handle error) {
grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
call_data* calld = static_cast<call_data*>(elem->call_data);
// If the result was not already processed, invoke the callback now.
if (!error.ok() &&
gpr_atm_full_cas(&calld->state, static_cast<gpr_atm>(STATE_INIT),
static_cast<gpr_atm>(STATE_CANCELLED))) {
on_md_processing_done_inner(elem, nullptr, 0, nullptr, 0, error);
}
GRPC_CALL_STACK_UNREF(calld->owning_call, "cancel_call");
}
static void recv_initial_metadata_ready(void* arg, grpc_error_handle error) {
grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
call_data* calld = static_cast<call_data*>(elem->call_data);
grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch;
if (error.ok()) {
if (chand->creds != nullptr &&
chand->creds->auth_metadata_processor().process != nullptr) {
// We're calling out to the application, so we need to make sure
// to drop the call combiner early if we get cancelled.
// TODO(yashykt): We would not need this ref if call combiners used
// Closure::Run() instead of ExecCtx::Run()
GRPC_CALL_STACK_REF(calld->owning_call, "cancel_call");
GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem,
grpc_schedule_on_exec_ctx);
calld->call_combiner->SetNotifyOnCancel(&calld->cancel_closure);
GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata");
calld->md = metadata_batch_to_md_array(
batch->payload->recv_initial_metadata.recv_initial_metadata);
chand->creds->auth_metadata_processor().process(
chand->creds->auth_metadata_processor().state,
chand->auth_context.get(), calld->md.metadata, calld->md.count,
on_md_processing_done, elem);
return;
// Clean up.
for (size_t i = 0; i < state->md.count; i++) {
CSliceUnref(state->md.metadata[i].key);
CSliceUnref(state->md.metadata[i].value);
}
}
grpc_closure* closure = calld->original_recv_initial_metadata_ready;
calld->original_recv_initial_metadata_ready = nullptr;
if (calld->seen_recv_trailing_metadata_ready) {
GRPC_CALL_COMBINER_START(calld->call_combiner,
&calld->recv_trailing_metadata_ready,
calld->recv_trailing_metadata_error,
"continue recv_trailing_metadata_ready");
}
grpc_core::Closure::Run(DEBUG_LOCATION, closure, error);
}
grpc_metadata_array_destroy(&state->md);
static void recv_trailing_metadata_ready(void* user_data,
grpc_error_handle err) {
grpc_call_element* elem = static_cast<grpc_call_element*>(user_data);
call_data* calld = static_cast<call_data*>(elem->call_data);
if (calld->original_recv_initial_metadata_ready != nullptr) {
calld->recv_trailing_metadata_error = err;
calld->seen_recv_trailing_metadata_ready = true;
GRPC_CALL_COMBINER_STOP(calld->call_combiner,
"deferring recv_trailing_metadata_ready until "
"after recv_initial_metadata_ready");
return;
auto waker = std::move(state->waker);
state->done.store(true, std::memory_order_release);
waker.Wakeup();
}
err = grpc_error_add_child(err, calld->recv_initial_metadata_error);
grpc_core::Closure::Run(DEBUG_LOCATION,
calld->original_recv_trailing_metadata_ready, err);
}
static void server_auth_start_transport_stream_op_batch(
grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
call_data* calld = static_cast<call_data*>(elem->call_data);
if (batch->recv_initial_metadata) {
// Inject our callback.
calld->recv_initial_metadata_batch = batch;
calld->original_recv_initial_metadata_ready =
batch->payload->recv_initial_metadata.recv_initial_metadata_ready;
batch->payload->recv_initial_metadata.recv_initial_metadata_ready =
&calld->recv_initial_metadata_ready;
}
if (batch->recv_trailing_metadata) {
calld->original_recv_trailing_metadata_ready =
batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready;
batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready =
&calld->recv_trailing_metadata_ready;
State* state_;
};
ArenaPromise<ServerMetadataHandle> ServerAuthFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
// Create server security context. Set its auth context from channel
// data and save it in the call context.
grpc_server_security_context* server_ctx =
grpc_server_security_context_create(GetContext<Arena>());
server_ctx->auth_context =
auth_context_->Ref(DEBUG_LOCATION, "server_auth_filter");
grpc_call_context_element& context =
GetContext<grpc_call_context_element>()[GRPC_CONTEXT_SECURITY];
if (context.value != nullptr) context.destroy(context.value);
context.value = server_ctx;
context.destroy = grpc_server_security_context_destroy;
if (server_credentials_ == nullptr ||
server_credentials_->auth_metadata_processor().process == nullptr) {
return next_promise_factory(std::move(call_args));
}
grpc_call_next_op(elem, batch);
}
// Constructor for call_data
static grpc_error_handle server_auth_init_call_elem(
grpc_call_element* elem, const grpc_call_element_args* args) {
new (elem->call_data) call_data(elem, *args);
return absl::OkStatus();
return TrySeq(RunApplicationCode(this, std::move(call_args)),
std::move(next_promise_factory));
}
// Destructor for call_data
static void server_auth_destroy_call_elem(
grpc_call_element* elem, const grpc_call_final_info* /*final_info*/,
grpc_closure* /*ignored*/) {
call_data* calld = static_cast<call_data*>(elem->call_data);
calld->~call_data();
}
ServerAuthFilter::ServerAuthFilter(
RefCountedPtr<grpc_server_credentials> server_credentials,
RefCountedPtr<grpc_auth_context> auth_context)
: server_credentials_(server_credentials), auth_context_(auth_context) {}
// Constructor for channel_data
static grpc_error_handle server_auth_init_channel_elem(
grpc_channel_element* elem, grpc_channel_element_args* args) {
GPR_ASSERT(!args->is_last);
grpc_auth_context* auth_context =
grpc_find_auth_context_in_args(args->channel_args);
absl::StatusOr<ServerAuthFilter> ServerAuthFilter::Create(
const ChannelArgs& args, ChannelFilter::Args) {
auto auth_context = args.GetObjectRef<grpc_auth_context>();
GPR_ASSERT(auth_context != nullptr);
grpc_server_credentials* creds =
grpc_find_server_credentials_in_args(args->channel_args);
new (elem->channel_data) channel_data(auth_context, creds);
return absl::OkStatus();
}
// Destructor for channel data
static void server_auth_destroy_channel_elem(grpc_channel_element* elem) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
chand->~channel_data();
auto creds = args.GetObjectRef<grpc_server_credentials>();
return ServerAuthFilter(std::move(creds), std::move(auth_context));
}
const grpc_channel_filter grpc_server_auth_filter = {
server_auth_start_transport_stream_op_batch,
nullptr,
grpc_channel_next_op,
sizeof(call_data),
server_auth_init_call_elem,
grpc_call_stack_ignore_set_pollset_or_pollset_set,
server_auth_destroy_call_elem,
sizeof(channel_data),
server_auth_init_channel_elem,
grpc_channel_stack_no_post_init,
server_auth_destroy_channel_elem,
grpc_channel_next_get_info,
"server-auth"};
} // namespace grpc_core

@ -80,7 +80,7 @@ static bool maybe_prepend_client_auth_filter(
static bool maybe_prepend_server_auth_filter(
grpc_core::ChannelStackBuilder* builder) {
if (builder->channel_args().Contains(GRPC_SERVER_CREDENTIALS_ARG)) {
builder->PrependFilter(&grpc_server_auth_filter);
builder->PrependFilter(&grpc_core::ServerAuthFilter::kFilter);
}
return true;
}

Loading…
Cancel
Save