From 5f67cd07f42338a8ffa3b8f8a5c6532c324a0867 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Tue, 15 Feb 2022 15:03:25 -0800 Subject: [PATCH] Explicit method for comparing channel credentials (#28844) * Use an explicit cmp method on grpc_channel_credentials * Add testing * Add TODOs to improve cmp methods * Automated change: Fix sanity tests * clang format * Reviewer comments * clang-format * Add cmp method for grpc_call_credentials * s/overriden/overridden Co-authored-by: yashykt --- .../lib/http/httpcli_security_connector.cc | 7 ++++ .../credentials/alts/alts_credentials.h | 6 +++ .../composite/composite_credentials.h | 12 ++++++ .../lib/security/credentials/credentials.cc | 3 +- .../lib/security/credentials/credentials.h | 33 +++++++++++++++ .../credentials/fake/fake_credentials.cc | 7 ++++ .../credentials/fake/fake_credentials.h | 6 +++ .../google_default_credentials.h | 6 +++ .../credentials/iam/iam_credentials.h | 6 +++ .../insecure/insecure_credentials.cc | 7 ++++ .../credentials/jwt/jwt_credentials.h | 6 +++ .../credentials/local/local_credentials.h | 6 +++ .../credentials/oauth2/oauth2_credentials.h | 12 ++++++ .../credentials/plugin/plugin_credentials.h | 6 +++ .../credentials/ssl/ssl_credentials.h | 6 +++ .../credentials/tls/tls_credentials.h | 6 +++ .../credentials/xds/xds_credentials.h | 6 +++ .../security_connector/security_connector.cc | 2 +- test/core/security/credentials_test.cc | 41 ++++++++++++++++++- 19 files changed, 180 insertions(+), 4 deletions(-) diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc index 0a63b57bc44..c82a09e375a 100644 --- a/src/core/lib/http/httpcli_security_connector.cc +++ b/src/core/lib/http/httpcli_security_connector.cc @@ -190,6 +190,13 @@ class HttpRequestSSLCredentials : public grpc_channel_credentials { grpc_channel_args* update_arguments(grpc_channel_args* args) override { return args; } + + private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return QsortCompare(static_cast(this), + other); + } }; } // namespace diff --git a/src/core/lib/security/credentials/alts/alts_credentials.h b/src/core/lib/security/credentials/alts/alts_credentials.h index 8e1362c0b61..5012d9685bf 100644 --- a/src/core/lib/security/credentials/alts/alts_credentials.h +++ b/src/core/lib/security/credentials/alts/alts_credentials.h @@ -44,6 +44,12 @@ class grpc_alts_credentials final : public grpc_channel_credentials { const char* handshaker_service_url() const { return handshaker_service_url_; } private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + grpc_alts_credentials_options* options_; char* handshaker_service_url_; }; diff --git a/src/core/lib/security/credentials/composite/composite_credentials.h b/src/core/lib/security/credentials/composite/composite_credentials.h index 91e34c1346a..fe2da4bd0ac 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.h +++ b/src/core/lib/security/credentials/composite/composite_credentials.h @@ -63,6 +63,12 @@ class grpc_composite_channel_credentials : public grpc_channel_credentials { grpc_call_credentials* mutable_call_creds() { return call_creds_.get(); } private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + grpc_core::RefCountedPtr inner_creds_; grpc_core::RefCountedPtr call_creds_; }; @@ -97,6 +103,12 @@ class grpc_composite_call_credentials : public grpc_call_credentials { std::string debug_string() override; private: + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + void push_to_inner(grpc_core::RefCountedPtr creds, bool is_composite); grpc_security_level min_security_level_; diff --git a/src/core/lib/security/credentials/credentials.cc b/src/core/lib/security/credentials/credentials.cc index cd8e066d5cd..0c225794de5 100644 --- a/src/core/lib/security/credentials/credentials.cc +++ b/src/core/lib/security/credentials/credentials.cc @@ -58,7 +58,8 @@ static void* credentials_pointer_arg_copy(void* p) { } static int credentials_pointer_cmp(void* a, void* b) { - return grpc_core::QsortCompare(a, b); + return static_cast(a)->cmp( + static_cast(b)); } static const grpc_arg_pointer_vtable credentials_pointer_vtable = { diff --git a/src/core/lib/security/credentials/credentials.h b/src/core/lib/security/credentials/credentials.h index 6f1a9d4ade7..59ae767586d 100644 --- a/src/core/lib/security/credentials/credentials.h +++ b/src/core/lib/security/credentials/credentials.h @@ -130,9 +130,28 @@ struct grpc_channel_credentials return args; } + // Compares this grpc_channel_credentials object with \a other. + // If this method returns 0, it means that gRPC can treat the two channel + // credentials as effectively the same. This method is used to compare + // `grpc_channel_credentials` objects when they are present in channel_args. + // One important usage of this is when channel args are used in SubchannelKey, + // which leads to a useful property that allows subchannels to be reused when + // two different `grpc_channel_credentials` objects are used but they compare + // as equal (assuming other channel args match). + int cmp(const grpc_channel_credentials* other) const { + GPR_ASSERT(other != nullptr); + int r = strcmp(type(), other->type()); + if (r != 0) return r; + return cmp_impl(other); + } + const char* type() const { return type_; } private: + // Implementation for `cmp` method intended to be overridden by subclasses. + // Only invoked if `type()` and `other->type()` compare equal as strings. + virtual int cmp_impl(const grpc_channel_credentials* other) const = 0; + const char* type_; }; @@ -193,6 +212,16 @@ struct grpc_call_credentials return min_security_level_; } + // Compares this grpc_call_credentials object with \a other. + // If this method returns 0, it means that gRPC can treat the two call + // credentials as effectively the same.. + int cmp(const grpc_call_credentials* other) const { + GPR_ASSERT(other != nullptr); + int r = strcmp(type(), other->type()); + if (r != 0) return r; + return cmp_impl(other); + } + virtual std::string debug_string() { return "grpc_call_credentials did not provide debug string"; } @@ -200,6 +229,10 @@ struct grpc_call_credentials const char* type() const { return type_; } private: + // Implementation for `cmp` method intended to be overridden by subclasses. + // Only invoked if `type()` and `other->type()` compare equal as strings. + virtual int cmp_impl(const grpc_call_credentials* other) const = 0; + const char* type_; const grpc_security_level min_security_level_; }; diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc index c3c61024553..5a9fbcbb1d8 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.cc +++ b/src/core/lib/security/credentials/fake/fake_credentials.cc @@ -49,6 +49,13 @@ class grpc_fake_channel_credentials final : public grpc_channel_credentials { return grpc_fake_channel_security_connector_create( this->Ref(), std::move(call_creds), target, args); } + + private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } }; class grpc_fake_server_credentials final : public grpc_server_credentials { diff --git a/src/core/lib/security/credentials/fake/fake_credentials.h b/src/core/lib/security/credentials/fake/fake_credentials.h index 1e3ce40a30b..95a96fff1fe 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.h +++ b/src/core/lib/security/credentials/fake/fake_credentials.h @@ -80,6 +80,12 @@ class grpc_md_only_test_credentials : public grpc_call_credentials { std::string debug_string() override { return "MD only Test Credentials"; }; private: + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + grpc_core::Slice key_; grpc_core::Slice value_; bool is_async_; diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.h b/src/core/lib/security/credentials/google_default/google_default_credentials.h index 8a945da31e2..93171414c2c 100644 --- a/src/core/lib/security/credentials/google_default/google_default_credentials.h +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.h @@ -66,6 +66,12 @@ class grpc_google_default_channel_credentials const grpc_channel_credentials* ssl_creds() const { return ssl_creds_.get(); } private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + grpc_core::RefCountedPtr alts_creds_; grpc_core::RefCountedPtr ssl_creds_; }; diff --git a/src/core/lib/security/credentials/iam/iam_credentials.h b/src/core/lib/security/credentials/iam/iam_credentials.h index f56e011b064..6f6d1fa73ed 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.h +++ b/src/core/lib/security/credentials/iam/iam_credentials.h @@ -42,6 +42,12 @@ class grpc_google_iam_credentials : public grpc_call_credentials { std::string debug_string() override { return debug_string_; } private: + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + const absl::optional token_; const grpc_core::Slice authority_selector_; const std::string debug_string_; diff --git a/src/core/lib/security/credentials/insecure/insecure_credentials.cc b/src/core/lib/security/credentials/insecure/insecure_credentials.cc index 2cd5c079fe6..878a32caac0 100644 --- a/src/core/lib/security/credentials/insecure/insecure_credentials.cc +++ b/src/core/lib/security/credentials/insecure/insecure_credentials.cc @@ -36,6 +36,13 @@ class InsecureCredentials final : public grpc_channel_credentials { return MakeRefCounted( Ref(), std::move(call_creds)); } + + private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return QsortCompare(static_cast(this), + other); + } }; class InsecureServerCredentials final : public grpc_server_credentials { diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.h b/src/core/lib/security/credentials/jwt/jwt_credentials.h index 77e7161b1e3..1fc1d42f730 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.h +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.h @@ -59,6 +59,12 @@ class grpc_service_account_jwt_access_credentials }; private: + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + // Have a simple cache for now with just 1 entry. We could have a map based on // the service_url for a more sophisticated one. gpr_mu cache_mu_; diff --git a/src/core/lib/security/credentials/local/local_credentials.h b/src/core/lib/security/credentials/local/local_credentials.h index a1857ad8dba..31f32a6d47c 100644 --- a/src/core/lib/security/credentials/local/local_credentials.h +++ b/src/core/lib/security/credentials/local/local_credentials.h @@ -40,6 +40,12 @@ class grpc_local_credentials final : public grpc_channel_credentials { grpc_local_connect_type connect_type() const { return connect_type_; } private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + grpc_local_connect_type connect_type_; }; diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h index beb8168bdf5..1c75130b37a 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h @@ -110,6 +110,12 @@ class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials { grpc_millis deadline) = 0; private: + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + gpr_mu mu_; absl::optional access_token_value_; gpr_timespec token_expiration_; @@ -161,6 +167,12 @@ class grpc_access_token_credentials final : public grpc_call_credentials { std::string debug_string() override; private: + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + const grpc_core::Slice access_token_value_; }; diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.h b/src/core/lib/security/credentials/plugin/plugin_credentials.h index c81f792982a..869e817fc69 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.h +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.h @@ -63,6 +63,12 @@ struct grpc_plugin_credentials final : public grpc_call_credentials { std::string debug_string() override; private: + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + void pending_request_remove_locked(pending_request* pending_request); grpc_metadata_credentials_plugin plugin_; diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.h b/src/core/lib/security/credentials/ssl/ssl_credentials.h index 2bd3b7eaf8a..d5b1535f300 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.h +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.h @@ -43,6 +43,12 @@ class grpc_ssl_credentials : public grpc_channel_credentials { void set_max_tls_version(grpc_tls_version max_tls_version); private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + void build_config(const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, const grpc_ssl_verify_peer_options* verify_options); diff --git a/src/core/lib/security/credentials/tls/tls_credentials.h b/src/core/lib/security/credentials/tls/tls_credentials.h index a5e4f486bf9..6dc49227000 100644 --- a/src/core/lib/security/credentials/tls/tls_credentials.h +++ b/src/core/lib/security/credentials/tls/tls_credentials.h @@ -41,6 +41,12 @@ class TlsCredentials final : public grpc_channel_credentials { grpc_tls_credentials_options* options() const { return options_.get(); } private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } + grpc_core::RefCountedPtr options_; }; diff --git a/src/core/lib/security/credentials/xds/xds_credentials.h b/src/core/lib/security/credentials/xds/xds_credentials.h index a7525e69309..23f91fe591c 100644 --- a/src/core/lib/security/credentials/xds/xds_credentials.h +++ b/src/core/lib/security/credentials/xds/xds_credentials.h @@ -42,6 +42,12 @@ class XdsCredentials final : public grpc_channel_credentials { const grpc_channel_args* args, grpc_channel_args** new_args) override; private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return QsortCompare(static_cast(this), + other); + } + RefCountedPtr fallback_credentials_; }; diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc index 24f169124ad..d7f16cca386 100644 --- a/src/core/lib/security/security_connector/security_connector.cc +++ b/src/core/lib/security/security_connector/security_connector.cc @@ -60,7 +60,7 @@ int grpc_channel_security_connector::channel_security_connector_cmp( static_cast(other); GPR_ASSERT(channel_creds() != nullptr); GPR_ASSERT(other_sc->channel_creds() != nullptr); - int c = grpc_core::QsortCompare(channel_creds(), other_sc->channel_creds()); + int c = channel_creds()->cmp(other_sc->channel_creds()); if (c != 0) return c; return grpc_core::QsortCompare(request_metadata_creds(), other_sc->request_metadata_creds()); diff --git a/test/core/security/credentials_test.cc b/test/core/security/credentials_test.cc index 49933b505fa..14b78815355 100644 --- a/test/core/security/credentials_test.cc +++ b/test/core/security/credentials_test.cc @@ -555,6 +555,13 @@ class check_channel_oauth2 final : public grpc_channel_credentials { 0); return nullptr; } + + private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } }; } // namespace @@ -642,6 +649,13 @@ class check_channel_oauth2_google_iam final : public grpc_channel_credentials { 0); return nullptr; } + + private: + int cmp_impl(const grpc_channel_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } }; } // namespace @@ -1765,6 +1779,13 @@ struct fake_call_creds : public grpc_call_credentials { void cancel_get_request_metadata( grpc_core::CredentialsMetadataArray* /*md_array*/, grpc_error_handle /*error*/) override {} + + private: + int cmp_impl(const grpc_call_credentials* other) const override { + // TODO(yashykt): Check if we can do something better here + return grpc_core::QsortCompare( + static_cast(this), other); + } }; static void test_google_default_creds_not_default(void) { @@ -3544,7 +3565,7 @@ test_external_account_credentials_create_failure_invalid_workforce_pool_audience static void test_insecure_credentials_compare_success(void) { auto* insecure_creds_1 = grpc_insecure_credentials_create(); auto* insecure_creds_2 = grpc_insecure_credentials_create(); - GPR_ASSERT(grpc_core::QsortCompare(insecure_creds_1, insecure_creds_2) == 0); + GPR_ASSERT(insecure_creds_1->cmp(insecure_creds_2) == 0); grpc_arg arg_1 = grpc_channel_credentials_to_arg(insecure_creds_1); grpc_channel_args args_1 = {1, &arg_1}; grpc_arg arg_2 = grpc_channel_credentials_to_arg(insecure_creds_2); @@ -3557,7 +3578,8 @@ static void test_insecure_credentials_compare_success(void) { static void test_insecure_credentials_compare_failure(void) { auto* insecure_creds = grpc_insecure_credentials_create(); auto* fake_creds = grpc_fake_transport_security_credentials_create(); - GPR_ASSERT(grpc_core::QsortCompare(insecure_creds, fake_creds) != 0); + GPR_ASSERT(insecure_creds->cmp(fake_creds) != 0); + GPR_ASSERT(fake_creds->cmp(insecure_creds) != 0); grpc_arg arg_1 = grpc_channel_credentials_to_arg(insecure_creds); grpc_channel_args args_1 = {1, &arg_1}; grpc_arg arg_2 = grpc_channel_credentials_to_arg(fake_creds); @@ -3567,6 +3589,19 @@ static void test_insecure_credentials_compare_failure(void) { grpc_channel_credentials_release(fake_creds); } +static void test_fake_call_credentials_compare_success(void) { + auto call_creds = grpc_core::MakeRefCounted(); + GPR_ASSERT(call_creds->cmp(call_creds.get()) == 0); +} + +static void test_fake_call_credentials_compare_failure(void) { + auto fake_creds = grpc_core::MakeRefCounted(); + auto* md_creds = grpc_md_only_test_credentials_create("key", "value", false); + GPR_ASSERT(fake_creds->cmp(md_creds) != 0); + GPR_ASSERT(md_creds->cmp(fake_creds.get()) != 0); + grpc_call_credentials_release(md_creds); +} + int main(int argc, char** argv) { grpc::testing::TestEnvironment env(argc, argv); grpc_init(); @@ -3647,6 +3682,8 @@ int main(int argc, char** argv) { test_external_account_credentials_create_failure_invalid_workforce_pool_audience(); test_insecure_credentials_compare_success(); test_insecure_credentials_compare_failure(); + test_fake_call_credentials_compare_success(); + test_fake_call_credentials_compare_failure(); grpc_shutdown(); return 0; }