Explicit method for comparing channel credentials (#28844)

* Use an explicit cmp method on grpc_channel_credentials

* Add testing

* Add TODOs to improve cmp methods

* Automated change: Fix sanity tests

* clang format

* Reviewer comments

* clang-format

* Add cmp method for grpc_call_credentials

* s/overriden/overridden

Co-authored-by: yashykt <yashykt@users.noreply.github.com>
pull/28894/head
Yash Tibrewal 3 years ago committed by GitHub
parent 35708ff6b4
commit 5f67cd07f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      src/core/lib/http/httpcli_security_connector.cc
  2. 6
      src/core/lib/security/credentials/alts/alts_credentials.h
  3. 12
      src/core/lib/security/credentials/composite/composite_credentials.h
  4. 3
      src/core/lib/security/credentials/credentials.cc
  5. 33
      src/core/lib/security/credentials/credentials.h
  6. 7
      src/core/lib/security/credentials/fake/fake_credentials.cc
  7. 6
      src/core/lib/security/credentials/fake/fake_credentials.h
  8. 6
      src/core/lib/security/credentials/google_default/google_default_credentials.h
  9. 6
      src/core/lib/security/credentials/iam/iam_credentials.h
  10. 7
      src/core/lib/security/credentials/insecure/insecure_credentials.cc
  11. 6
      src/core/lib/security/credentials/jwt/jwt_credentials.h
  12. 6
      src/core/lib/security/credentials/local/local_credentials.h
  13. 12
      src/core/lib/security/credentials/oauth2/oauth2_credentials.h
  14. 6
      src/core/lib/security/credentials/plugin/plugin_credentials.h
  15. 6
      src/core/lib/security/credentials/ssl/ssl_credentials.h
  16. 6
      src/core/lib/security/credentials/tls/tls_credentials.h
  17. 6
      src/core/lib/security/credentials/xds/xds_credentials.h
  18. 2
      src/core/lib/security/security_connector/security_connector.cc
  19. 41
      test/core/security/credentials_test.cc

@ -190,6 +190,13 @@ class HttpRequestSSLCredentials : public grpc_channel_credentials {
grpc_channel_args* update_arguments(grpc_channel_args* args) override {
return args;
}
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return QsortCompare(static_cast<const grpc_channel_credentials*>(this),
other);
}
};
} // namespace

@ -44,6 +44,12 @@ class grpc_alts_credentials final : public grpc_channel_credentials {
const char* handshaker_service_url() const { return handshaker_service_url_; }
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
grpc_alts_credentials_options* options_;
char* handshaker_service_url_;
};

@ -63,6 +63,12 @@ class grpc_composite_channel_credentials : public grpc_channel_credentials {
grpc_call_credentials* mutable_call_creds() { return call_creds_.get(); }
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
grpc_core::RefCountedPtr<grpc_channel_credentials> inner_creds_;
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds_;
};
@ -97,6 +103,12 @@ class grpc_composite_call_credentials : public grpc_call_credentials {
std::string debug_string() override;
private:
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_call_credentials*>(this), other);
}
void push_to_inner(grpc_core::RefCountedPtr<grpc_call_credentials> creds,
bool is_composite);
grpc_security_level min_security_level_;

@ -58,7 +58,8 @@ static void* credentials_pointer_arg_copy(void* p) {
}
static int credentials_pointer_cmp(void* a, void* b) {
return grpc_core::QsortCompare(a, b);
return static_cast<const grpc_channel_credentials*>(a)->cmp(
static_cast<const grpc_channel_credentials*>(b));
}
static const grpc_arg_pointer_vtable credentials_pointer_vtable = {

@ -130,9 +130,28 @@ struct grpc_channel_credentials
return args;
}
// Compares this grpc_channel_credentials object with \a other.
// If this method returns 0, it means that gRPC can treat the two channel
// credentials as effectively the same. This method is used to compare
// `grpc_channel_credentials` objects when they are present in channel_args.
// One important usage of this is when channel args are used in SubchannelKey,
// which leads to a useful property that allows subchannels to be reused when
// two different `grpc_channel_credentials` objects are used but they compare
// as equal (assuming other channel args match).
int cmp(const grpc_channel_credentials* other) const {
GPR_ASSERT(other != nullptr);
int r = strcmp(type(), other->type());
if (r != 0) return r;
return cmp_impl(other);
}
const char* type() const { return type_; }
private:
// Implementation for `cmp` method intended to be overridden by subclasses.
// Only invoked if `type()` and `other->type()` compare equal as strings.
virtual int cmp_impl(const grpc_channel_credentials* other) const = 0;
const char* type_;
};
@ -193,6 +212,16 @@ struct grpc_call_credentials
return min_security_level_;
}
// Compares this grpc_call_credentials object with \a other.
// If this method returns 0, it means that gRPC can treat the two call
// credentials as effectively the same..
int cmp(const grpc_call_credentials* other) const {
GPR_ASSERT(other != nullptr);
int r = strcmp(type(), other->type());
if (r != 0) return r;
return cmp_impl(other);
}
virtual std::string debug_string() {
return "grpc_call_credentials did not provide debug string";
}
@ -200,6 +229,10 @@ struct grpc_call_credentials
const char* type() const { return type_; }
private:
// Implementation for `cmp` method intended to be overridden by subclasses.
// Only invoked if `type()` and `other->type()` compare equal as strings.
virtual int cmp_impl(const grpc_call_credentials* other) const = 0;
const char* type_;
const grpc_security_level min_security_level_;
};

@ -49,6 +49,13 @@ class grpc_fake_channel_credentials final : public grpc_channel_credentials {
return grpc_fake_channel_security_connector_create(
this->Ref(), std::move(call_creds), target, args);
}
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
};
class grpc_fake_server_credentials final : public grpc_server_credentials {

@ -80,6 +80,12 @@ class grpc_md_only_test_credentials : public grpc_call_credentials {
std::string debug_string() override { return "MD only Test Credentials"; };
private:
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_call_credentials*>(this), other);
}
grpc_core::Slice key_;
grpc_core::Slice value_;
bool is_async_;

@ -66,6 +66,12 @@ class grpc_google_default_channel_credentials
const grpc_channel_credentials* ssl_creds() const { return ssl_creds_.get(); }
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds_;
grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds_;
};

@ -42,6 +42,12 @@ class grpc_google_iam_credentials : public grpc_call_credentials {
std::string debug_string() override { return debug_string_; }
private:
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_call_credentials*>(this), other);
}
const absl::optional<grpc_core::Slice> token_;
const grpc_core::Slice authority_selector_;
const std::string debug_string_;

@ -36,6 +36,13 @@ class InsecureCredentials final : public grpc_channel_credentials {
return MakeRefCounted<InsecureChannelSecurityConnector>(
Ref(), std::move(call_creds));
}
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return QsortCompare(static_cast<const grpc_channel_credentials*>(this),
other);
}
};
class InsecureServerCredentials final : public grpc_server_credentials {

@ -59,6 +59,12 @@ class grpc_service_account_jwt_access_credentials
};
private:
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_call_credentials*>(this), other);
}
// Have a simple cache for now with just 1 entry. We could have a map based on
// the service_url for a more sophisticated one.
gpr_mu cache_mu_;

@ -40,6 +40,12 @@ class grpc_local_credentials final : public grpc_channel_credentials {
grpc_local_connect_type connect_type() const { return connect_type_; }
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
grpc_local_connect_type connect_type_;
};

@ -110,6 +110,12 @@ class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials {
grpc_millis deadline) = 0;
private:
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_call_credentials*>(this), other);
}
gpr_mu mu_;
absl::optional<grpc_core::Slice> access_token_value_;
gpr_timespec token_expiration_;
@ -161,6 +167,12 @@ class grpc_access_token_credentials final : public grpc_call_credentials {
std::string debug_string() override;
private:
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_call_credentials*>(this), other);
}
const grpc_core::Slice access_token_value_;
};

@ -63,6 +63,12 @@ struct grpc_plugin_credentials final : public grpc_call_credentials {
std::string debug_string() override;
private:
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_call_credentials*>(this), other);
}
void pending_request_remove_locked(pending_request* pending_request);
grpc_metadata_credentials_plugin plugin_;

@ -43,6 +43,12 @@ class grpc_ssl_credentials : public grpc_channel_credentials {
void set_max_tls_version(grpc_tls_version max_tls_version);
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
void build_config(const char* pem_root_certs,
grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const grpc_ssl_verify_peer_options* verify_options);

@ -41,6 +41,12 @@ class TlsCredentials final : public grpc_channel_credentials {
grpc_tls_credentials_options* options() const { return options_.get(); }
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
grpc_core::RefCountedPtr<grpc_tls_credentials_options> options_;
};

@ -42,6 +42,12 @@ class XdsCredentials final : public grpc_channel_credentials {
const grpc_channel_args* args, grpc_channel_args** new_args) override;
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return QsortCompare(static_cast<const grpc_channel_credentials*>(this),
other);
}
RefCountedPtr<grpc_channel_credentials> fallback_credentials_;
};

@ -60,7 +60,7 @@ int grpc_channel_security_connector::channel_security_connector_cmp(
static_cast<const grpc_channel_security_connector*>(other);
GPR_ASSERT(channel_creds() != nullptr);
GPR_ASSERT(other_sc->channel_creds() != nullptr);
int c = grpc_core::QsortCompare(channel_creds(), other_sc->channel_creds());
int c = channel_creds()->cmp(other_sc->channel_creds());
if (c != 0) return c;
return grpc_core::QsortCompare(request_metadata_creds(),
other_sc->request_metadata_creds());

@ -555,6 +555,13 @@ class check_channel_oauth2 final : public grpc_channel_credentials {
0);
return nullptr;
}
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
};
} // namespace
@ -642,6 +649,13 @@ class check_channel_oauth2_google_iam final : public grpc_channel_credentials {
0);
return nullptr;
}
private:
int cmp_impl(const grpc_channel_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_channel_credentials*>(this), other);
}
};
} // namespace
@ -1765,6 +1779,13 @@ struct fake_call_creds : public grpc_call_credentials {
void cancel_get_request_metadata(
grpc_core::CredentialsMetadataArray* /*md_array*/,
grpc_error_handle /*error*/) override {}
private:
int cmp_impl(const grpc_call_credentials* other) const override {
// TODO(yashykt): Check if we can do something better here
return grpc_core::QsortCompare(
static_cast<const grpc_call_credentials*>(this), other);
}
};
static void test_google_default_creds_not_default(void) {
@ -3544,7 +3565,7 @@ test_external_account_credentials_create_failure_invalid_workforce_pool_audience
static void test_insecure_credentials_compare_success(void) {
auto* insecure_creds_1 = grpc_insecure_credentials_create();
auto* insecure_creds_2 = grpc_insecure_credentials_create();
GPR_ASSERT(grpc_core::QsortCompare(insecure_creds_1, insecure_creds_2) == 0);
GPR_ASSERT(insecure_creds_1->cmp(insecure_creds_2) == 0);
grpc_arg arg_1 = grpc_channel_credentials_to_arg(insecure_creds_1);
grpc_channel_args args_1 = {1, &arg_1};
grpc_arg arg_2 = grpc_channel_credentials_to_arg(insecure_creds_2);
@ -3557,7 +3578,8 @@ static void test_insecure_credentials_compare_success(void) {
static void test_insecure_credentials_compare_failure(void) {
auto* insecure_creds = grpc_insecure_credentials_create();
auto* fake_creds = grpc_fake_transport_security_credentials_create();
GPR_ASSERT(grpc_core::QsortCompare(insecure_creds, fake_creds) != 0);
GPR_ASSERT(insecure_creds->cmp(fake_creds) != 0);
GPR_ASSERT(fake_creds->cmp(insecure_creds) != 0);
grpc_arg arg_1 = grpc_channel_credentials_to_arg(insecure_creds);
grpc_channel_args args_1 = {1, &arg_1};
grpc_arg arg_2 = grpc_channel_credentials_to_arg(fake_creds);
@ -3567,6 +3589,19 @@ static void test_insecure_credentials_compare_failure(void) {
grpc_channel_credentials_release(fake_creds);
}
static void test_fake_call_credentials_compare_success(void) {
auto call_creds = grpc_core::MakeRefCounted<fake_call_creds>();
GPR_ASSERT(call_creds->cmp(call_creds.get()) == 0);
}
static void test_fake_call_credentials_compare_failure(void) {
auto fake_creds = grpc_core::MakeRefCounted<fake_call_creds>();
auto* md_creds = grpc_md_only_test_credentials_create("key", "value", false);
GPR_ASSERT(fake_creds->cmp(md_creds) != 0);
GPR_ASSERT(md_creds->cmp(fake_creds.get()) != 0);
grpc_call_credentials_release(md_creds);
}
int main(int argc, char** argv) {
grpc::testing::TestEnvironment env(argc, argv);
grpc_init();
@ -3647,6 +3682,8 @@ int main(int argc, char** argv) {
test_external_account_credentials_create_failure_invalid_workforce_pool_audience();
test_insecure_credentials_compare_success();
test_insecure_credentials_compare_failure();
test_fake_call_credentials_compare_success();
test_fake_call_credentials_compare_failure();
grpc_shutdown();
return 0;
}

Loading…
Cancel
Save