From bf9304ef1710aca5151c603431589dc0acaf8f5c Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Thu, 1 Sep 2022 15:15:14 -0700 Subject: [PATCH] client_channel, client_auth: rewrite disallowed status codes from the control plane (#30789) * client_channel: rewrite illegal status codes from control plane * rewrite illegal status codes for call creds * move fail_lb policy out of retry_lb_fail test so it can be reused * test resolver and LB policy status rewrites * add test for ConfigSelector status rewriting * attempt to add client_auth filter unit test * fix client_auth_filter test * cleanup test * fix build * fix some memory leaks * Automated change: Fix sanity tests * Update client_auth_filter_test.cc * fix build * code review comments * clang-tidy Co-authored-by: markdroth Co-authored-by: Craig Tiller --- BUILD | 2 + CMakeLists.txt | 36 ++++ build_autogenerated.yaml | 11 + .../filters/client_channel/client_channel.cc | 24 +-- src/core/lib/channel/status_util.cc | 27 +++ src/core/lib/channel/status_util.h | 10 + src/core/lib/promise/context.h | 2 +- .../security/transport/client_auth_filter.cc | 13 +- test/core/end2end/tests/retry_lb_fail.cc | 86 +------- test/core/filters/BUILD | 14 ++ test/core/filters/client_auth_filter_test.cc | 194 ++++++++++++++++++ test/core/util/test_lb_policies.cc | 78 +++++++ test/core/util/test_lb_policies.h | 8 + test/cpp/end2end/client_lb_end2end_test.cc | 93 +++++++++ tools/run_tests/generated/tests.json | 24 +++ 15 files changed, 532 insertions(+), 90 deletions(-) create mode 100644 test/core/filters/client_auth_filter_test.cc diff --git a/BUILD b/BUILD index 6d7ff2d2167..b2e3b413be1 100644 --- a/BUILD +++ b/BUILD @@ -6205,6 +6205,7 @@ grpc_cc_library( "activity", "arena", "arena_promise", + "basic_seq", "channel_args", "channel_fwd", "closure", @@ -6230,6 +6231,7 @@ grpc_cc_library( "ref_counted_ptr", "resource_quota", "resource_quota_trace", + "seq", "slice", "slice_refcount", "try_seq", diff --git a/CMakeLists.txt b/CMakeLists.txt index fb17f1ceafd..63da4c8b20e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -886,6 +886,7 @@ if(gRPC_BUILD_TESTS) add_dependencies(buildtests_cxx check_gcp_environment_windows_test) add_dependencies(buildtests_cxx chunked_vector_test) add_dependencies(buildtests_cxx cli_call_test) + add_dependencies(buildtests_cxx client_auth_filter_test) add_dependencies(buildtests_cxx client_authority_filter_test) add_dependencies(buildtests_cxx client_callback_end2end_test) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) @@ -7678,6 +7679,41 @@ target_link_libraries(cli_call_test ) +endif() +if(gRPC_BUILD_TESTS) + +add_executable(client_auth_filter_test + test/core/filters/client_auth_filter_test.cc + third_party/googletest/googletest/src/gtest-all.cc + third_party/googletest/googlemock/src/gmock-all.cc +) + +target_include_directories(client_auth_filter_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(client_auth_filter_test + ${_gRPC_PROTOBUF_LIBRARIES} + ${_gRPC_ALLTARGETS_LIBRARIES} + grpc +) + + endif() if(gRPC_BUILD_TESTS) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 0af1fb72a80..e3591df5821 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -5000,6 +5000,17 @@ targets: - test/cpp/util/service_describer.cc deps: - grpc++_test_util +- name: client_auth_filter_test + gtest: true + build: test + language: c++ + headers: + - test/core/promise/test_context.h + src: + - test/core/filters/client_auth_filter_test.cc + deps: + - grpc + uses_polling: false - name: client_authority_filter_test gtest: true build: test diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index f31884cf575..39c69e24acf 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -61,6 +61,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/channel_trace.h" +#include "src/core/lib/channel/status_util.h" #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/debug/trace.h" #include "src/core/lib/gpr/useful.h" @@ -1294,7 +1295,8 @@ void ClientChannel::OnResolverErrorLocked(absl::Status status) { { MutexLock lock(&resolution_mu_); // Update resolver transient failure. - resolver_transient_failure_error_ = status; + resolver_transient_failure_error_ = + MaybeRewriteIllegalStatusCode(status, "resolver"); // Process calls that were queued waiting for the resolver result. for (ResolverQueuedCall* call = resolver_queued_calls_; call != nullptr; call = call->next) { @@ -1393,8 +1395,7 @@ void ClientChannel::UpdateServiceConfigInControlPlaneLocked( RefCountedPtr config_selector, std::string lb_policy_name) { std::string service_config_json(service_config->json_string()); if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_trace)) { - gpr_log(GPR_INFO, - "chand=%p: resolver returned updated service config: \"%s\"", this, + gpr_log(GPR_INFO, "chand=%p: using service config: \"%s\"", this, service_config_json.c_str()); } // Save service config. @@ -2165,7 +2166,8 @@ grpc_error_handle ClientChannel::CallData::ApplyServiceConfigToCallLocked( ConfigSelector::CallConfig call_config = config_selector->GetCallConfig({&path_, initial_metadata, arena_}); if (!call_config.status.ok()) { - return absl_status_to_grpc_error(call_config.status); + return absl_status_to_grpc_error(MaybeRewriteIllegalStatusCode( + std::move(call_config.status), "ConfigSelector")); } // Create a ClientChannelServiceConfigCallData for the call. This stores // a ref to the ServiceConfig and caches the right set of parsed configs @@ -3158,11 +3160,8 @@ bool ClientChannel::LoadBalancedCall::PickSubchannelLocked( // attempt's final status. if (!initial_metadata_batch->GetOrCreatePointer(WaitForReady()) ->value) { - grpc_error_handle lb_error = - absl_status_to_grpc_error(fail_pick->status); - *error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Failed to pick subchannel", &lb_error, 1); - GRPC_ERROR_UNREF(lb_error); + *error = absl_status_to_grpc_error(MaybeRewriteIllegalStatusCode( + std::move(fail_pick->status), "LB pick")); MaybeRemoveCallFromLbQueuedCallsLocked(); return true; } @@ -3178,9 +3177,10 @@ bool ClientChannel::LoadBalancedCall::PickSubchannelLocked( gpr_log(GPR_INFO, "chand=%p lb_call=%p: LB pick dropped: %s", chand_, this, drop_pick->status.ToString().c_str()); } - *error = - grpc_error_set_int(absl_status_to_grpc_error(drop_pick->status), - GRPC_ERROR_INT_LB_POLICY_DROP, 1); + *error = grpc_error_set_int( + absl_status_to_grpc_error(MaybeRewriteIllegalStatusCode( + std::move(drop_pick->status), "LB drop")), + GRPC_ERROR_INT_LB_POLICY_DROP, 1); MaybeRemoveCallFromLbQueuedCallsLocked(); return true; }); diff --git a/src/core/lib/channel/status_util.cc b/src/core/lib/channel/status_util.cc index 6c4c13ca3b2..b943ab797ac 100644 --- a/src/core/lib/channel/status_util.cc +++ b/src/core/lib/channel/status_util.cc @@ -22,6 +22,8 @@ #include +#include "absl/strings/str_cat.h" + #include "src/core/lib/gpr/useful.h" struct status_string_entry { @@ -109,3 +111,28 @@ bool grpc_status_code_from_int(int status_int, grpc_status_code* status) { *status = static_cast(status_int); return true; } + +namespace grpc_core { + +absl::Status MaybeRewriteIllegalStatusCode(absl::Status status, + absl::string_view source) { + switch (status.code()) { + // The set of disallowed codes, as per + // https://github.com/grpc/proposal/blob/master/A54-restrict-control-plane-status-codes.md. + case absl::StatusCode::kInvalidArgument: + case absl::StatusCode::kNotFound: + case absl::StatusCode::kAlreadyExists: + case absl::StatusCode::kFailedPrecondition: + case absl::StatusCode::kAborted: + case absl::StatusCode::kOutOfRange: + case absl::StatusCode::kDataLoss: { + return absl::InternalError( + absl::StrCat("Illegal status code from ", source, + "; original status: ", status.ToString())); + } + default: + return status; + } +} + +} // namespace grpc_core diff --git a/src/core/lib/channel/status_util.h b/src/core/lib/channel/status_util.h index e7524a09189..df0fc53f8ad 100644 --- a/src/core/lib/channel/status_util.h +++ b/src/core/lib/channel/status_util.h @@ -21,6 +21,9 @@ #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + #include /// If \a status_str is a valid status string, sets \a status to the @@ -59,6 +62,13 @@ class StatusCodeSet { }; } // namespace internal + +// Optionally rewrites a status as per +// https://github.com/grpc/proposal/blob/master/A54-restrict-control-plane-status-codes.md. +// The source parameter indicates where the status came from. +absl::Status MaybeRewriteIllegalStatusCode(absl::Status status, + absl::string_view source); + } // namespace grpc_core #endif /* GRPC_CORE_LIB_CHANNEL_STATUS_UTIL_H */ diff --git a/src/core/lib/promise/context.h b/src/core/lib/promise/context.h index ad8a0e43b55..29127bbd4a6 100644 --- a/src/core/lib/promise/context.h +++ b/src/core/lib/promise/context.h @@ -77,7 +77,7 @@ T* GetContext() { // Given a promise and a context, return a promise that has that context set. template promise_detail::WithContext WithContext(F f, T* context) { - return promise_detail::WithContext(f, context); + return promise_detail::WithContext(std::move(f), context); } } // namespace grpc_core diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc index 4ef8ff5a508..11117fc4a8a 100644 --- a/src/core/lib/security/transport/client_auth_filter.cc +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -37,11 +37,14 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" #include "src/core/lib/channel/promise_based_filter.h" +#include "src/core/lib/channel/status_util.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/detail/basic_seq.h" #include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/seq.h" #include "src/core/lib/promise/try_seq.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/security/context/security_context.h" @@ -157,7 +160,15 @@ ArenaPromise> ClientAuthFilter::GetCallCredsMetadata( auto client_initial_metadata = std::move(call_args.client_initial_metadata); return TrySeq( - creds->GetRequestMetadata(std::move(client_initial_metadata), &args_), + Seq(creds->GetRequestMetadata(std::move(client_initial_metadata), &args_), + [](absl::StatusOr new_metadata) mutable { + if (!new_metadata.ok()) { + return absl::StatusOr( + MaybeRewriteIllegalStatusCode(new_metadata.status(), + "call credentials")); + } + return new_metadata; + }), [call_args = std::move(call_args)](ClientMetadataHandle new_metadata) mutable { call_args.client_initial_metadata = std::move(new_metadata); diff --git a/test/core/end2end/tests/retry_lb_fail.cc b/test/core/end2end/tests/retry_lb_fail.cc index 610a2a9f174..ac8360b25cc 100644 --- a/test/core/end2end/tests/retry_lb_fail.cc +++ b/test/core/end2end/tests/retry_lb_fail.cc @@ -18,13 +18,8 @@ #include #include -#include -#include -#include "absl/memory/memory.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include #include @@ -36,81 +31,16 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/gpr/useful.h" -#include "src/core/lib/gprpp/orphanable.h" -#include "src/core/lib/gprpp/ref_counted_ptr.h" -#include "src/core/lib/json/json.h" -#include "src/core/lib/load_balancing/lb_policy.h" -#include "src/core/lib/load_balancing/lb_policy_factory.h" -#include "src/core/lib/load_balancing/lb_policy_registry.h" #include "test/core/end2end/cq_verifier.h" #include "test/core/end2end/end2end_tests.h" #include "test/core/util/test_config.h" +#include "test/core/util/test_lb_policies.h" -namespace grpc_core { namespace { -constexpr absl::string_view kFailPolicyName = "fail_lb"; - std::atomic g_num_lb_picks; -class FailPolicy : public LoadBalancingPolicy { - public: - explicit FailPolicy(Args args) : LoadBalancingPolicy(std::move(args)) {} - - absl::string_view name() const override { return kFailPolicyName; } - - void UpdateLocked(UpdateArgs) override { - absl::Status status = absl::AbortedError("LB pick failed"); - channel_control_helper()->UpdateState( - GRPC_CHANNEL_TRANSIENT_FAILURE, status, - absl::make_unique(status)); - } - - void ResetBackoffLocked() override {} - void ShutdownLocked() override {} - - private: - class FailPicker : public SubchannelPicker { - public: - explicit FailPicker(absl::Status status) : status_(status) {} - - PickResult Pick(PickArgs /*args*/) override { - g_num_lb_picks.fetch_add(1); - return PickResult::Fail(status_); - } - - private: - absl::Status status_; - }; -}; - -class FailLbConfig : public LoadBalancingPolicy::Config { - public: - absl::string_view name() const override { return kFailPolicyName; } -}; - -class FailPolicyFactory : public LoadBalancingPolicyFactory { - public: - OrphanablePtr CreateLoadBalancingPolicy( - LoadBalancingPolicy::Args args) const override { - return MakeOrphanable(std::move(args)); - } - - absl::string_view name() const override { return kFailPolicyName; } - - absl::StatusOr> - ParseLoadBalancingConfig(const Json& /*json*/) const override { - return MakeRefCounted(); - } -}; - -void RegisterFailPolicy(CoreConfiguration::Builder* builder) { - builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory( - absl::make_unique()); -} - } // namespace -} // namespace grpc_core static void* tag(intptr_t t) { return reinterpret_cast(t); } @@ -188,7 +118,7 @@ static void test_retry_lb_fail(grpc_end2end_test_config config) { grpc_call_error error; grpc_slice details; - grpc_core::g_num_lb_picks.store(0, std::memory_order_relaxed); + g_num_lb_picks.store(0, std::memory_order_relaxed); grpc_arg args[] = { grpc_channel_arg_integer_create( @@ -209,7 +139,7 @@ static void test_retry_lb_fail(grpc_end2end_test_config config) { " \"initialBackoff\": \"1s\",\n" " \"maxBackoff\": \"120s\",\n" " \"backoffMultiplier\": 1.6,\n" - " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " \"retryableStatusCodes\": [ \"UNAVAILABLE\" ]\n" " }\n" " } ]\n" "}")), @@ -255,7 +185,7 @@ static void test_retry_lb_fail(grpc_end2end_test_config config) { cqv.Expect(tag(2), true); cqv.Verify(); - GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); GPR_ASSERT(0 == grpc_slice_str_cmp(details, "LB pick failed")); grpc_slice_unref(details); @@ -266,7 +196,7 @@ static void test_retry_lb_fail(grpc_end2end_test_config config) { grpc_call_unref(c); - int num_picks = grpc_core::g_num_lb_picks.load(std::memory_order_relaxed); + int num_picks = g_num_lb_picks.load(std::memory_order_relaxed); gpr_log(GPR_INFO, "NUM LB PICKS: %d", num_picks); GPR_ASSERT(num_picks == 2); @@ -280,5 +210,9 @@ void retry_lb_fail(grpc_end2end_test_config config) { } void retry_lb_fail_pre_init(void) { - grpc_core::CoreConfiguration::RegisterBuilder(grpc_core::RegisterFailPolicy); + grpc_core::CoreConfiguration::RegisterBuilder( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::RegisterFailLoadBalancingPolicy( + builder, absl::UnavailableError("LB pick failed"), &g_num_lb_picks); + }); } diff --git a/test/core/filters/BUILD b/test/core/filters/BUILD index eab0ec86450..28654c436a1 100644 --- a/test/core/filters/BUILD +++ b/test/core/filters/BUILD @@ -19,6 +19,20 @@ licenses(["notice"]) grpc_package(name = "test/core/filters") +grpc_cc_test( + name = "client_auth_filter_test", + srcs = ["client_auth_filter_test.cc"], + external_deps = ["gtest"], + language = "c++", + uses_event_engine = False, + uses_polling = False, + deps = [ + "//:grpc", + "//:grpc_security_base", + "//test/core/promise:test_context", + ], +) + grpc_cc_test( name = "client_authority_filter_test", srcs = ["client_authority_filter_test.cc"], diff --git a/test/core/filters/client_auth_filter_test.cc b/test/core/filters/client_auth_filter_test.cc new file mode 100644 index 00000000000..1d3929a07cd --- /dev/null +++ b/test/core/filters/client_auth_filter_test.cc @@ -0,0 +1,194 @@ +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/resource_quota/resource_quota.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/lib/security/security_connector/fake/fake_security_connector.h" +#include "src/core/lib/security/transport/auth_filters.h" +#include "test/core/promise/test_context.h" + +// TODO(roth): Need to add a lot more tests here. I created this file +// as part of adding a feature, and I added tests only for the feature I +// was adding. When we have time, we need to go back and write +// comprehensive tests for all of the functionality in the filter. + +namespace grpc_core { +namespace { + +class ClientAuthFilterTest : public ::testing::Test { + protected: + class FailCallCreds : public grpc_call_credentials { + public: + explicit FailCallCreds(absl::Status status) + : grpc_call_credentials(GRPC_SECURITY_NONE), + status_(std::move(status)) {} + + UniqueTypeName type() const override { + static UniqueTypeName::Factory kFactory("FailCallCreds"); + return kFactory.Create(); + } + + ArenaPromise> GetRequestMetadata( + ClientMetadataHandle /*initial_metadata*/, + const GetRequestMetadataArgs* /*args*/) override { + return Immediate>(status_); + } + + int cmp_impl(const grpc_call_credentials* other) const override { + return QsortCompare( + status_.ToString(), + static_cast(other)->status_.ToString()); + } + + private: + absl::Status status_; + }; + + ClientAuthFilterTest() + : memory_allocator_( + ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( + "test")), + arena_(MakeScopedArena(1024, &memory_allocator_)), + initial_metadata_batch_(arena_.get()), + trailing_metadata_batch_(arena_.get()), + target_(Slice::FromStaticString("localhost:1234")), + channel_creds_(grpc_fake_transport_security_credentials_create()) { + initial_metadata_batch_.Set(HttpAuthorityMetadata(), target_.Ref()); + } + + ~ClientAuthFilterTest() override { + for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) { + if (call_context_[i].destroy != nullptr) { + call_context_[i].destroy(call_context_[i].value); + } + } + } + + ChannelArgs MakeChannelArgs(absl::Status status_for_call_creds) { + ChannelArgs args; + auto security_connector = channel_creds_->create_security_connector( + status_for_call_creds.ok() + ? nullptr + : MakeRefCounted(std::move(status_for_call_creds)), + std::string(target_.as_string_view()).c_str(), &args); + auto auth_context = MakeRefCounted(nullptr); + absl::string_view security_level = "TSI_SECURITY_NONE"; + auth_context->add_property(GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME, + security_level.data(), security_level.size()); + return args.SetObject(std::move(security_connector)) + .SetObject(std::move(auth_context)); + } + + MemoryAllocator memory_allocator_; + ScopedArenaPtr arena_; + grpc_metadata_batch initial_metadata_batch_; + grpc_metadata_batch trailing_metadata_batch_; + Slice target_; + RefCountedPtr channel_creds_; + grpc_call_context_element call_context_[GRPC_CONTEXT_COUNT]; +}; + +TEST_F(ClientAuthFilterTest, CreateFailsWithoutRequiredChannelArgs) { + EXPECT_FALSE( + ClientAuthFilter::Create(ChannelArgs(), ChannelFilter::Args()).ok()); +} + +TEST_F(ClientAuthFilterTest, CreateSucceeds) { + auto filter = ClientAuthFilter::Create(MakeChannelArgs(absl::OkStatus()), + ChannelFilter::Args()); + EXPECT_TRUE(filter.ok()) << filter.status(); +} + +TEST_F(ClientAuthFilterTest, CallCredsFails) { + auto filter = ClientAuthFilter::Create( + MakeChannelArgs(absl::UnauthenticatedError("access denied")), + ChannelFilter::Args()); + // TODO(ctiller): use Activity here, once it's ready. + TestContext context(arena_.get()); + TestContext promise_call_context(call_context_); + auto promise = filter->MakeCallPromise( + CallArgs{ + ClientMetadataHandle::TestOnlyWrap(&initial_metadata_batch_), + nullptr, + }, + [&](CallArgs /*call_args*/) { + return ArenaPromise( + [&]() -> Poll { + return ServerMetadataHandle::TestOnlyWrap( + &trailing_metadata_batch_); + }); + }); + auto result = promise(); + ServerMetadataHandle* server_metadata = + absl::get_if(&result); + ASSERT_TRUE(server_metadata != nullptr); + auto status_md = (*server_metadata)->get(GrpcStatusMetadata()); + ASSERT_TRUE(status_md.has_value()); + EXPECT_EQ(*status_md, GRPC_STATUS_UNAUTHENTICATED); + const Slice* message_md = + (*server_metadata)->get_pointer(GrpcMessageMetadata()); + ASSERT_TRUE(message_md != nullptr); + EXPECT_EQ(message_md->as_string_view(), "access denied"); + (*server_metadata)->~ServerMetadata(); +} + +TEST_F(ClientAuthFilterTest, RewritesInvalidStatusFromCallCreds) { + auto filter = ClientAuthFilter::Create( + MakeChannelArgs(absl::AbortedError("nope")), ChannelFilter::Args()); + // TODO(ctiller): use Activity here, once it's ready. + TestContext context(arena_.get()); + TestContext promise_call_context(call_context_); + auto promise = filter->MakeCallPromise( + CallArgs{ + ClientMetadataHandle::TestOnlyWrap(&initial_metadata_batch_), + nullptr, + }, + [&](CallArgs /*call_args*/) { + return ArenaPromise( + [&]() -> Poll { + return ServerMetadataHandle::TestOnlyWrap( + &trailing_metadata_batch_); + }); + }); + auto result = promise(); + ServerMetadataHandle* server_metadata = + absl::get_if(&result); + ASSERT_TRUE(server_metadata != nullptr); + auto status_md = (*server_metadata)->get(GrpcStatusMetadata()); + ASSERT_TRUE(status_md.has_value()); + EXPECT_EQ(*status_md, GRPC_STATUS_INTERNAL); + const Slice* message_md = + (*server_metadata)->get_pointer(GrpcMessageMetadata()); + ASSERT_TRUE(message_md != nullptr); + EXPECT_EQ(message_md->as_string_view(), + "Illegal status code from call credentials; original status: " + "ABORTED: nope"); + (*server_metadata)->~ServerMetadata(); +} + +} // namespace +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int retval = RUN_ALL_TESTS(); + grpc_shutdown(); + return retval; +} diff --git a/test/core/util/test_lb_policies.cc b/test/core/util/test_lb_policies.cc index bd1ac07d324..aad85eb157f 100644 --- a/test/core/util/test_lb_policies.cc +++ b/test/core/util/test_lb_policies.cc @@ -656,6 +656,77 @@ class OobBackendMetricTestFactory : public LoadBalancingPolicyFactory { OobBackendMetricCallback cb_; }; +// +// FailLoadBalancingPolicy +// + +constexpr char kFailPolicyName[] = "fail_lb"; + +class FailPolicy : public LoadBalancingPolicy { + public: + FailPolicy(Args args, absl::Status status, std::atomic* pick_counter) + : LoadBalancingPolicy(std::move(args)), + status_(std::move(status)), + pick_counter_(pick_counter) {} + + absl::string_view name() const override { return kFailPolicyName; } + + void UpdateLocked(UpdateArgs) override { + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status_, + absl::make_unique(status_, pick_counter_)); + } + + void ResetBackoffLocked() override {} + void ShutdownLocked() override {} + + private: + class FailPicker : public SubchannelPicker { + public: + FailPicker(absl::Status status, std::atomic* pick_counter) + : status_(std::move(status)), pick_counter_(pick_counter) {} + + PickResult Pick(PickArgs /*args*/) override { + if (pick_counter_ != nullptr) pick_counter_->fetch_add(1); + return PickResult::Fail(status_); + } + + private: + absl::Status status_; + std::atomic* pick_counter_; + }; + + absl::Status status_; + std::atomic* pick_counter_; +}; + +class FailLbConfig : public LoadBalancingPolicy::Config { + public: + absl::string_view name() const override { return kFailPolicyName; } +}; + +class FailLbFactory : public LoadBalancingPolicyFactory { + public: + FailLbFactory(absl::Status status, std::atomic* pick_counter) + : status_(std::move(status)), pick_counter_(pick_counter) {} + + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args), status_, pick_counter_); + } + + absl::string_view name() const override { return kFailPolicyName; } + + absl::StatusOr> + ParseLoadBalancingConfig(const Json& /*json*/) const override { + return MakeRefCounted(); + } + + private: + absl::Status status_; + std::atomic* pick_counter_; +}; + } // namespace void RegisterTestPickArgsLoadBalancingPolicy( @@ -691,4 +762,11 @@ void RegisterOobBackendMetricTestLoadBalancingPolicy( absl::make_unique(std::move(cb))); } +void RegisterFailLoadBalancingPolicy(CoreConfiguration::Builder* builder, + absl::Status status, + std::atomic* pick_counter) { + builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory( + absl::make_unique(std::move(status), pick_counter)); +} + } // namespace grpc_core diff --git a/test/core/util/test_lb_policies.h b/test/core/util/test_lb_policies.h index 1cd85f27a1f..afec49b29df 100644 --- a/test/core/util/test_lb_policies.h +++ b/test/core/util/test_lb_policies.h @@ -19,6 +19,7 @@ #include +#include #include #include #include @@ -83,6 +84,13 @@ using OobBackendMetricCallback = void RegisterOobBackendMetricTestLoadBalancingPolicy( CoreConfiguration::Builder* builder, OobBackendMetricCallback cb); +// Registers an LB policy called "fail_lb" that fails all picks with the +// specified status. If pick_counter is non-null, it will be +// incremented for each pick. +void RegisterFailLoadBalancingPolicy(CoreConfiguration::Builder* builder, + absl::Status status, + std::atomic* pick_counter = nullptr); + } // namespace grpc_core #endif // GRPC_TEST_CORE_UTIL_TEST_LB_POLICIES_H diff --git a/test/cpp/end2end/client_lb_end2end_test.cc b/test/cpp/end2end/client_lb_end2end_test.cc index afa997499c9..e7c74ebacd0 100644 --- a/test/cpp/end2end/client_lb_end2end_test.cc +++ b/test/cpp/end2end/client_lb_end2end_test.cc @@ -46,6 +46,7 @@ #include #include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/config_selector.h" #include "src/core/ext/filters/client_channel/global_subchannel_pool.h" #include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" #include "src/core/lib/address_utils/parse_address.h" @@ -176,6 +177,11 @@ class FakeResolverResponseGeneratorWrapper { response_generator_->SetFailureOnReresolution(); } + void SetResponse(grpc_core::Resolver::Result result) { + grpc_core::ExecCtx exec_ctx; + response_generator_->SetResponse(std::move(result)); + } + grpc_core::FakeResolverResponseGenerator* Get() const { return response_generator_.get(); } @@ -2684,6 +2690,93 @@ TEST_F(OobBackendMetricTest, Basic) { } } +// +// tests rewriting of control plane status codes +// + +class ControlPlaneStatusRewritingTest : public ClientLbEnd2endTest { + protected: + static void SetUpTestCase() { + grpc_core::CoreConfiguration::Reset(); + grpc_core::CoreConfiguration::RegisterBuilder( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::RegisterFailLoadBalancingPolicy( + builder, absl::AbortedError("nope")); + }); + grpc_init(); + } + + static void TearDownTestCase() { + grpc_shutdown(); + grpc_core::CoreConfiguration::Reset(); + } +}; + +TEST_F(ControlPlaneStatusRewritingTest, RewritesFromLb) { + // Start client. + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("fail_lb", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + // Send an RPC, verify that status was rewritten. + CheckRpcSendFailure( + DEBUG_LOCATION, stub, StatusCode::INTERNAL, + "Illegal status code from LB pick; original status: ABORTED: nope"); +} + +TEST_F(ControlPlaneStatusRewritingTest, RewritesFromResolver) { + // Start client. + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + grpc_core::Resolver::Result result; + result.service_config = absl::AbortedError("nope"); + result.addresses.emplace(); + response_generator.SetResponse(std::move(result)); + // Send an RPC, verify that status was rewritten. + CheckRpcSendFailure( + DEBUG_LOCATION, stub, StatusCode::INTERNAL, + "Illegal status code from resolver; original status: ABORTED: nope"); +} + +TEST_F(ControlPlaneStatusRewritingTest, RewritesFromConfigSelector) { + class FailConfigSelector : public grpc_core::ConfigSelector { + public: + explicit FailConfigSelector(absl::Status status) + : status_(std::move(status)) {} + const char* name() const override { return "FailConfigSelector"; } + bool Equals(const ConfigSelector* other) const override { + return status_ == static_cast(other)->status_; + } + CallConfig GetCallConfig(GetCallConfigArgs /*args*/) override { + CallConfig config; + config.status = status_; + return config; + } + + private: + absl::Status status_; + }; + // Start client. + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + auto config_selector = + grpc_core::MakeRefCounted(absl::AbortedError("nope")); + grpc_core::Resolver::Result result; + result.addresses.emplace(); + result.service_config = + grpc_core::ServiceConfigImpl::Create(grpc_core::ChannelArgs(), "{}"); + ASSERT_TRUE(result.service_config.ok()) << result.service_config.status(); + result.args = grpc_core::ChannelArgs().SetObject(config_selector); + response_generator.SetResponse(std::move(result)); + // Send an RPC, verify that status was rewritten. + CheckRpcSendFailure( + DEBUG_LOCATION, stub, StatusCode::INTERNAL, + "Illegal status code from ConfigSelector; original status: " + "ABORTED: nope"); +} + } // namespace } // namespace testing } // namespace grpc diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 21a059e4a06..9ae6e0597db 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -1825,6 +1825,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "client_auth_filter_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false,