client_channel, client_auth: rewrite disallowed status codes from the control plane (#30789)

* client_channel: rewrite illegal status codes from control plane

* rewrite illegal status codes for call creds

* move fail_lb policy out of retry_lb_fail test so it can be reused

* test resolver and LB policy status rewrites

* add test for ConfigSelector status rewriting

* attempt to add client_auth filter unit test

* fix client_auth_filter test

* cleanup test

* fix build

* fix some memory leaks

* Automated change: Fix sanity tests

* Update client_auth_filter_test.cc

* fix build

* code review comments

* clang-tidy

Co-authored-by: markdroth <markdroth@users.noreply.github.com>
Co-authored-by: Craig Tiller <ctiller@google.com>
pull/30820/head
Mark D. Roth 2 years ago committed by GitHub
parent 2142183ef4
commit bf9304ef17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      BUILD
  2. 36
      CMakeLists.txt
  3. 11
      build_autogenerated.yaml
  4. 24
      src/core/ext/filters/client_channel/client_channel.cc
  5. 27
      src/core/lib/channel/status_util.cc
  6. 10
      src/core/lib/channel/status_util.h
  7. 2
      src/core/lib/promise/context.h
  8. 13
      src/core/lib/security/transport/client_auth_filter.cc
  9. 86
      test/core/end2end/tests/retry_lb_fail.cc
  10. 14
      test/core/filters/BUILD
  11. 194
      test/core/filters/client_auth_filter_test.cc
  12. 78
      test/core/util/test_lb_policies.cc
  13. 8
      test/core/util/test_lb_policies.h
  14. 93
      test/cpp/end2end/client_lb_end2end_test.cc
  15. 24
      tools/run_tests/generated/tests.json

@ -6205,6 +6205,7 @@ grpc_cc_library(
"activity",
"arena",
"arena_promise",
"basic_seq",
"channel_args",
"channel_fwd",
"closure",
@ -6230,6 +6231,7 @@ grpc_cc_library(
"ref_counted_ptr",
"resource_quota",
"resource_quota_trace",
"seq",
"slice",
"slice_refcount",
"try_seq",

36
CMakeLists.txt generated

@ -886,6 +886,7 @@ if(gRPC_BUILD_TESTS)
add_dependencies(buildtests_cxx check_gcp_environment_windows_test)
add_dependencies(buildtests_cxx chunked_vector_test)
add_dependencies(buildtests_cxx cli_call_test)
add_dependencies(buildtests_cxx client_auth_filter_test)
add_dependencies(buildtests_cxx client_authority_filter_test)
add_dependencies(buildtests_cxx client_callback_end2end_test)
if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX)
@ -7678,6 +7679,41 @@ target_link_libraries(cli_call_test
)
endif()
if(gRPC_BUILD_TESTS)
add_executable(client_auth_filter_test
test/core/filters/client_auth_filter_test.cc
third_party/googletest/googletest/src/gtest-all.cc
third_party/googletest/googlemock/src/gmock-all.cc
)
target_include_directories(client_auth_filter_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/include
${_gRPC_ADDRESS_SORTING_INCLUDE_DIR}
${_gRPC_RE2_INCLUDE_DIR}
${_gRPC_SSL_INCLUDE_DIR}
${_gRPC_UPB_GENERATED_DIR}
${_gRPC_UPB_GRPC_GENERATED_DIR}
${_gRPC_UPB_INCLUDE_DIR}
${_gRPC_XXHASH_INCLUDE_DIR}
${_gRPC_ZLIB_INCLUDE_DIR}
third_party/googletest/googletest/include
third_party/googletest/googletest
third_party/googletest/googlemock/include
third_party/googletest/googlemock
${_gRPC_PROTO_GENS_DIR}
)
target_link_libraries(client_auth_filter_test
${_gRPC_PROTOBUF_LIBRARIES}
${_gRPC_ALLTARGETS_LIBRARIES}
grpc
)
endif()
if(gRPC_BUILD_TESTS)

@ -5000,6 +5000,17 @@ targets:
- test/cpp/util/service_describer.cc
deps:
- grpc++_test_util
- name: client_auth_filter_test
gtest: true
build: test
language: c++
headers:
- test/core/promise/test_context.h
src:
- test/core/filters/client_auth_filter_test.cc
deps:
- grpc
uses_polling: false
- name: client_authority_filter_test
gtest: true
build: test

@ -61,6 +61,7 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/channel/channel_trace.h"
#include "src/core/lib/channel/status_util.h"
#include "src/core/lib/config/core_configuration.h"
#include "src/core/lib/debug/trace.h"
#include "src/core/lib/gpr/useful.h"
@ -1294,7 +1295,8 @@ void ClientChannel::OnResolverErrorLocked(absl::Status status) {
{
MutexLock lock(&resolution_mu_);
// Update resolver transient failure.
resolver_transient_failure_error_ = status;
resolver_transient_failure_error_ =
MaybeRewriteIllegalStatusCode(status, "resolver");
// Process calls that were queued waiting for the resolver result.
for (ResolverQueuedCall* call = resolver_queued_calls_; call != nullptr;
call = call->next) {
@ -1393,8 +1395,7 @@ void ClientChannel::UpdateServiceConfigInControlPlaneLocked(
RefCountedPtr<ConfigSelector> config_selector, std::string lb_policy_name) {
std::string service_config_json(service_config->json_string());
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_trace)) {
gpr_log(GPR_INFO,
"chand=%p: resolver returned updated service config: \"%s\"", this,
gpr_log(GPR_INFO, "chand=%p: using service config: \"%s\"", this,
service_config_json.c_str());
}
// Save service config.
@ -2165,7 +2166,8 @@ grpc_error_handle ClientChannel::CallData::ApplyServiceConfigToCallLocked(
ConfigSelector::CallConfig call_config =
config_selector->GetCallConfig({&path_, initial_metadata, arena_});
if (!call_config.status.ok()) {
return absl_status_to_grpc_error(call_config.status);
return absl_status_to_grpc_error(MaybeRewriteIllegalStatusCode(
std::move(call_config.status), "ConfigSelector"));
}
// Create a ClientChannelServiceConfigCallData for the call. This stores
// a ref to the ServiceConfig and caches the right set of parsed configs
@ -3158,11 +3160,8 @@ bool ClientChannel::LoadBalancedCall::PickSubchannelLocked(
// attempt's final status.
if (!initial_metadata_batch->GetOrCreatePointer(WaitForReady())
->value) {
grpc_error_handle lb_error =
absl_status_to_grpc_error(fail_pick->status);
*error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Failed to pick subchannel", &lb_error, 1);
GRPC_ERROR_UNREF(lb_error);
*error = absl_status_to_grpc_error(MaybeRewriteIllegalStatusCode(
std::move(fail_pick->status), "LB pick"));
MaybeRemoveCallFromLbQueuedCallsLocked();
return true;
}
@ -3178,9 +3177,10 @@ bool ClientChannel::LoadBalancedCall::PickSubchannelLocked(
gpr_log(GPR_INFO, "chand=%p lb_call=%p: LB pick dropped: %s",
chand_, this, drop_pick->status.ToString().c_str());
}
*error =
grpc_error_set_int(absl_status_to_grpc_error(drop_pick->status),
GRPC_ERROR_INT_LB_POLICY_DROP, 1);
*error = grpc_error_set_int(
absl_status_to_grpc_error(MaybeRewriteIllegalStatusCode(
std::move(drop_pick->status), "LB drop")),
GRPC_ERROR_INT_LB_POLICY_DROP, 1);
MaybeRemoveCallFromLbQueuedCallsLocked();
return true;
});

@ -22,6 +22,8 @@
#include <string.h>
#include "absl/strings/str_cat.h"
#include "src/core/lib/gpr/useful.h"
struct status_string_entry {
@ -109,3 +111,28 @@ bool grpc_status_code_from_int(int status_int, grpc_status_code* status) {
*status = static_cast<grpc_status_code>(status_int);
return true;
}
namespace grpc_core {
absl::Status MaybeRewriteIllegalStatusCode(absl::Status status,
absl::string_view source) {
switch (status.code()) {
// The set of disallowed codes, as per
// https://github.com/grpc/proposal/blob/master/A54-restrict-control-plane-status-codes.md.
case absl::StatusCode::kInvalidArgument:
case absl::StatusCode::kNotFound:
case absl::StatusCode::kAlreadyExists:
case absl::StatusCode::kFailedPrecondition:
case absl::StatusCode::kAborted:
case absl::StatusCode::kOutOfRange:
case absl::StatusCode::kDataLoss: {
return absl::InternalError(
absl::StrCat("Illegal status code from ", source,
"; original status: ", status.ToString()));
}
default:
return status;
}
}
} // namespace grpc_core

@ -21,6 +21,9 @@
#include <grpc/support/port_platform.h>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include <grpc/status.h>
/// If \a status_str is a valid status string, sets \a status to the
@ -59,6 +62,13 @@ class StatusCodeSet {
};
} // namespace internal
// Optionally rewrites a status as per
// https://github.com/grpc/proposal/blob/master/A54-restrict-control-plane-status-codes.md.
// The source parameter indicates where the status came from.
absl::Status MaybeRewriteIllegalStatusCode(absl::Status status,
absl::string_view source);
} // namespace grpc_core
#endif /* GRPC_CORE_LIB_CHANNEL_STATUS_UTIL_H */

@ -77,7 +77,7 @@ T* GetContext() {
// Given a promise and a context, return a promise that has that context set.
template <typename T, typename F>
promise_detail::WithContext<T, F> WithContext(F f, T* context) {
return promise_detail::WithContext<T, F>(f, context);
return promise_detail::WithContext<T, F>(std::move(f), context);
}
} // namespace grpc_core

@ -37,11 +37,14 @@
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/channel/context.h"
#include "src/core/lib/channel/promise_based_filter.h"
#include "src/core/lib/channel/status_util.h"
#include "src/core/lib/gprpp/debug_location.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/detail/basic_seq.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/security/context/security_context.h"
@ -157,7 +160,15 @@ ArenaPromise<absl::StatusOr<CallArgs>> ClientAuthFilter::GetCallCredsMetadata(
auto client_initial_metadata = std::move(call_args.client_initial_metadata);
return TrySeq(
creds->GetRequestMetadata(std::move(client_initial_metadata), &args_),
Seq(creds->GetRequestMetadata(std::move(client_initial_metadata), &args_),
[](absl::StatusOr<ClientMetadataHandle> new_metadata) mutable {
if (!new_metadata.ok()) {
return absl::StatusOr<ClientMetadataHandle>(
MaybeRewriteIllegalStatusCode(new_metadata.status(),
"call credentials"));
}
return new_metadata;
}),
[call_args =
std::move(call_args)](ClientMetadataHandle new_metadata) mutable {
call_args.client_initial_metadata = std::move(new_metadata);

@ -18,13 +18,8 @@
#include <string.h>
#include <atomic>
#include <memory>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include <grpc/byte_buffer.h>
#include <grpc/grpc.h>
@ -36,81 +31,16 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/config/core_configuration.h"
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/orphanable.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/json/json.h"
#include "src/core/lib/load_balancing/lb_policy.h"
#include "src/core/lib/load_balancing/lb_policy_factory.h"
#include "src/core/lib/load_balancing/lb_policy_registry.h"
#include "test/core/end2end/cq_verifier.h"
#include "test/core/end2end/end2end_tests.h"
#include "test/core/util/test_config.h"
#include "test/core/util/test_lb_policies.h"
namespace grpc_core {
namespace {
constexpr absl::string_view kFailPolicyName = "fail_lb";
std::atomic<int> g_num_lb_picks;
class FailPolicy : public LoadBalancingPolicy {
public:
explicit FailPolicy(Args args) : LoadBalancingPolicy(std::move(args)) {}
absl::string_view name() const override { return kFailPolicyName; }
void UpdateLocked(UpdateArgs) override {
absl::Status status = absl::AbortedError("LB pick failed");
channel_control_helper()->UpdateState(
GRPC_CHANNEL_TRANSIENT_FAILURE, status,
absl::make_unique<FailPicker>(status));
}
void ResetBackoffLocked() override {}
void ShutdownLocked() override {}
private:
class FailPicker : public SubchannelPicker {
public:
explicit FailPicker(absl::Status status) : status_(status) {}
PickResult Pick(PickArgs /*args*/) override {
g_num_lb_picks.fetch_add(1);
return PickResult::Fail(status_);
}
private:
absl::Status status_;
};
};
class FailLbConfig : public LoadBalancingPolicy::Config {
public:
absl::string_view name() const override { return kFailPolicyName; }
};
class FailPolicyFactory : public LoadBalancingPolicyFactory {
public:
OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
LoadBalancingPolicy::Args args) const override {
return MakeOrphanable<FailPolicy>(std::move(args));
}
absl::string_view name() const override { return kFailPolicyName; }
absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json& /*json*/) const override {
return MakeRefCounted<FailLbConfig>();
}
};
void RegisterFailPolicy(CoreConfiguration::Builder* builder) {
builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
absl::make_unique<FailPolicyFactory>());
}
} // namespace
} // namespace grpc_core
static void* tag(intptr_t t) { return reinterpret_cast<void*>(t); }
@ -188,7 +118,7 @@ static void test_retry_lb_fail(grpc_end2end_test_config config) {
grpc_call_error error;
grpc_slice details;
grpc_core::g_num_lb_picks.store(0, std::memory_order_relaxed);
g_num_lb_picks.store(0, std::memory_order_relaxed);
grpc_arg args[] = {
grpc_channel_arg_integer_create(
@ -209,7 +139,7 @@ static void test_retry_lb_fail(grpc_end2end_test_config config) {
" \"initialBackoff\": \"1s\",\n"
" \"maxBackoff\": \"120s\",\n"
" \"backoffMultiplier\": 1.6,\n"
" \"retryableStatusCodes\": [ \"ABORTED\" ]\n"
" \"retryableStatusCodes\": [ \"UNAVAILABLE\" ]\n"
" }\n"
" } ]\n"
"}")),
@ -255,7 +185,7 @@ static void test_retry_lb_fail(grpc_end2end_test_config config) {
cqv.Expect(tag(2), true);
cqv.Verify();
GPR_ASSERT(status == GRPC_STATUS_ABORTED);
GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE);
GPR_ASSERT(0 == grpc_slice_str_cmp(details, "LB pick failed"));
grpc_slice_unref(details);
@ -266,7 +196,7 @@ static void test_retry_lb_fail(grpc_end2end_test_config config) {
grpc_call_unref(c);
int num_picks = grpc_core::g_num_lb_picks.load(std::memory_order_relaxed);
int num_picks = g_num_lb_picks.load(std::memory_order_relaxed);
gpr_log(GPR_INFO, "NUM LB PICKS: %d", num_picks);
GPR_ASSERT(num_picks == 2);
@ -280,5 +210,9 @@ void retry_lb_fail(grpc_end2end_test_config config) {
}
void retry_lb_fail_pre_init(void) {
grpc_core::CoreConfiguration::RegisterBuilder(grpc_core::RegisterFailPolicy);
grpc_core::CoreConfiguration::RegisterBuilder(
[](grpc_core::CoreConfiguration::Builder* builder) {
grpc_core::RegisterFailLoadBalancingPolicy(
builder, absl::UnavailableError("LB pick failed"), &g_num_lb_picks);
});
}

@ -19,6 +19,20 @@ licenses(["notice"])
grpc_package(name = "test/core/filters")
grpc_cc_test(
name = "client_auth_filter_test",
srcs = ["client_auth_filter_test.cc"],
external_deps = ["gtest"],
language = "c++",
uses_event_engine = False,
uses_polling = False,
deps = [
"//:grpc",
"//:grpc_security_base",
"//test/core/promise:test_context",
],
)
grpc_cc_test(
name = "client_authority_filter_test",
srcs = ["client_authority_filter_test.cc"],

@ -0,0 +1,194 @@
// Copyright 2022 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/fake/fake_credentials.h"
#include "src/core/lib/security/security_connector/fake/fake_security_connector.h"
#include "src/core/lib/security/transport/auth_filters.h"
#include "test/core/promise/test_context.h"
// TODO(roth): Need to add a lot more tests here. I created this file
// as part of adding a feature, and I added tests only for the feature I
// was adding. When we have time, we need to go back and write
// comprehensive tests for all of the functionality in the filter.
namespace grpc_core {
namespace {
class ClientAuthFilterTest : public ::testing::Test {
protected:
class FailCallCreds : public grpc_call_credentials {
public:
explicit FailCallCreds(absl::Status status)
: grpc_call_credentials(GRPC_SECURITY_NONE),
status_(std::move(status)) {}
UniqueTypeName type() const override {
static UniqueTypeName::Factory kFactory("FailCallCreds");
return kFactory.Create();
}
ArenaPromise<absl::StatusOr<ClientMetadataHandle>> GetRequestMetadata(
ClientMetadataHandle /*initial_metadata*/,
const GetRequestMetadataArgs* /*args*/) override {
return Immediate<absl::StatusOr<ClientMetadataHandle>>(status_);
}
int cmp_impl(const grpc_call_credentials* other) const override {
return QsortCompare(
status_.ToString(),
static_cast<const FailCallCreds*>(other)->status_.ToString());
}
private:
absl::Status status_;
};
ClientAuthFilterTest()
: memory_allocator_(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
"test")),
arena_(MakeScopedArena(1024, &memory_allocator_)),
initial_metadata_batch_(arena_.get()),
trailing_metadata_batch_(arena_.get()),
target_(Slice::FromStaticString("localhost:1234")),
channel_creds_(grpc_fake_transport_security_credentials_create()) {
initial_metadata_batch_.Set(HttpAuthorityMetadata(), target_.Ref());
}
~ClientAuthFilterTest() override {
for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) {
if (call_context_[i].destroy != nullptr) {
call_context_[i].destroy(call_context_[i].value);
}
}
}
ChannelArgs MakeChannelArgs(absl::Status status_for_call_creds) {
ChannelArgs args;
auto security_connector = channel_creds_->create_security_connector(
status_for_call_creds.ok()
? nullptr
: MakeRefCounted<FailCallCreds>(std::move(status_for_call_creds)),
std::string(target_.as_string_view()).c_str(), &args);
auto auth_context = MakeRefCounted<grpc_auth_context>(nullptr);
absl::string_view security_level = "TSI_SECURITY_NONE";
auth_context->add_property(GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME,
security_level.data(), security_level.size());
return args.SetObject(std::move(security_connector))
.SetObject(std::move(auth_context));
}
MemoryAllocator memory_allocator_;
ScopedArenaPtr arena_;
grpc_metadata_batch initial_metadata_batch_;
grpc_metadata_batch trailing_metadata_batch_;
Slice target_;
RefCountedPtr<grpc_channel_credentials> channel_creds_;
grpc_call_context_element call_context_[GRPC_CONTEXT_COUNT];
};
TEST_F(ClientAuthFilterTest, CreateFailsWithoutRequiredChannelArgs) {
EXPECT_FALSE(
ClientAuthFilter::Create(ChannelArgs(), ChannelFilter::Args()).ok());
}
TEST_F(ClientAuthFilterTest, CreateSucceeds) {
auto filter = ClientAuthFilter::Create(MakeChannelArgs(absl::OkStatus()),
ChannelFilter::Args());
EXPECT_TRUE(filter.ok()) << filter.status();
}
TEST_F(ClientAuthFilterTest, CallCredsFails) {
auto filter = ClientAuthFilter::Create(
MakeChannelArgs(absl::UnauthenticatedError("access denied")),
ChannelFilter::Args());
// TODO(ctiller): use Activity here, once it's ready.
TestContext<Arena> context(arena_.get());
TestContext<grpc_call_context_element> promise_call_context(call_context_);
auto promise = filter->MakeCallPromise(
CallArgs{
ClientMetadataHandle::TestOnlyWrap(&initial_metadata_batch_),
nullptr,
},
[&](CallArgs /*call_args*/) {
return ArenaPromise<ServerMetadataHandle>(
[&]() -> Poll<ServerMetadataHandle> {
return ServerMetadataHandle::TestOnlyWrap(
&trailing_metadata_batch_);
});
});
auto result = promise();
ServerMetadataHandle* server_metadata =
absl::get_if<ServerMetadataHandle>(&result);
ASSERT_TRUE(server_metadata != nullptr);
auto status_md = (*server_metadata)->get(GrpcStatusMetadata());
ASSERT_TRUE(status_md.has_value());
EXPECT_EQ(*status_md, GRPC_STATUS_UNAUTHENTICATED);
const Slice* message_md =
(*server_metadata)->get_pointer(GrpcMessageMetadata());
ASSERT_TRUE(message_md != nullptr);
EXPECT_EQ(message_md->as_string_view(), "access denied");
(*server_metadata)->~ServerMetadata();
}
TEST_F(ClientAuthFilterTest, RewritesInvalidStatusFromCallCreds) {
auto filter = ClientAuthFilter::Create(
MakeChannelArgs(absl::AbortedError("nope")), ChannelFilter::Args());
// TODO(ctiller): use Activity here, once it's ready.
TestContext<Arena> context(arena_.get());
TestContext<grpc_call_context_element> promise_call_context(call_context_);
auto promise = filter->MakeCallPromise(
CallArgs{
ClientMetadataHandle::TestOnlyWrap(&initial_metadata_batch_),
nullptr,
},
[&](CallArgs /*call_args*/) {
return ArenaPromise<ServerMetadataHandle>(
[&]() -> Poll<ServerMetadataHandle> {
return ServerMetadataHandle::TestOnlyWrap(
&trailing_metadata_batch_);
});
});
auto result = promise();
ServerMetadataHandle* server_metadata =
absl::get_if<ServerMetadataHandle>(&result);
ASSERT_TRUE(server_metadata != nullptr);
auto status_md = (*server_metadata)->get(GrpcStatusMetadata());
ASSERT_TRUE(status_md.has_value());
EXPECT_EQ(*status_md, GRPC_STATUS_INTERNAL);
const Slice* message_md =
(*server_metadata)->get_pointer(GrpcMessageMetadata());
ASSERT_TRUE(message_md != nullptr);
EXPECT_EQ(message_md->as_string_view(),
"Illegal status code from call credentials; original status: "
"ABORTED: nope");
(*server_metadata)->~ServerMetadata();
}
} // namespace
} // namespace grpc_core
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
grpc_init();
int retval = RUN_ALL_TESTS();
grpc_shutdown();
return retval;
}

@ -656,6 +656,77 @@ class OobBackendMetricTestFactory : public LoadBalancingPolicyFactory {
OobBackendMetricCallback cb_;
};
//
// FailLoadBalancingPolicy
//
constexpr char kFailPolicyName[] = "fail_lb";
class FailPolicy : public LoadBalancingPolicy {
public:
FailPolicy(Args args, absl::Status status, std::atomic<int>* pick_counter)
: LoadBalancingPolicy(std::move(args)),
status_(std::move(status)),
pick_counter_(pick_counter) {}
absl::string_view name() const override { return kFailPolicyName; }
void UpdateLocked(UpdateArgs) override {
channel_control_helper()->UpdateState(
GRPC_CHANNEL_TRANSIENT_FAILURE, status_,
absl::make_unique<FailPicker>(status_, pick_counter_));
}
void ResetBackoffLocked() override {}
void ShutdownLocked() override {}
private:
class FailPicker : public SubchannelPicker {
public:
FailPicker(absl::Status status, std::atomic<int>* pick_counter)
: status_(std::move(status)), pick_counter_(pick_counter) {}
PickResult Pick(PickArgs /*args*/) override {
if (pick_counter_ != nullptr) pick_counter_->fetch_add(1);
return PickResult::Fail(status_);
}
private:
absl::Status status_;
std::atomic<int>* pick_counter_;
};
absl::Status status_;
std::atomic<int>* pick_counter_;
};
class FailLbConfig : public LoadBalancingPolicy::Config {
public:
absl::string_view name() const override { return kFailPolicyName; }
};
class FailLbFactory : public LoadBalancingPolicyFactory {
public:
FailLbFactory(absl::Status status, std::atomic<int>* pick_counter)
: status_(std::move(status)), pick_counter_(pick_counter) {}
OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
LoadBalancingPolicy::Args args) const override {
return MakeOrphanable<FailPolicy>(std::move(args), status_, pick_counter_);
}
absl::string_view name() const override { return kFailPolicyName; }
absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json& /*json*/) const override {
return MakeRefCounted<FailLbConfig>();
}
private:
absl::Status status_;
std::atomic<int>* pick_counter_;
};
} // namespace
void RegisterTestPickArgsLoadBalancingPolicy(
@ -691,4 +762,11 @@ void RegisterOobBackendMetricTestLoadBalancingPolicy(
absl::make_unique<OobBackendMetricTestFactory>(std::move(cb)));
}
void RegisterFailLoadBalancingPolicy(CoreConfiguration::Builder* builder,
absl::Status status,
std::atomic<int>* pick_counter) {
builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
absl::make_unique<FailLbFactory>(std::move(status), pick_counter));
}
} // namespace grpc_core

@ -19,6 +19,7 @@
#include <grpc/support/port_platform.h>
#include <atomic>
#include <functional>
#include <string>
#include <utility>
@ -83,6 +84,13 @@ using OobBackendMetricCallback =
void RegisterOobBackendMetricTestLoadBalancingPolicy(
CoreConfiguration::Builder* builder, OobBackendMetricCallback cb);
// Registers an LB policy called "fail_lb" that fails all picks with the
// specified status. If pick_counter is non-null, it will be
// incremented for each pick.
void RegisterFailLoadBalancingPolicy(CoreConfiguration::Builder* builder,
absl::Status status,
std::atomic<int>* pick_counter = nullptr);
} // namespace grpc_core
#endif // GRPC_TEST_CORE_UTIL_TEST_LB_POLICIES_H

@ -46,6 +46,7 @@
#include <grpcpp/server_builder.h>
#include "src/core/ext/filters/client_channel/backup_poller.h"
#include "src/core/ext/filters/client_channel/config_selector.h"
#include "src/core/ext/filters/client_channel/global_subchannel_pool.h"
#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h"
#include "src/core/lib/address_utils/parse_address.h"
@ -176,6 +177,11 @@ class FakeResolverResponseGeneratorWrapper {
response_generator_->SetFailureOnReresolution();
}
void SetResponse(grpc_core::Resolver::Result result) {
grpc_core::ExecCtx exec_ctx;
response_generator_->SetResponse(std::move(result));
}
grpc_core::FakeResolverResponseGenerator* Get() const {
return response_generator_.get();
}
@ -2684,6 +2690,93 @@ TEST_F(OobBackendMetricTest, Basic) {
}
}
//
// tests rewriting of control plane status codes
//
class ControlPlaneStatusRewritingTest : public ClientLbEnd2endTest {
protected:
static void SetUpTestCase() {
grpc_core::CoreConfiguration::Reset();
grpc_core::CoreConfiguration::RegisterBuilder(
[](grpc_core::CoreConfiguration::Builder* builder) {
grpc_core::RegisterFailLoadBalancingPolicy(
builder, absl::AbortedError("nope"));
});
grpc_init();
}
static void TearDownTestCase() {
grpc_shutdown();
grpc_core::CoreConfiguration::Reset();
}
};
TEST_F(ControlPlaneStatusRewritingTest, RewritesFromLb) {
// Start client.
auto response_generator = BuildResolverResponseGenerator();
auto channel = BuildChannel("fail_lb", response_generator);
auto stub = BuildStub(channel);
response_generator.SetNextResolution(GetServersPorts());
// Send an RPC, verify that status was rewritten.
CheckRpcSendFailure(
DEBUG_LOCATION, stub, StatusCode::INTERNAL,
"Illegal status code from LB pick; original status: ABORTED: nope");
}
TEST_F(ControlPlaneStatusRewritingTest, RewritesFromResolver) {
// Start client.
auto response_generator = BuildResolverResponseGenerator();
auto channel = BuildChannel("pick_first", response_generator);
auto stub = BuildStub(channel);
grpc_core::Resolver::Result result;
result.service_config = absl::AbortedError("nope");
result.addresses.emplace();
response_generator.SetResponse(std::move(result));
// Send an RPC, verify that status was rewritten.
CheckRpcSendFailure(
DEBUG_LOCATION, stub, StatusCode::INTERNAL,
"Illegal status code from resolver; original status: ABORTED: nope");
}
TEST_F(ControlPlaneStatusRewritingTest, RewritesFromConfigSelector) {
class FailConfigSelector : public grpc_core::ConfigSelector {
public:
explicit FailConfigSelector(absl::Status status)
: status_(std::move(status)) {}
const char* name() const override { return "FailConfigSelector"; }
bool Equals(const ConfigSelector* other) const override {
return status_ == static_cast<const FailConfigSelector*>(other)->status_;
}
CallConfig GetCallConfig(GetCallConfigArgs /*args*/) override {
CallConfig config;
config.status = status_;
return config;
}
private:
absl::Status status_;
};
// Start client.
auto response_generator = BuildResolverResponseGenerator();
auto channel = BuildChannel("pick_first", response_generator);
auto stub = BuildStub(channel);
auto config_selector =
grpc_core::MakeRefCounted<FailConfigSelector>(absl::AbortedError("nope"));
grpc_core::Resolver::Result result;
result.addresses.emplace();
result.service_config =
grpc_core::ServiceConfigImpl::Create(grpc_core::ChannelArgs(), "{}");
ASSERT_TRUE(result.service_config.ok()) << result.service_config.status();
result.args = grpc_core::ChannelArgs().SetObject(config_selector);
response_generator.SetResponse(std::move(result));
// Send an RPC, verify that status was rewritten.
CheckRpcSendFailure(
DEBUG_LOCATION, stub, StatusCode::INTERNAL,
"Illegal status code from ConfigSelector; original status: "
"ABORTED: nope");
}
} // namespace
} // namespace testing
} // namespace grpc

@ -1825,6 +1825,30 @@
],
"uses_polling": true
},
{
"args": [],
"benchmark": false,
"ci_platforms": [
"linux",
"mac",
"posix",
"windows"
],
"cpu_cost": 1.0,
"exclude_configs": [],
"exclude_iomgrs": [],
"flaky": false,
"gtest": true,
"language": "c++",
"name": "client_auth_filter_test",
"platforms": [
"linux",
"mac",
"posix",
"windows"
],
"uses_polling": false
},
{
"args": [],
"benchmark": false,

Loading…
Cancel
Save