pull/35296/head
Craig Tiller 1 year ago
parent 41bc454beb
commit e3836bc027
  1. 9
      src/core/lib/security/transport/auth_filters.h
  2. 149
      src/core/lib/security/transport/server_auth_filter.cc

@ -83,6 +83,15 @@ class ServerAuthFilter final : public ImplementChannelFilter<ServerAuthFilter> {
Poll<absl::Status> operator()();
private:
// Called from application code.
static void OnMdProcessingDone(void* user_data,
const grpc_metadata* consumed_md,
size_t num_consumed_md,
const grpc_metadata* response_md,
size_t num_response_md,
grpc_status_code status,
const char* error_details);
struct State;
State* state_;
};

@ -66,6 +66,12 @@ const grpc_channel_filter ServerAuthFilter::kFilter =
MakePromiseBasedFilter<ServerAuthFilter, FilterEndpoint::kServer>(
"server-auth");
const NoInterceptor ServerAuthFilter::Call::OnClientToServerMessage;
const NoInterceptor ServerAuthFilter::Call::OnServerToClientMessage;
const NoInterceptor ServerAuthFilter::Call::OnServerInitialMetadata;
const NoInterceptor ServerAuthFilter::Call::OnServerTrailingMetadata;
const NoInterceptor ServerAuthFilter::Call::OnFinalize;
namespace {
class ArrayEncoder {
@ -114,98 +120,79 @@ grpc_metadata_array MetadataBatchToMetadataArray(
} // namespace
#if 0
class ServerAuthFilter::RunApplicationCode {
public:
// TODO(ctiller): Allocate state_ into a pool on the arena to reuse this
// memory later
RunApplicationCode(ServerAuthFilter* filter, CallArgs call_args)
: state_(GetContext<Arena>()->ManagedNew<State>(std::move(call_args))) {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_ERROR,
"%s[server-auth]: Delegate to application: filter=%p this=%p "
"auth_ctx=%p",
Activity::current()->DebugTag().c_str(), filter, this,
filter->auth_context_.get());
}
filter->server_credentials_->auth_metadata_processor().process(
filter->server_credentials_->auth_metadata_processor().state,
filter->auth_context_.get(), state_->md.metadata, state_->md.count,
OnMdProcessingDone, state_);
}
struct ServerAuthFilter::RunApplicationCode::State {
explicit State(ClientMetadata& client_metadata)
: client_metadata(&client_metadata) {}
Waker waker{Activity::current()->MakeOwningWaker()};
absl::StatusOr<ClientMetadata*> client_metadata;
grpc_metadata_array md = MetadataBatchToMetadataArray(*client_metadata);
std::atomic<bool> done{false};
};
RunApplicationCode(const RunApplicationCode&) = delete;
RunApplicationCode& operator=(const RunApplicationCode&) = delete;
RunApplicationCode(RunApplicationCode&& other) noexcept
: state_(std::exchange(other.state_, nullptr)) {}
RunApplicationCode& operator=(RunApplicationCode&& other) noexcept {
state_ = std::exchange(other.state_, nullptr);
return *this;
ServerAuthFilter::RunApplicationCode::RunApplicationCode(
ServerAuthFilter* filter, ClientMetadata& metadata)
: state_(GetContext<Arena>()->ManagedNew<State>(metadata)) {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_ERROR,
"%s[server-auth]: Delegate to application: filter=%p this=%p "
"auth_ctx=%p",
Activity::current()->DebugTag().c_str(), filter, this,
filter->auth_context_.get());
}
filter->server_credentials_->auth_metadata_processor().process(
filter->server_credentials_->auth_metadata_processor().state,
filter->auth_context_.get(), state_->md.metadata, state_->md.count,
OnMdProcessingDone, state_);
}
Poll<absl::StatusOr<CallArgs>> operator()() {
if (state_->done.load(std::memory_order_acquire)) {
return Poll<absl::StatusOr<CallArgs>>(std::move(state_->call_args));
}
return Pending{};
Poll<absl::Status> ServerAuthFilter::RunApplicationCode::operator()() {
if (state_->done.load(std::memory_order_acquire)) {
return Poll<absl::Status>(std::move(state_->client_metadata).status());
}
return Pending{};
}
private:
struct State {
explicit State(CallArgs call_args) : call_args(std::move(call_args)) {}
Waker waker{Activity::current()->MakeOwningWaker()};
absl::StatusOr<CallArgs> call_args;
grpc_metadata_array md =
MetadataBatchToMetadataArray(call_args->client_initial_metadata.get());
std::atomic<bool> done{false};
};
// Called from application code.
static void OnMdProcessingDone(
void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md,
const grpc_metadata* response_md, size_t num_response_md,
grpc_status_code status, const char* error_details) {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto* state = static_cast<State*>(user_data);
// TODO(ZhenLian): Implement support for response_md.
if (response_md != nullptr && num_response_md > 0) {
gpr_log(GPR_ERROR,
"response_md in auth metadata processing not supported for now. "
"Ignoring...");
}
void ServerAuthFilter::RunApplicationCode::OnMdProcessingDone(
void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md,
const grpc_metadata* response_md, size_t num_response_md,
grpc_status_code status, const char* error_details) {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
if (status == GRPC_STATUS_OK) {
ClientMetadataHandle& md = state->call_args->client_initial_metadata;
for (size_t i = 0; i < num_consumed_md; i++) {
md->Remove(StringViewFromSlice(consumed_md[i].key));
}
} else {
if (error_details == nullptr) {
error_details = "Authentication metadata processing failed.";
}
state->call_args = grpc_error_set_int(
absl::Status(static_cast<absl::StatusCode>(status), error_details),
StatusIntProperty::kRpcStatus, status);
}
auto* state = static_cast<State*>(user_data);
// TODO(ZhenLian): Implement support for response_md.
if (response_md != nullptr && num_response_md > 0) {
gpr_log(GPR_ERROR,
"response_md in auth metadata processing not supported for now. "
"Ignoring...");
}
// Clean up.
for (size_t i = 0; i < state->md.count; i++) {
CSliceUnref(state->md.metadata[i].key);
CSliceUnref(state->md.metadata[i].value);
if (status == GRPC_STATUS_OK) {
ClientMetadata& md = **state->client_metadata;
for (size_t i = 0; i < num_consumed_md; i++) {
md.Remove(StringViewFromSlice(consumed_md[i].key));
}
grpc_metadata_array_destroy(&state->md);
} else {
if (error_details == nullptr) {
error_details = "Authentication metadata processing failed.";
}
state->client_metadata = grpc_error_set_int(
absl::Status(static_cast<absl::StatusCode>(status), error_details),
StatusIntProperty::kRpcStatus, status);
}
auto waker = std::move(state->waker);
state->done.store(true, std::memory_order_release);
waker.Wakeup();
// Clean up.
for (size_t i = 0; i < state->md.count; i++) {
CSliceUnref(state->md.metadata[i].key);
CSliceUnref(state->md.metadata[i].value);
}
grpc_metadata_array_destroy(&state->md);
State* state_;
};
#endif
auto waker = std::move(state->waker);
state->done.store(true, std::memory_order_release);
waker.Wakeup();
}
ServerAuthFilter::Call::Call(ServerAuthFilter* filter) {
// Create server security context. Set its auth context from channel

Loading…
Cancel
Save