diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc index 46311fa122c..858ab6b41bb 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.cc +++ b/src/core/lib/security/credentials/fake/fake_credentials.cc @@ -32,9 +32,6 @@ /* -- Fake transport security credentials. -- */ -#define GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS \ - "grpc.fake_security.expected_targets" - static grpc_security_status fake_transport_security_create_security_connector( grpc_channel_credentials* c, grpc_call_credentials* call_creds, const char* target, const grpc_channel_args* args, diff --git a/src/core/lib/security/credentials/fake/fake_credentials.h b/src/core/lib/security/credentials/fake/fake_credentials.h index 5166e43167f..e89e6e24cca 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.h +++ b/src/core/lib/security/credentials/fake/fake_credentials.h @@ -23,6 +23,9 @@ #include "src/core/lib/security/credentials/credentials.h" +#define GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS \ + "grpc.fake_security.expected_targets" + /* -- Fake transport security credentials. -- */ /* Creates a fake transport security credentials object for testing. */ diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc index a57c8953742..3cc151bec73 100644 --- a/src/core/lib/security/security_connector/security_connector.cc +++ b/src/core/lib/security/security_connector/security_connector.cc @@ -463,6 +463,15 @@ static bool fake_channel_check_call_host(grpc_channel_security_connector* sc, grpc_auth_context* auth_context, grpc_closure* on_call_host_checked, grpc_error** error) { + grpc_fake_channel_security_connector* c = + reinterpret_cast(sc); + if (c->is_lb_channel) { + // TODO(dgq): verify that the host (ie, authority header) matches that of + // the LB, as opposed to that of the backends. + } else { + // TODO(dgq): verify that the host (ie, authority header) matches that of + // the backend, not the LB's. + } return true; } diff --git a/test/cpp/end2end/grpclb_end2end_test.cc b/test/cpp/end2end/grpclb_end2end_test.cc index eb354907f6d..fcfe860b1cc 100644 --- a/test/cpp/end2end/grpclb_end2end_test.cc +++ b/test/cpp/end2end/grpclb_end2end_test.cc @@ -37,6 +37,10 @@ #include "src/core/lib/gpr/thd.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" + +#include "src/cpp/client/secure_credentials.h" #include "test/core/util/port.h" #include "test/core/util/test_config.h" @@ -380,15 +384,21 @@ class GrpclbEnd2endTest : public ::testing::Test { SetNextResolution(addresses); } - void ResetStub(int fallback_timeout = 0) { + void ResetStub(int fallback_timeout = 0, grpc::string expected_targets = "") { ChannelArguments args; args.SetGrpclbFallbackTimeout(fallback_timeout); args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, response_generator_.get()); + if (!expected_targets.empty()) { + args.SetString(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS, expected_targets); + } std::ostringstream uri; - uri << "fake:///servername_not_used"; - channel_ = - CreateCustomChannel(uri.str(), InsecureChannelCredentials(), args); + uri << "fake:///" << kApplicationTargetName_; + // TODO(dgq): templatize tests to run everything using both secure and + // insecure channel credentials. + std::shared_ptr creds(new SecureChannelCredentials( + grpc_fake_transport_security_credentials_create())); + channel_ = CreateCustomChannel(uri.str(), creds, args); stub_ = grpc::testing::EchoTestService::NewStub(channel_); } @@ -566,8 +576,9 @@ class GrpclbEnd2endTest : public ::testing::Test { std::ostringstream server_address; server_address << server_host << ":" << port_; ServerBuilder builder; - builder.AddListeningPort(server_address.str(), - InsecureServerCredentials()); + std::shared_ptr creds(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create())); + builder.AddListeningPort(server_address.str(), creds); builder.RegisterService(service_); server_ = builder.BuildAndStart(); cond->notify_one(); @@ -600,6 +611,7 @@ class GrpclbEnd2endTest : public ::testing::Test { grpc_core::RefCountedPtr response_generator_; const grpc::string kRequestMessage_ = "Live long and prosper."; + const grpc::string kApplicationTargetName_ = "application_target_name"; }; class SingleBalancerTest : public GrpclbEnd2endTest { @@ -635,6 +647,48 @@ TEST_F(SingleBalancerTest, Vanilla) { EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); } +TEST_F(SingleBalancerTest, SecureNaming) { + ResetStub(0, kApplicationTargetName_ + ";lb"); + SetNextResolution({AddressData{balancer_servers_[0].port_, true, "lb"}}); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backend_servers_[i].service_->request_count()); + } + balancers_[0]->NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancer_servers_[0].service_->request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancer_servers_[0].service_->response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, SecureNamingDeathTest) { + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; + // Make sure that we blow up (via abort() from the security connector) when + // the name from the balancer doesn't match expectations. + ASSERT_DEATH( + { + ResetStub(0, kApplicationTargetName_ + ";lb"); + SetNextResolution( + {AddressData{balancer_servers_[0].port_, true, "woops"}}); + channel_->WaitForConnected(grpc_timeout_seconds_to_deadline(1)); + }, + ""); +} + TEST_F(SingleBalancerTest, InitiallyEmptyServerlist) { SetNextResolutionAllBalancers(); const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor();