Convert promise based filter channel args to new C++ type (#29165)

* begin

* tests

* fix

* http

* Revert "Revert "HTTP Client Filter --> promises (#29031)" (#29181)"

This reverts commit 6ee276f672.

* debug

* minimal reproduction

* debug

* fix state machine for c#

* Revert "minimal reproduction"

This reverts commit 4d02d2e730.

* Revert "debug"

This reverts commit 7960842f48.

* Revert "debug"

This reverts commit a6f224e4a1.

* no-logging

* Revert "Revert "debug""

This reverts commit 951844e857.

* Better int conversion

* debug

* Fix for Cronet

* Revert "debug"

This reverts commit 4d641c4281.

* Revert "Better int conversion"

This reverts commit 4001b957cb.

* Revert "Revert "Revert "debug"""

This reverts commit d135c61043.

* fix, c++ize

* x

Co-authored-by: Jan Tattermusch <jtattermusch@google.com>
pull/29265/head
Craig Tiller 3 years ago committed by GitHub
parent e122c64000
commit 6a13c26cef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      BUILD
  2. 86
      src/core/ext/filters/channel_idle/channel_idle_filter.cc
  3. 63
      src/core/ext/filters/http/client/http_client_filter.cc
  4. 2
      src/core/ext/filters/http/client/http_client_filter.h
  5. 20
      src/core/ext/filters/http/client_authority_filter.cc
  6. 4
      src/core/ext/filters/http/client_authority_filter.h
  7. 4
      src/core/ext/filters/load_reporting/server_load_reporting_filter.cc
  8. 5
      src/core/ext/filters/load_reporting/server_load_reporting_filter.h
  9. 22
      src/core/lib/channel/channel_args.cc
  10. 9
      src/core/lib/channel/channel_args.h
  11. 2
      src/core/lib/channel/channel_stack.h
  12. 5
      src/core/lib/channel/promise_based_filter.h
  13. 3
      src/core/lib/security/authorization/authorization_policy_provider.h
  14. 8
      src/core/lib/security/authorization/grpc_server_authz_filter.cc
  15. 4
      src/core/lib/security/authorization/grpc_server_authz_filter.h
  16. 5
      src/core/lib/security/context/security_context.h
  17. 4
      src/core/lib/security/security_connector/security_connector.h
  18. 2
      src/core/lib/security/transport/auth_filters.h
  19. 8
      src/core/lib/security/transport/client_auth_filter.cc
  20. 2
      src/core/lib/transport/transport.h
  21. 1
      src/core/lib/transport/transport_impl.h
  22. 41
      test/core/filters/client_authority_filter_test.cc

@ -2433,6 +2433,7 @@ grpc_cc_library(
"match",
"ref_counted",
"ref_counted_ptr",
"time",
"useful",
],
)

@ -37,27 +37,22 @@
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/transport/http2_errors.h"
// TODO(juanlishen): The idle filter was disabled in client channel by default
namespace grpc_core {
namespace {
// TODO(ctiller): The idle filter was disabled in client channel by default
// due to b/143502997. Now the bug is fixed enable the filter by default.
#define DEFAULT_IDLE_TIMEOUT_MS INT_MAX
// The user input idle timeout smaller than this would be capped to it.
#define MIN_IDLE_TIMEOUT_MS (1 /*second*/ * 1000)
const auto kDefaultIdleTimeout = Duration::Infinity();
// If these settings change, make sure that we are not sending a GOAWAY for
// inproc transport, since a GOAWAY to inproc ends up destroying the transport.
#define DEFAULT_MAX_CONNECTION_AGE_MS INT_MAX
#define DEFAULT_MAX_CONNECTION_AGE_GRACE_MS INT_MAX
#define DEFAULT_MAX_CONNECTION_IDLE_MS INT_MAX
#define MAX_CONNECTION_AGE_JITTER 0.1
#define MAX_CONNECTION_AGE_INTEGER_OPTIONS \
{ DEFAULT_MAX_CONNECTION_AGE_MS, 1, INT_MAX }
#define MAX_CONNECTION_IDLE_INTEGER_OPTIONS \
{ DEFAULT_MAX_CONNECTION_IDLE_MS, 1, INT_MAX }
namespace grpc_core {
const auto kDefaultMaxConnectionAge = Duration::Infinity();
const auto kDefaultMaxConnectionAgeGrace = Duration::Infinity();
const auto kDefaultMaxConnectionIdle = Duration::Infinity();
const auto kMaxConnectionAgeJitter = 0.1;
TraceFlag grpc_trace_client_idle_filter(false, "client_idle_filter");
} // namespace
#define GRPC_IDLE_FILTER_LOG(format, ...) \
do { \
@ -71,13 +66,9 @@ namespace {
using SingleSetActivityPtr =
SingleSetPtr<Activity, typename ActivityPtr::deleter_type>;
Duration GetClientIdleTimeout(const grpc_channel_args* args) {
int ms = std::max(
grpc_channel_arg_get_integer(
grpc_channel_args_find(args, GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS),
{DEFAULT_IDLE_TIMEOUT_MS, 0, INT_MAX}),
MIN_IDLE_TIMEOUT_MS);
return ms == INT_MAX ? Duration::Infinity() : Duration::Milliseconds(ms);
Duration GetClientIdleTimeout(const ChannelArgs& args) {
return args.GetDurationFromIntMillis(GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS)
.value_or(kDefaultIdleTimeout);
}
struct MaxAgeConfig {
@ -95,30 +86,24 @@ struct MaxAgeConfig {
connection storms. Note that the MAX_CONNECTION_AGE option without jitter
would not create connection storms by itself, but if there happened to be a
connection storm it could cause it to repeat at a fixed period. */
MaxAgeConfig GetMaxAgeConfig(const grpc_channel_args* args) {
const int args_max_age = grpc_channel_arg_get_integer(
grpc_channel_args_find(args, GRPC_ARG_MAX_CONNECTION_AGE_MS),
MAX_CONNECTION_AGE_INTEGER_OPTIONS);
const int args_max_idle = grpc_channel_arg_get_integer(
grpc_channel_args_find(args, GRPC_ARG_MAX_CONNECTION_IDLE_MS),
MAX_CONNECTION_IDLE_INTEGER_OPTIONS);
const int args_max_age_grace = grpc_channel_arg_get_integer(
grpc_channel_args_find(args, GRPC_ARG_MAX_CONNECTION_AGE_GRACE_MS),
{DEFAULT_MAX_CONNECTION_AGE_GRACE_MS, 0, INT_MAX});
/* generate a random number between 1 - MAX_CONNECTION_AGE_JITTER and
1 + MAX_CONNECTION_AGE_JITTER */
const double multiplier =
rand() * MAX_CONNECTION_AGE_JITTER * 2.0 / RAND_MAX + 1.0 -
MAX_CONNECTION_AGE_JITTER;
MaxAgeConfig GetMaxAgeConfig(ChannelArgs args) {
const Duration args_max_age =
args.GetDurationFromIntMillis(GRPC_ARG_MAX_CONNECTION_AGE_MS)
.value_or(kDefaultMaxConnectionAge);
const Duration args_max_idle =
args.GetDurationFromIntMillis(GRPC_ARG_MAX_CONNECTION_IDLE_MS)
.value_or(kDefaultMaxConnectionIdle);
const Duration args_max_age_grace =
args.GetDurationFromIntMillis(GRPC_ARG_MAX_CONNECTION_AGE_GRACE_MS)
.value_or(kDefaultMaxConnectionAgeGrace);
/* generate a random number between 1 - kMaxConnectionAgeJitter and
1 + kMaxConnectionAgeJitter */
const double multiplier = rand() * kMaxConnectionAgeJitter * 2.0 / RAND_MAX +
1.0 - kMaxConnectionAgeJitter;
/* GRPC_MILLIS_INF_FUTURE - 0.5 converts the value to float, so that result
will not be cast to int implicitly before the comparison. */
return MaxAgeConfig{
args_max_age == INT_MAX
? Duration::Infinity()
: Duration::FromSecondsAsDouble(multiplier * args_max_age / 1000.0),
args_max_idle == INT_MAX ? Duration::Infinity()
: Duration::Milliseconds(args_max_idle),
Duration::Milliseconds(args_max_age_grace)};
return MaxAgeConfig{args_max_age * multiplier, args_max_idle,
args_max_age_grace};
}
class ChannelIdleFilter : public ChannelFilter {
@ -171,7 +156,7 @@ class ChannelIdleFilter : public ChannelFilter {
class ClientIdleFilter final : public ChannelIdleFilter {
public:
static absl::StatusOr<ClientIdleFilter> Create(
const grpc_channel_args* args, ChannelFilter::Args filter_args);
ChannelArgs args, ChannelFilter::Args filter_args);
private:
using ChannelIdleFilter::ChannelIdleFilter;
@ -179,7 +164,7 @@ class ClientIdleFilter final : public ChannelIdleFilter {
class MaxAgeFilter final : public ChannelIdleFilter {
public:
static absl::StatusOr<MaxAgeFilter> Create(const grpc_channel_args* args,
static absl::StatusOr<MaxAgeFilter> Create(ChannelArgs args,
ChannelFilter::Args filter_args);
void Start();
@ -215,14 +200,14 @@ class MaxAgeFilter final : public ChannelIdleFilter {
};
absl::StatusOr<ClientIdleFilter> ClientIdleFilter::Create(
const grpc_channel_args* args, ChannelFilter::Args filter_args) {
ChannelArgs args, ChannelFilter::Args filter_args) {
ClientIdleFilter filter(filter_args.channel_stack(),
GetClientIdleTimeout(args));
return absl::StatusOr<ClientIdleFilter>(std::move(filter));
}
absl::StatusOr<MaxAgeFilter> MaxAgeFilter::Create(
const grpc_channel_args* args, ChannelFilter::Args filter_args) {
ChannelArgs args, ChannelFilter::Args filter_args) {
const auto config = GetMaxAgeConfig(args);
MaxAgeFilter filter(filter_args.channel_stack(), config);
return absl::StatusOr<MaxAgeFilter>(std::move(filter));
@ -383,7 +368,8 @@ void RegisterChannelIdleFilters(CoreConfiguration::Builder* builder) {
[](ChannelStackBuilder* builder) {
const grpc_channel_args* channel_args = builder->channel_args();
if (!grpc_channel_args_want_minimal_stack(channel_args) &&
GetClientIdleTimeout(channel_args) != Duration::Infinity()) {
GetClientIdleTimeout(ChannelArgs::FromC(channel_args)) !=
Duration::Infinity()) {
builder->PrependFilter(&grpc_client_idle_filter, nullptr);
}
return true;
@ -393,7 +379,7 @@ void RegisterChannelIdleFilters(CoreConfiguration::Builder* builder) {
[](ChannelStackBuilder* builder) {
const grpc_channel_args* channel_args = builder->channel_args();
if (!grpc_channel_args_want_minimal_stack(channel_args) &&
GetMaxAgeConfig(channel_args).enable()) {
GetMaxAgeConfig(ChannelArgs::FromC(channel_args)).enable()) {
builder->PrependFilter(
&grpc_max_age_filter,
[](grpc_channel_stack*, grpc_channel_element* elem) {

@ -71,52 +71,26 @@ absl::Status CheckServerMetadata(ServerMetadata* b) {
return absl::OkStatus();
}
HttpSchemeMetadata::ValueType SchemeFromArgs(const grpc_channel_args* args) {
if (args != nullptr) {
for (size_t i = 0; i < args->num_args; ++i) {
if (args->args[i].type == GRPC_ARG_STRING &&
0 == strcmp(args->args[i].key, GRPC_ARG_HTTP2_SCHEME)) {
HttpSchemeMetadata::ValueType scheme = HttpSchemeMetadata::Parse(
args->args[i].value.string, [](absl::string_view, const Slice&) {});
if (scheme != HttpSchemeMetadata::kInvalid) return scheme;
}
}
}
return HttpSchemeMetadata::kHttp;
HttpSchemeMetadata::ValueType SchemeFromArgs(const ChannelArgs& args) {
HttpSchemeMetadata::ValueType scheme = HttpSchemeMetadata::Parse(
args.GetString(GRPC_ARG_HTTP2_SCHEME).value_or(""),
[](absl::string_view, const Slice&) {});
if (scheme == HttpSchemeMetadata::kInvalid) return HttpSchemeMetadata::kHttp;
return scheme;
}
Slice UserAgentFromArgs(const grpc_channel_args* args,
const char* transport_name) {
std::vector<std::string> user_agent_fields;
for (size_t i = 0; args && i < args->num_args; i++) {
if (0 == strcmp(args->args[i].key, GRPC_ARG_PRIMARY_USER_AGENT_STRING)) {
if (args->args[i].type != GRPC_ARG_STRING) {
gpr_log(GPR_ERROR, "Channel argument '%s' should be a string",
GRPC_ARG_PRIMARY_USER_AGENT_STRING);
} else {
user_agent_fields.push_back(args->args[i].value.string);
}
}
}
Slice UserAgentFromArgs(const ChannelArgs& args, const char* transport_name) {
std::vector<std::string> fields;
auto add = [&fields](absl::string_view x) {
if (!x.empty()) fields.push_back(std::string(x));
};
user_agent_fields.push_back(
absl::StrFormat("grpc-c/%s (%s; %s)", grpc_version_string(),
add(args.GetString(GRPC_ARG_PRIMARY_USER_AGENT_STRING).value_or(""));
add(absl::StrFormat("grpc-c/%s (%s; %s)", grpc_version_string(),
GPR_PLATFORM_STRING, transport_name));
add(args.GetString(GRPC_ARG_SECONDARY_USER_AGENT_STRING).value_or(""));
for (size_t i = 0; args && i < args->num_args; i++) {
if (0 == strcmp(args->args[i].key, GRPC_ARG_SECONDARY_USER_AGENT_STRING)) {
if (args->args[i].type != GRPC_ARG_STRING) {
gpr_log(GPR_ERROR, "Channel argument '%s' should be a string",
GRPC_ARG_SECONDARY_USER_AGENT_STRING);
} else {
user_agent_fields.push_back(args->args[i].value.string);
}
}
}
std::string user_agent_string = absl::StrJoin(user_agent_fields, " ");
return Slice::FromCopiedString(user_agent_string.c_str());
return Slice::FromCopiedString(absl::StrJoin(fields, " "));
}
} // namespace
@ -154,10 +128,9 @@ HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme,
Slice user_agent)
: scheme_(scheme), user_agent_(std::move(user_agent)) {}
absl::StatusOr<HttpClientFilter> HttpClientFilter::Create(
const grpc_channel_args* args, ChannelFilter::Args) {
auto* transport =
grpc_channel_args_find_pointer<grpc_transport>(args, GRPC_ARG_TRANSPORT);
absl::StatusOr<HttpClientFilter> HttpClientFilter::Create(ChannelArgs args,
ChannelFilter::Args) {
auto* transport = args.GetObject<grpc_transport>();
GPR_ASSERT(transport != nullptr);
return HttpClientFilter(SchemeFromArgs(args),
UserAgentFromArgs(args, transport->vtable->name));

@ -30,7 +30,7 @@ class HttpClientFilter : public ChannelFilter {
static const grpc_channel_filter kFilter;
static absl::StatusOr<HttpClientFilter> Create(
const grpc_channel_args* args, ChannelFilter::Args filter_args);
ChannelArgs args, ChannelFilter::Args filter_args);
// Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise(

@ -40,21 +40,15 @@
namespace grpc_core {
absl::StatusOr<ClientAuthorityFilter> ClientAuthorityFilter::Create(
const grpc_channel_args* args, ChannelFilter::Args) {
const grpc_arg* default_authority_arg =
grpc_channel_args_find(args, GRPC_ARG_DEFAULT_AUTHORITY);
if (default_authority_arg == nullptr) {
ChannelArgs args, ChannelFilter::Args) {
absl::optional<absl::string_view> default_authority =
args.GetString(GRPC_ARG_DEFAULT_AUTHORITY);
if (!default_authority.has_value()) {
return absl::InvalidArgumentError(
"GRPC_ARG_DEFAULT_AUTHORITY channel arg. not found. Note that direct "
"channels must explicitly specify a value for this argument.");
"GRPC_ARG_DEFAULT_AUTHORITY string channel arg. not found. Note that "
"direct channels must explicitly specify a value for this argument.");
}
const char* default_authority_str =
grpc_channel_arg_get_string(default_authority_arg);
if (default_authority_str == nullptr) {
return absl::InvalidArgumentError(
"GRPC_ARG_DEFAULT_AUTHORITY channel arg. must be a string");
}
return ClientAuthorityFilter(Slice::FromCopiedString(default_authority_str));
return ClientAuthorityFilter(Slice::FromCopiedString(*default_authority));
}
ArenaPromise<ServerMetadataHandle> ClientAuthorityFilter::MakeCallPromise(

@ -33,8 +33,8 @@ namespace grpc_core {
class ClientAuthorityFilter final : public ChannelFilter {
public:
static absl::StatusOr<ClientAuthorityFilter> Create(
const grpc_channel_args* args, ChannelFilter::Args);
static absl::StatusOr<ClientAuthorityFilter> Create(ChannelArgs args,
ChannelFilter::Args);
// Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise(

@ -56,10 +56,10 @@ constexpr char kEncodedIpv6AddressLengthString[] = "32";
constexpr char kEmptyAddressLengthString[] = "00";
absl::StatusOr<ServerLoadReportingFilter> ServerLoadReportingFilter::Create(
const grpc_channel_args* args, grpc_core::ChannelFilter::Args) {
grpc_core::ChannelArgs channel_args, grpc_core::ChannelFilter::Args) {
// Find and record the peer_identity.
ServerLoadReportingFilter filter;
const grpc_auth_context* auth_context = grpc_find_auth_context_in_args(args);
const auto* auth_context = channel_args.GetObject<grpc_auth_context>();
if (auth_context != nullptr &&
grpc_auth_context_peer_is_authenticated(auth_context)) {
grpc_auth_property_iterator auth_it =

@ -31,7 +31,7 @@ namespace grpc {
class ServerLoadReportingFilter : public grpc_core::ChannelFilter {
public:
static absl::StatusOr<ServerLoadReportingFilter> Create(
const grpc_channel_args* args, grpc_core::ChannelFilter::Args);
grpc_core::ChannelArgs args, grpc_core::ChannelFilter::Args);
// Getters.
const char* peer_identity() { return peer_identity_.c_str(); }
@ -49,5 +49,4 @@ class ServerLoadReportingFilter : public grpc_core::ChannelFilter {
} // namespace grpc
#endif /* GRPC_CORE_EXT_FILTERS_LOAD_REPORTING_SERVER_LOAD_REPORTING_FILTER_H \
*/
#endif // GRPC_CORE_EXT_FILTERS_LOAD_REPORTING_SERVER_LOAD_REPORTING_FILTER_H

@ -112,6 +112,19 @@ ChannelArgs ChannelArgs::Set(absl::string_view key, Value value) const {
return ChannelArgs(args_.Add(std::string(key), std::move(value)));
}
ChannelArgs ChannelArgs::Set(absl::string_view key,
absl::string_view value) const {
return Set(key, std::string(value));
}
ChannelArgs ChannelArgs::Set(absl::string_view key, const char* value) const {
return Set(key, std::string(value));
}
ChannelArgs ChannelArgs::Set(absl::string_view key, std::string value) const {
return Set(key, Value(std::move(value)));
}
ChannelArgs ChannelArgs::Remove(absl::string_view key) const {
return ChannelArgs(args_.Remove(key));
}
@ -123,6 +136,15 @@ absl::optional<int> ChannelArgs::GetInt(absl::string_view name) const {
return absl::get<int>(*v);
}
absl::optional<Duration> ChannelArgs::GetDurationFromIntMillis(
absl::string_view name) const {
auto ms = GetInt(name);
if (!ms.has_value()) return absl::nullopt;
if (*ms == INT_MAX) return Duration::Infinity();
if (*ms == INT_MIN) return Duration::NegativeInfinity();
return Duration::Milliseconds(*ms);
}
absl::optional<absl::string_view> ChannelArgs::GetString(
absl::string_view name) const {
auto* v = Get(name);

@ -33,6 +33,7 @@
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/surface/channel_stack_type.h"
// Channel args are intentionally immutable, to avoid the need for locking.
@ -131,6 +132,12 @@ class ChannelArgs {
const Value* Get(absl::string_view name) const { return args_.Lookup(name); }
GRPC_MUST_USE_RESULT ChannelArgs Set(absl::string_view name,
Value value) const;
GRPC_MUST_USE_RESULT ChannelArgs Set(absl::string_view name,
absl::string_view value) const;
GRPC_MUST_USE_RESULT ChannelArgs Set(absl::string_view name,
std::string value) const;
GRPC_MUST_USE_RESULT ChannelArgs Set(absl::string_view name,
const char* value) const;
GRPC_MUST_USE_RESULT ChannelArgs Set(grpc_arg arg) const;
template <typename T>
GRPC_MUST_USE_RESULT absl::enable_if_t<
@ -160,6 +167,8 @@ class ChannelArgs {
T* GetPointer(absl::string_view name) const {
return static_cast<T*>(GetVoidPointer(name));
}
absl::optional<Duration> GetDurationFromIntMillis(
absl::string_view name) const;
// Object based get/set.
// Deal with the common case that we set a pointer to an object under

@ -69,8 +69,6 @@ typedef struct grpc_call_element grpc_call_element;
typedef struct grpc_channel_stack grpc_channel_stack;
typedef struct grpc_call_stack grpc_call_stack;
#define GRPC_ARG_TRANSPORT "grpc.internal.transport"
struct grpc_channel_element_args {
grpc_channel_stack* channel_stack;
const grpc_channel_args* channel_args;

@ -27,6 +27,7 @@
#include <grpc/support/log.h>
#include "src/core/lib/channel/call_finalization.h"
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/channel/context.h"
#include "src/core/lib/gprpp/debug_location.h"
@ -360,7 +361,7 @@ class CallData<ChannelFilter, FilterEndpoint::kServer> : public ServerCallData {
// class SomeChannelFilter : public ChannelFilter {
// public:
// static absl::StatusOr<SomeChannelFilter> Create(
// ChannelFilter::Args filter_args);
// ChannelArgs channel_args, ChannelFilter::Args filter_args);
// };
// TODO(ctiller): allow implementing get_channel_info, start_transport_op in
// some way on ChannelFilter.
@ -410,7 +411,7 @@ MakePromiseBasedFilter(const char* name) {
// init_channel_elem
[](grpc_channel_element* elem, grpc_channel_element_args* args) {
GPR_ASSERT(!args->is_last);
auto status = F::Create(args->channel_args,
auto status = F::Create(ChannelArgs::FromC(args->channel_args),
ChannelFilter::Args(args->channel_stack));
if (!status.ok()) return absl_status_to_grpc_error(status.status());
new (elem->channel_data) F(std::move(*status));

@ -23,6 +23,9 @@
struct grpc_authorization_policy_provider
: public grpc_core::DualRefCounted<grpc_authorization_policy_provider> {
public:
static absl::string_view ChannelArgName() {
return GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER;
}
struct AuthorizationEngines {
grpc_core::RefCountedPtr<grpc_core::AuthorizationEngine> allow_engine;
grpc_core::RefCountedPtr<grpc_core::AuthorizationEngine> deny_engine;

@ -32,11 +32,9 @@ GrpcServerAuthzFilter::GrpcServerAuthzFilter(
provider_(std::move(provider)) {}
absl::StatusOr<GrpcServerAuthzFilter> GrpcServerAuthzFilter::Create(
const grpc_channel_args* args, ChannelFilter::Args) {
grpc_auth_context* auth_context = grpc_find_auth_context_in_args(args);
grpc_authorization_policy_provider* provider =
grpc_channel_args_find_pointer<grpc_authorization_policy_provider>(
args, GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER);
ChannelArgs args, ChannelFilter::Args) {
auto* auth_context = args.GetObject<grpc_auth_context>();
auto* provider = args.GetObject<grpc_authorization_policy_provider>();
if (provider == nullptr) {
return absl::InvalidArgumentError("Failed to get authorization provider.");
}

@ -27,8 +27,8 @@ class GrpcServerAuthzFilter final : public ChannelFilter {
public:
static const grpc_channel_filter kFilterVtable;
static absl::StatusOr<GrpcServerAuthzFilter> Create(
const grpc_channel_args* args, ChannelFilter::Args);
static absl::StatusOr<GrpcServerAuthzFilter> Create(ChannelArgs args,
ChannelFilter::Args);
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;

@ -43,6 +43,8 @@ struct grpc_auth_property_array {
void grpc_auth_property_reset(grpc_auth_property* property);
#define GRPC_AUTH_CONTEXT_ARG "grpc.auth_context"
// This type is forward declared as a C struct and we cannot define it as a
// class. Otherwise, compiler will complain about type mismatch due to
// -Wmismatched-tags.
@ -73,6 +75,8 @@ struct grpc_auth_context
}
}
static absl::string_view ChannelArgName() { return GRPC_AUTH_CONTEXT_ARG; }
const grpc_auth_context* chained() const { return chained_.get(); }
const grpc_auth_property_array& properties() const { return properties_; }
@ -142,7 +146,6 @@ grpc_server_security_context* grpc_server_security_context_create(
void grpc_server_security_context_destroy(void* ctx);
/* --- Channel args for auth context --- */
#define GRPC_AUTH_CONTEXT_ARG "grpc.auth_context"
grpc_arg grpc_auth_context_to_arg(grpc_auth_context* c);
grpc_auth_context* grpc_auth_context_from_arg(const grpc_arg* arg);

@ -60,6 +60,10 @@ class grpc_security_connector
url_scheme_(url_scheme) {}
~grpc_security_connector() override = default;
static absl::string_view ChannelArgName() {
return GRPC_ARG_SECURITY_CONNECTOR;
}
// Checks the peer. Callee takes ownership of the peer object.
// When done, sets *auth_context and invokes on_peer_checked.
virtual void check_peer(

@ -37,7 +37,7 @@ namespace grpc_core {
// Handles calling out to credentials to fill in metadata per call.
class ClientAuthFilter final : public ChannelFilter {
public:
static absl::StatusOr<ClientAuthFilter> Create(const grpc_channel_args* args,
static absl::StatusOr<ClientAuthFilter> Create(ChannelArgs args,
ChannelFilter::Args);
// Construct a promise for one call.

@ -185,14 +185,14 @@ ArenaPromise<ServerMetadataHandle> ClientAuthFilter::MakeCallPromise(
next_promise_factory);
}
absl::StatusOr<ClientAuthFilter> ClientAuthFilter::Create(
const grpc_channel_args* args, ChannelFilter::Args) {
grpc_security_connector* sc = grpc_security_connector_find_in_args(args);
absl::StatusOr<ClientAuthFilter> ClientAuthFilter::Create(ChannelArgs args,
ChannelFilter::Args) {
auto* sc = args.GetObject<grpc_security_connector>();
if (sc == nullptr) {
return absl::InvalidArgumentError(
"Security connector missing from client auth filter args");
}
grpc_auth_context* auth_context = grpc_find_auth_context_in_args(args);
auto* auth_context = args.GetObject<grpc_auth_context>();
if (auth_context == nullptr) {
return absl::InvalidArgumentError(
"Auth context missing from client auth filter args");

@ -44,6 +44,8 @@
#define GRPC_PROTOCOL_VERSION_MIN_MAJOR 2
#define GRPC_PROTOCOL_VERSION_MIN_MINOR 1
#define GRPC_ARG_TRANSPORT "grpc.internal.transport"
namespace grpc_core {
// TODO(ctiller): eliminate once MetadataHandle is constructable directly.
namespace promise_filter_detail {

@ -77,6 +77,7 @@ typedef struct grpc_transport_vtable {
/* an instance of a grpc transport */
struct grpc_transport {
static absl::string_view ChannelArgName() { return GRPC_ARG_TRANSPORT; }
/* pointer to a vtable defining operations on this transport */
const grpc_transport_vtable* vtable;
};

@ -25,45 +25,32 @@ namespace {
auto* g_memory_allocator = new MemoryAllocator(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator("test"));
class TestChannelArgs {
public:
explicit TestChannelArgs(const char* default_authority)
: arg_(grpc_channel_arg_string_create(
const_cast<char*>(GRPC_ARG_DEFAULT_AUTHORITY),
const_cast<char*>(default_authority))),
args_{1, &arg_} {}
const grpc_channel_args* args() const { return &args_; }
private:
grpc_arg arg_;
grpc_channel_args args_;
};
ChannelArgs TestChannelArgs(absl::string_view default_authority) {
return ChannelArgs().Set(GRPC_ARG_DEFAULT_AUTHORITY, default_authority);
}
TEST(ClientAuthorityFilterTest, DefaultFails) {
EXPECT_FALSE(
ClientAuthorityFilter::Create(nullptr, ChannelFilter::Args()).ok());
ClientAuthorityFilter::Create(ChannelArgs(), ChannelFilter::Args()).ok());
}
TEST(ClientAuthorityFilterTest, WithArgSucceeds) {
EXPECT_EQ(
ClientAuthorityFilter::Create(
TestChannelArgs("foo.test.google.au").args(), ChannelFilter::Args())
.status(),
absl::OkStatus());
EXPECT_EQ(ClientAuthorityFilter::Create(TestChannelArgs("foo.test.google.au"),
ChannelFilter::Args())
.status(),
absl::OkStatus());
}
TEST(ClientAuthorityFilterTest, NonStringArgFails) {
grpc_arg arg = grpc_channel_arg_integer_create(
const_cast<char*>(GRPC_ARG_DEFAULT_AUTHORITY), 123);
grpc_channel_args args = {1, &arg};
EXPECT_FALSE(
ClientAuthorityFilter::Create(&args, ChannelFilter::Args()).ok());
EXPECT_FALSE(ClientAuthorityFilter::Create(
ChannelArgs().Set(GRPC_ARG_DEFAULT_AUTHORITY, 123),
ChannelFilter::Args())
.ok());
}
TEST(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) {
auto filter = *ClientAuthorityFilter::Create(
TestChannelArgs("foo.test.google.au").args(), ChannelFilter::Args());
TestChannelArgs("foo.test.google.au"), ChannelFilter::Args());
auto arena = MakeScopedArena(1024, g_memory_allocator);
grpc_metadata_batch initial_metadata_batch(arena.get());
grpc_metadata_batch trailing_metadata_batch(arena.get());
@ -95,7 +82,7 @@ TEST(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) {
TEST(ClientAuthorityFilterTest,
PromiseCompletesImmediatelyAndDoesNotClobberAlreadySetsAuthority) {
auto filter = *ClientAuthorityFilter::Create(
TestChannelArgs("foo.test.google.au").args(), ChannelFilter::Args());
TestChannelArgs("foo.test.google.au"), ChannelFilter::Args());
auto arena = MakeScopedArena(1024, g_memory_allocator);
grpc_metadata_batch initial_metadata_batch(arena.get());
grpc_metadata_batch trailing_metadata_batch(arena.get());

Loading…
Cancel
Save