Add tests that expose a race and lock cycle in alts client, fix TSAN issues

pull/20596/head
Alexander Polcyn 5 years ago
parent 6dde2f43f7
commit 6bb8629879
  1. 63
      CMakeLists.txt
  2. 91
      Makefile
  3. 19
      build.yaml
  4. 8
      src/core/lib/security/credentials/alts/alts_credentials.cc
  5. 21
      src/core/lib/security/security_connector/alts/alts_security_connector.cc
  6. 4
      src/core/lib/security/security_connector/alts/alts_security_connector.h
  7. 9
      src/core/lib/surface/channel.cc
  8. 4
      src/core/lib/surface/channel.h
  9. 26
      src/core/tsi/alts/handshaker/alts_handshaker_client.cc
  10. 164
      src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc
  11. 6
      test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h
  12. 19
      test/core/tsi/alts/handshaker/BUILD
  13. 476
      test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc
  14. 36
      test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc
  15. 18
      tools/run_tests/generated/tests.json

@ -324,6 +324,12 @@ protobuf_generate_grpc_cpp(
protobuf_generate_grpc_cpp(
src/proto/grpc/testing/xds/orca_load_report_for_test.proto
)
protobuf_generate_grpc_cpp(
test/core/tsi/alts/fake_handshaker/handshaker.proto
)
protobuf_generate_grpc_cpp(
test/core/tsi/alts/fake_handshaker/transport_security_common.proto
)
if(gRPC_BUILD_TESTS)
add_custom_target(buildtests_c)
@ -594,6 +600,9 @@ if(gRPC_BUILD_TESTS)
add_custom_target(buildtests_cxx)
add_dependencies(buildtests_cxx alarm_test)
if(_gRPC_PLATFORM_LINUX)
add_dependencies(buildtests_cxx alts_concurrent_connectivity_test)
endif()
add_dependencies(buildtests_cxx alts_counter_test)
add_dependencies(buildtests_cxx alts_crypt_test)
add_dependencies(buildtests_cxx alts_crypter_test)
@ -10190,6 +10199,60 @@ target_link_libraries(alarm_test
)
endif()
if(gRPC_BUILD_TESTS)
if(_gRPC_PLATFORM_LINUX)
add_executable(alts_concurrent_connectivity_test
${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/handshaker.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.h
test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc
test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc
third_party/googletest/googletest/src/gtest-all.cc
third_party/googletest/googlemock/src/gmock-all.cc
)
target_include_directories(alts_concurrent_connectivity_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/include
${_gRPC_ADDRESS_SORTING_INCLUDE_DIR}
${_gRPC_BENCHMARK_INCLUDE_DIR}
${_gRPC_CARES_INCLUDE_DIR}
${_gRPC_GFLAGS_INCLUDE_DIR}
${_gRPC_PROTOBUF_INCLUDE_DIR}
${_gRPC_SSL_INCLUDE_DIR}
${_gRPC_UPB_GENERATED_DIR}
${_gRPC_UPB_GRPC_GENERATED_DIR}
${_gRPC_UPB_INCLUDE_DIR}
${_gRPC_ZLIB_INCLUDE_DIR}
third_party/googletest/googletest/include
third_party/googletest/googletest
third_party/googletest/googlemock/include
third_party/googletest/googlemock
${_gRPC_PROTO_GENS_DIR}
)
target_link_libraries(alts_concurrent_connectivity_test
${_gRPC_PROTOBUF_LIBRARIES}
${_gRPC_ALLTARGETS_LIBRARIES}
grpc++_test_util
grpc_test_util
grpc++
grpc
gpr
grpc++_test_config
${_gRPC_GFLAGS_LIBRARIES}
)
endif()
endif()
if(gRPC_BUILD_TESTS)

@ -1143,6 +1143,7 @@ udp_server_test: $(BINDIR)/$(CONFIG)/udp_server_test
uri_fuzzer_test: $(BINDIR)/$(CONFIG)/uri_fuzzer_test
uri_parser_test: $(BINDIR)/$(CONFIG)/uri_parser_test
alarm_test: $(BINDIR)/$(CONFIG)/alarm_test
alts_concurrent_connectivity_test: $(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test
alts_counter_test: $(BINDIR)/$(CONFIG)/alts_counter_test
alts_crypt_test: $(BINDIR)/$(CONFIG)/alts_crypt_test
alts_crypter_test: $(BINDIR)/$(CONFIG)/alts_crypter_test
@ -1625,6 +1626,7 @@ buildtests_c: privatelibs_c \
ifeq ($(EMBED_OPENSSL),true)
buildtests_cxx: privatelibs_cxx \
$(BINDIR)/$(CONFIG)/alarm_test \
$(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test \
$(BINDIR)/$(CONFIG)/alts_counter_test \
$(BINDIR)/$(CONFIG)/alts_crypt_test \
$(BINDIR)/$(CONFIG)/alts_crypter_test \
@ -1796,6 +1798,7 @@ buildtests_cxx: privatelibs_cxx \
else
buildtests_cxx: privatelibs_cxx \
$(BINDIR)/$(CONFIG)/alarm_test \
$(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test \
$(BINDIR)/$(CONFIG)/alts_counter_test \
$(BINDIR)/$(CONFIG)/alts_crypt_test \
$(BINDIR)/$(CONFIG)/alts_crypter_test \
@ -2234,6 +2237,8 @@ flaky_test_c: buildtests_c
test_cxx: buildtests_cxx
$(E) "[RUN] Testing alarm_test"
$(Q) $(BINDIR)/$(CONFIG)/alarm_test || ( echo test alarm_test failed ; exit 1 )
$(E) "[RUN] Testing alts_concurrent_connectivity_test"
$(Q) $(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test || ( echo test alts_concurrent_connectivity_test failed ; exit 1 )
$(E) "[RUN] Testing alts_counter_test"
$(Q) $(BINDIR)/$(CONFIG)/alts_counter_test || ( echo test alts_counter_test failed ; exit 1 )
$(E) "[RUN] Testing alts_crypt_test"
@ -3053,6 +3058,38 @@ $(GENDIR)/src/proto/grpc/testing/xds/orca_load_report_for_test.grpc.pb.cc: src/p
$(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --grpc_out=$(GENDIR) --plugin=protoc-gen-grpc=$(PROTOC_PLUGINS_DIR)/grpc_cpp_plugin$(EXECUTABLE_SUFFIX) $<
endif
ifeq ($(NO_PROTOC),true)
$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc: protoc_dep_error
$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc: protoc_dep_error
else
$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc: test/core/tsi/alts/fake_handshaker/handshaker.proto $(PROTOBUF_DEP) $(PROTOC_PLUGINS)
$(E) "[PROTOC] Generating protobuf CC file from $<"
$(Q) mkdir -p `dirname $@`
$(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --cpp_out=$(GENDIR) $<
$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc: test/core/tsi/alts/fake_handshaker/handshaker.proto $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc $(PROTOBUF_DEP) $(PROTOC_PLUGINS)
$(E) "[GRPC] Generating gRPC's protobuf service CC file from $<"
$(Q) mkdir -p `dirname $@`
$(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --grpc_out=$(GENDIR) --plugin=protoc-gen-grpc=$(PROTOC_PLUGINS_DIR)/grpc_cpp_plugin$(EXECUTABLE_SUFFIX) $<
endif
ifeq ($(NO_PROTOC),true)
$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc: protoc_dep_error
$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc: protoc_dep_error
else
$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc: test/core/tsi/alts/fake_handshaker/transport_security_common.proto $(PROTOBUF_DEP) $(PROTOC_PLUGINS)
$(E) "[PROTOC] Generating protobuf CC file from $<"
$(Q) mkdir -p `dirname $@`
$(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --cpp_out=$(GENDIR) $<
$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc: test/core/tsi/alts/fake_handshaker/transport_security_common.proto $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc $(PROTOBUF_DEP) $(PROTOC_PLUGINS)
$(E) "[GRPC] Generating gRPC's protobuf service CC file from $<"
$(Q) mkdir -p `dirname $@`
$(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --grpc_out=$(GENDIR) --plugin=protoc-gen-grpc=$(PROTOC_PLUGINS_DIR)/grpc_cpp_plugin$(EXECUTABLE_SUFFIX) $<
endif
ifeq ($(CONFIG),stapprof)
src/core/profiling/stap_timers.c: $(GENDIR)/src/core/profiling/stap_probes.h
@ -13375,6 +13412,60 @@ endif
endif
ALTS_CONCURRENT_CONNECTIVITY_TEST_SRC = \
$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc \
$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc \
test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc \
test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc \
ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS = $(addprefix $(OBJDIR)/$(CONFIG)/, $(addsuffix .o, $(basename $(ALTS_CONCURRENT_CONNECTIVITY_TEST_SRC))))
ifeq ($(NO_SECURE),true)
# You can't build secure targets if you don't have OpenSSL.
$(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test: openssl_dep_error
else
ifeq ($(NO_PROTOBUF),true)
# You can't build the protoc plugins or protobuf-enabled targets if you don't have protobuf 3.5.0+.
$(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test: protobuf_dep_error
else
$(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test: $(PROTOBUF_DEP) $(ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS) $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a
$(E) "[LD] Linking $@"
$(Q) mkdir -p `dirname $@`
$(Q) $(LDXX) $(LDFLAGS) $(ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS) $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a $(LDLIBSXX) $(LDLIBS_PROTOBUF) $(LDLIBS) $(LDLIBS_SECURE) $(GTEST_LIB) -o $(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test
endif
endif
$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/fake_handshaker/handshaker.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a
$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/fake_handshaker/transport_security_common.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a
$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a
$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a
deps_alts_concurrent_connectivity_test: $(ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS:.o=.dep)
ifneq ($(NO_SECURE),true)
ifneq ($(NO_DEPS),true)
-include $(ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS:.o=.dep)
endif
endif
$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.o: $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc
$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.o: $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc
ALTS_COUNTER_TEST_SRC = \
test/core/tsi/alts/frame_protector/alts_counter_test.cc \

@ -3946,6 +3946,25 @@ targets:
- grpc++_unsecure
- grpc_unsecure
- gpr
- name: alts_concurrent_connectivity_test
build: test
language: c++
headers:
- test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h
src:
- test/core/tsi/alts/fake_handshaker/handshaker.proto
- test/core/tsi/alts/fake_handshaker/transport_security_common.proto
- test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc
- test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc
deps:
- grpc++_test_util
- grpc_test_util
- grpc++
- grpc
- gpr
- grpc++_test_config
platforms:
- linux
- name: alts_counter_test
build: test
language: c++

@ -40,7 +40,9 @@ grpc_alts_credentials::grpc_alts_credentials(
options_(grpc_alts_credentials_options_copy(options)),
handshaker_service_url_(handshaker_service_url == nullptr
? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
: gpr_strdup(handshaker_service_url)) {}
: gpr_strdup(handshaker_service_url)) {
grpc_alts_set_rpc_protocol_versions(&options_->rpc_versions);
}
grpc_alts_credentials::~grpc_alts_credentials() {
grpc_alts_credentials_options_destroy(options_);
@ -63,7 +65,9 @@ grpc_alts_server_credentials::grpc_alts_server_credentials(
options_(grpc_alts_credentials_options_copy(options)),
handshaker_service_url_(handshaker_service_url == nullptr
? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
: gpr_strdup(handshaker_service_url)) {}
: gpr_strdup(handshaker_service_url)) {
grpc_alts_set_rpc_protocol_versions(&options_->rpc_versions);
}
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_alts_server_credentials::create_security_connector() {

@ -36,9 +36,7 @@
#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h"
#include "src/core/tsi/transport_security.h"
namespace {
void alts_set_rpc_protocol_versions(
void grpc_alts_set_rpc_protocol_versions(
grpc_gcp_rpc_protocol_versions* rpc_versions) {
grpc_gcp_rpc_protocol_versions_set_max(rpc_versions,
GRPC_PROTOCOL_VERSION_MAX_MAJOR,
@ -48,6 +46,8 @@ void alts_set_rpc_protocol_versions(
GRPC_PROTOCOL_VERSION_MIN_MINOR);
}
namespace {
void alts_check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) {
@ -72,11 +72,7 @@ class grpc_alts_channel_security_connector final
: grpc_channel_security_connector(/*url_scheme=*/nullptr,
std::move(channel_creds),
std::move(request_metadata_creds)),
target_name_(gpr_strdup(target_name)) {
grpc_alts_credentials* creds =
static_cast<grpc_alts_credentials*>(mutable_channel_creds());
alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions);
}
target_name_(gpr_strdup(target_name)) {}
~grpc_alts_channel_security_connector() override { gpr_free(target_name_); }
@ -134,11 +130,8 @@ class grpc_alts_server_security_connector final
grpc_alts_server_security_connector(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
: grpc_server_security_connector(/*url_scheme=*/nullptr,
std::move(server_creds)) {
grpc_alts_server_credentials* creds =
reinterpret_cast<grpc_alts_server_credentials*>(mutable_server_creds());
alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions);
}
std::move(server_creds)) {}
~grpc_alts_server_security_connector() override = default;
void add_handshakers(
@ -193,7 +186,7 @@ grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) {
return nullptr;
}
grpc_gcp_rpc_protocol_versions local_versions, peer_versions;
alts_set_rpc_protocol_versions(&local_versions);
grpc_alts_set_rpc_protocol_versions(&local_versions);
grpc_slice slice = grpc_slice_from_copied_buffer(
rpc_versions_prop->value.data, rpc_versions_prop->value.length);
bool decode_result =

@ -57,6 +57,10 @@ grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_alts_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
/* Initializes rpc_versions. */
void grpc_alts_set_rpc_protocol_versions(
grpc_gcp_rpc_protocol_versions* rpc_versions);
namespace grpc_core {
namespace internal {

@ -500,15 +500,18 @@ static void destroy_channel(void* arg, grpc_error* /*error*/) {
grpc_shutdown();
}
void grpc_channel_destroy(grpc_channel* channel) {
void grpc_channel_destroy_internal(grpc_channel* channel) {
grpc_transport_op* op = grpc_make_transport_op(nullptr);
grpc_channel_element* elem;
grpc_core::ExecCtx exec_ctx;
GRPC_API_TRACE("grpc_channel_destroy(channel=%p)", 1, (channel));
op->disconnect_with_error =
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Channel Destroyed");
elem = grpc_channel_stack_element(CHANNEL_STACK_FROM_CHANNEL(channel), 0);
elem->filter->start_transport_op(elem, op);
GRPC_CHANNEL_INTERNAL_UNREF(channel, "channel");
}
void grpc_channel_destroy(grpc_channel* channel) {
grpc_core::ExecCtx exec_ctx;
grpc_channel_destroy_internal(channel);
}

@ -32,6 +32,10 @@ grpc_channel* grpc_channel_create(const char* target,
grpc_transport* optional_transport,
grpc_resource_user* resource_user = nullptr);
/** The same as grpc_channel_destroy, but doesn't create an ExecCtx, and so
* is safe to use from within core. */
void grpc_channel_destroy_internal(grpc_channel* channel);
grpc_channel* grpc_channel_create_with_builder(
grpc_channel_stack_builder* builder,
grpc_channel_stack_type channel_stack_type);

@ -49,12 +49,8 @@ typedef struct alts_grpc_handshaker_client {
* that validates the data to be sent to handshaker service in a testing use
* case. */
alts_grpc_caller grpc_caller;
/* A callback function provided by gRPC to handle the response returned from
* handshaker service. It also serves to bring the control safely back to
* application when dedicated CQ and thread are used. */
grpc_iomgr_cb_func grpc_cb;
/* A gRPC closure to be scheduled when the response from handshaker service
* is received. It will be initialized with grpc_cb. */
* is received. It will be initialized with the injected grpc RPC callback. */
grpc_closure on_handshaker_service_resp_recv;
/* Buffers containing information to be sent (or received) to (or from) the
* handshaker service. */
@ -415,6 +411,11 @@ static void handshaker_client_shutdown(alts_handshaker_client* c) {
}
}
static void handshaker_call_unref(void* arg, grpc_error* error) {
grpc_call* call = static_cast<grpc_call*>(arg);
grpc_call_unref(call);
}
static void handshaker_client_destruct(alts_handshaker_client* c) {
if (c == nullptr) {
return;
@ -422,7 +423,15 @@ static void handshaker_client_destruct(alts_handshaker_client* c) {
alts_grpc_handshaker_client* client =
reinterpret_cast<alts_grpc_handshaker_client*>(c);
if (client->call != nullptr) {
grpc_call_unref(client->call);
// Throw this grpc_call_unref over to the ExecCtx so that
// we invoke it at the bottom of the call stack and
// prevent lock inversion problems due to nested ExecCtx flushing.
// TODO(apolcyn): we could remove this indirection and call
// grpc_call_unref inline if there was an internal variant of
// grpc_call_unref that didn't need to flush an ExecCtx.
GRPC_CLOSURE_SCHED(GRPC_CLOSURE_CREATE(handshaker_call_unref, client->call,
grpc_schedule_on_exec_ctx),
GRPC_ERROR_NONE);
}
}
@ -454,7 +463,6 @@ alts_handshaker_client* alts_grpc_handshaker_client_create(
client->target_name = grpc_slice_copy(target_name);
client->recv_bytes = grpc_empty_slice();
grpc_metadata_array_init(&client->recv_initial_metadata);
client->grpc_cb = grpc_cb;
client->is_client = is_client;
client->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE;
client->buffer = static_cast<unsigned char*>(gpr_zalloc(client->buffer_size));
@ -469,8 +477,8 @@ alts_handshaker_client* alts_grpc_handshaker_client_create(
GRPC_MILLIS_INF_FUTURE, nullptr);
client->base.vtable =
vtable_for_testing == nullptr ? &vtable : vtable_for_testing;
GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, client->grpc_cb,
client, grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, grpc_cb, client,
grpc_schedule_on_exec_ctx);
grpc_slice_unref_internal(slice);
return &client->base;
}

@ -30,9 +30,11 @@
#include <grpc/support/sync.h>
#include <grpc/support/thd_id.h>
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/gprpp/thd.h"
#include "src/core/lib/iomgr/closure.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/surface/channel.h"
#include "src/core/tsi/alts/frame_protector/alts_frame_protector.h"
#include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
#include "src/core/tsi/alts/handshaker/alts_shared_resource.h"
@ -42,7 +44,6 @@
/* Main struct for ALTS TSI handshaker. */
struct alts_tsi_handshaker {
tsi_handshaker base;
alts_handshaker_client* client;
grpc_slice target_name;
bool is_client;
bool has_sent_start_message;
@ -52,6 +53,16 @@ struct alts_tsi_handshaker {
grpc_alts_credentials_options* options;
alts_handshaker_client_vtable* client_vtable_for_testing;
grpc_channel* channel;
bool use_dedicated_cq;
// mu synchronizes all fields below. Note these are the
// only fields that can be concurrently accessed (due to
// potential concurrency of tsi_handshaker_shutdown and
// tsi_handshaker_next).
gpr_mu mu;
alts_handshaker_client* client;
// shutdown effectively follows base.handshake_shutdown,
// but is synchronized by the mutex of this object.
bool shutdown;
};
/* Main struct for ALTS TSI handshaker result. */
@ -272,22 +283,11 @@ static void on_handshaker_service_resp_recv_dedicated(void* arg,
nullptr, &resource->storage);
}
static tsi_result handshaker_next(
tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** /*bytes_to_send*/,
size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/,
tsi_handshaker_on_next_done_cb cb, void* user_data) {
if (self == nullptr || cb == nullptr) {
gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()");
return TSI_INVALID_ARGUMENT;
}
if (self->handshake_shutdown) {
gpr_log(GPR_ERROR, "TSI handshake shutdown");
return TSI_HANDSHAKE_SHUTDOWN;
}
alts_tsi_handshaker* handshaker =
reinterpret_cast<alts_tsi_handshaker*>(self);
tsi_result ok = TSI_OK;
/* Returns TSI_OK if and only if no error is encountered. */
static tsi_result alts_tsi_handshaker_continue_handshaker_next(
alts_tsi_handshaker* handshaker, const unsigned char* received_bytes,
size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb,
void* user_data) {
if (!handshaker->has_created_handshaker_client) {
if (handshaker->channel == nullptr) {
grpc_alts_shared_resource_dedicated_start(
@ -303,15 +303,24 @@ static tsi_result handshaker_next(
handshaker->channel == nullptr
? grpc_alts_get_shared_resource_dedicated()->channel
: handshaker->channel;
handshaker->client = alts_grpc_handshaker_client_create(
alts_handshaker_client* client = alts_grpc_handshaker_client_create(
handshaker, channel, handshaker->handshaker_service_url,
handshaker->interested_parties, handshaker->options,
handshaker->target_name, grpc_cb, cb, user_data,
handshaker->client_vtable_for_testing, handshaker->is_client);
if (handshaker->client == nullptr) {
if (client == nullptr) {
gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client");
return TSI_FAILED_PRECONDITION;
}
{
grpc_core::MutexLock lock(&handshaker->mu);
GPR_ASSERT(handshaker->client == nullptr);
handshaker->client = client;
if (handshaker->shutdown) {
gpr_log(GPR_ERROR, "TSI handshake shutdown");
return TSI_HANDSHAKE_SHUTDOWN;
}
}
handshaker->has_created_handshaker_client = true;
}
if (handshaker->channel == nullptr &&
@ -324,18 +333,100 @@ static tsi_result handshaker_next(
: grpc_slice_from_copied_buffer(
reinterpret_cast<const char*>(received_bytes),
received_bytes_size);
tsi_result ok = TSI_OK;
if (!handshaker->has_sent_start_message) {
handshaker->has_sent_start_message = true;
ok = handshaker->is_client
? alts_handshaker_client_start_client(handshaker->client)
: alts_handshaker_client_start_server(handshaker->client, &slice);
handshaker->has_sent_start_message = true;
// It's unsafe for the current thread to access any state in handshaker
// at this point, since alts_handshaker_client_start_client/server
// have potentially just started an op batch on the handshake call.
// The completion callback for that batch is unsynchronized and so
// can invoke the TSI next API callback from any thread, at which point
// there is nothing taking ownership of this handshaker to prevent it
// from being destroyed.
} else {
ok = alts_handshaker_client_next(handshaker->client, &slice);
}
grpc_slice_unref_internal(slice);
if (ok != TSI_OK) {
gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests");
return ok;
return ok;
}
struct alts_tsi_handshaker_continue_handshaker_next_args {
alts_tsi_handshaker* handshaker;
grpc_core::UniquePtr<unsigned char> received_bytes;
size_t received_bytes_size;
tsi_handshaker_on_next_done_cb cb;
void* user_data;
grpc_closure closure;
};
static void alts_tsi_handshaker_create_channel(void* arg,
grpc_error* unused_error) {
alts_tsi_handshaker_continue_handshaker_next_args* next_args =
static_cast<alts_tsi_handshaker_continue_handshaker_next_args*>(arg);
alts_tsi_handshaker* handshaker = next_args->handshaker;
GPR_ASSERT(handshaker->channel == nullptr);
handshaker->channel = grpc_insecure_channel_create(
next_args->handshaker->handshaker_service_url, nullptr, nullptr);
tsi_result continue_next_result =
alts_tsi_handshaker_continue_handshaker_next(
handshaker, next_args->received_bytes.get(),
next_args->received_bytes_size, next_args->cb, next_args->user_data);
if (continue_next_result != TSI_OK) {
next_args->cb(continue_next_result, next_args->user_data, nullptr, 0,
nullptr);
}
grpc_core::Delete(next_args);
}
static tsi_result handshaker_next(
tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** /*bytes_to_send*/,
size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/,
tsi_handshaker_on_next_done_cb cb, void* user_data) {
if (self == nullptr || cb == nullptr) {
gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()");
return TSI_INVALID_ARGUMENT;
}
alts_tsi_handshaker* handshaker =
reinterpret_cast<alts_tsi_handshaker*>(self);
{
grpc_core::MutexLock lock(&handshaker->mu);
if (handshaker->shutdown) {
gpr_log(GPR_ERROR, "TSI handshake shutdown");
return TSI_HANDSHAKE_SHUTDOWN;
}
}
if (handshaker->channel == nullptr && !handshaker->use_dedicated_cq) {
alts_tsi_handshaker_continue_handshaker_next_args* args =
grpc_core::New<alts_tsi_handshaker_continue_handshaker_next_args>();
args->handshaker = handshaker;
args->received_bytes = nullptr;
args->received_bytes_size = received_bytes_size;
if (received_bytes_size > 0) {
args->received_bytes = grpc_core::UniquePtr<unsigned char>(
static_cast<unsigned char*>(gpr_zalloc(received_bytes_size)));
memcpy(args->received_bytes.get(), received_bytes, received_bytes_size);
}
args->cb = cb;
args->user_data = user_data;
GRPC_CLOSURE_INIT(&args->closure, alts_tsi_handshaker_create_channel, args,
grpc_schedule_on_exec_ctx);
// We continue this handshaker_next call at the bottom of the ExecCtx just
// so that we can invoke grpc_channel_create at the bottom of the call
// stack. Doing so avoids potential lock cycles between g_init_mu and other
// mutexes within core that might be held on the current call stack
// (note that g_init_mu gets acquired during channel creation).
GRPC_CLOSURE_SCHED(&args->closure, GRPC_ERROR_NONE);
} else {
tsi_result ok = alts_tsi_handshaker_continue_handshaker_next(
handshaker, received_bytes, received_bytes_size, cb, user_data);
if (ok != TSI_OK) {
gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests");
return ok;
}
}
return TSI_ASYNC;
}
@ -358,12 +449,14 @@ static tsi_result handshaker_next_dedicated(
static void handshaker_shutdown(tsi_handshaker* self) {
GPR_ASSERT(self != nullptr);
if (self->handshake_shutdown) {
return;
}
alts_tsi_handshaker* handshaker =
reinterpret_cast<alts_tsi_handshaker*>(self);
grpc_core::MutexLock lock(&handshaker->mu);
if (handshaker->shutdown) {
return;
}
alts_handshaker_client_shutdown(handshaker->client);
handshaker->shutdown = true;
}
static void handshaker_destroy(tsi_handshaker* self) {
@ -376,9 +469,10 @@ static void handshaker_destroy(tsi_handshaker* self) {
grpc_slice_unref_internal(handshaker->target_name);
grpc_alts_credentials_options_destroy(handshaker->options);
if (handshaker->channel != nullptr) {
grpc_channel_destroy(handshaker->channel);
grpc_channel_destroy_internal(handshaker->channel);
}
gpr_free(handshaker->handshaker_service_url);
gpr_mu_destroy(&handshaker->mu);
gpr_free(handshaker);
}
@ -400,7 +494,8 @@ static const tsi_handshaker_vtable handshaker_vtable_dedicated = {
bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) {
GPR_ASSERT(handshaker != nullptr);
return handshaker->base.handshake_shutdown;
grpc_core::MutexLock lock(&handshaker->mu);
return handshaker->shutdown;
}
tsi_result alts_tsi_handshaker_create(
@ -414,7 +509,8 @@ tsi_result alts_tsi_handshaker_create(
}
alts_tsi_handshaker* handshaker =
static_cast<alts_tsi_handshaker*>(gpr_zalloc(sizeof(*handshaker)));
bool use_dedicated_cq = interested_parties == nullptr;
gpr_mu_init(&handshaker->mu);
handshaker->use_dedicated_cq = interested_parties == nullptr;
handshaker->client = nullptr;
handshaker->is_client = is_client;
handshaker->has_sent_start_message = false;
@ -425,13 +521,9 @@ tsi_result alts_tsi_handshaker_create(
handshaker->has_created_handshaker_client = false;
handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url);
handshaker->options = grpc_alts_credentials_options_copy(options);
handshaker->base.vtable =
use_dedicated_cq ? &handshaker_vtable_dedicated : &handshaker_vtable;
handshaker->channel =
use_dedicated_cq
? nullptr
: grpc_insecure_channel_create(handshaker->handshaker_service_url,
nullptr, nullptr);
handshaker->base.vtable = handshaker->use_dedicated_cq
? &handshaker_vtable_dedicated
: &handshaker_vtable;
*self = &handshaker->base;
return TSI_OK;
}

@ -15,6 +15,10 @@
* limitations under the License.
*
*/
#ifndef TEST_CORE_TSI_ALTS_FAKE_HANDSHAKER_FAKE_HANDSHAKER_SERVER_H
#define TEST_CORE_TSI_ALTS_FAKE_HANDSHAKER_FAKE_HANDSHAKER_SERVER_H
#include <memory>
#include <string>
@ -27,3 +31,5 @@ std::unique_ptr<grpc::Service> CreateFakeHandshakerService();
} // namespace gcp
} // namespace grpc
#endif // TEST_CORE_TSI_ALTS_FAKE_HANDSHAKER_FAKE_HANDSHAKER_SERVER_H

@ -77,3 +77,22 @@ grpc_cc_test(
"//test/core/util:grpc_test_util",
],
)
grpc_cc_test(
name = "alts_concurrent_connectivity_test",
srcs = [
"alts_concurrent_connectivity_test.cc",
],
language = "C++",
deps = [
"//:alts_util",
"//:grpc",
"//test/core/util:grpc_test_util",
"//test/core/tsi/alts/fake_handshaker:fake_handshaker_lib",
"//test/core/end2end:cq_verifier",
],
external_deps = ["gtest"],
# TODO(apolcyn): make the fake TCP server used in this
# test portable to Windows.
tags = ["no_windows"],
)

@ -0,0 +1,476 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#include <grpc/support/port_platform.h>
#include <fcntl.h>
#include <gmock/gmock.h>
#include <netinet/in.h>
#include <pthread.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <functional>
#include <set>
#include <thread>
#include <grpc/grpc.h>
#include <grpc/grpc_security.h>
#include <grpc/slice.h>
#include <grpc/support/alloc.h>
#include <grpc/support/log.h>
#include <grpc/support/string_util.h>
#include <grpc/support/time.h>
#include <grpcpp/impl/codegen/service_type.h>
#include <grpcpp/server_builder.h>
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/host_port.h"
#include "src/core/lib/gprpp/thd.h"
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/security/credentials/alts/alts_credentials.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/security_connector/alts/alts_security_connector.h"
#include "src/core/lib/slice/slice_string_helpers.h"
#include "test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h"
#include "test/core/util/memory_counters.h"
#include "test/core/util/port.h"
#include "test/core/util/test_config.h"
#include "test/core/end2end/cq_verifier.h"
namespace {
void drain_cq(grpc_completion_queue* cq) {
grpc_event ev;
do {
ev = grpc_completion_queue_next(
cq, grpc_timeout_milliseconds_to_deadline(5000), nullptr);
} while (ev.type != GRPC_QUEUE_SHUTDOWN);
}
grpc_channel* create_secure_channel_for_test(
const char* server_addr, const char* fake_handshake_server_addr) {
grpc_alts_credentials_options* alts_options =
grpc_alts_credentials_client_options_create();
grpc_channel_credentials* channel_creds =
grpc_alts_credentials_create_customized(alts_options,
fake_handshake_server_addr,
true /* enable_untrusted_alts */);
grpc_alts_credentials_options_destroy(alts_options);
// The main goal of these tests are to stress concurrent ALTS handshakes,
// so we prevent subchnannel sharing.
grpc_arg disable_subchannel_sharing_arg =
grpc_channel_arg_integer_create(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
grpc_channel_args channel_args = {1, &disable_subchannel_sharing_arg};
grpc_channel* channel = grpc_secure_channel_create(channel_creds, server_addr,
&channel_args, nullptr);
grpc_channel_credentials_release(channel_creds);
return channel;
}
class FakeHandshakeServer {
public:
FakeHandshakeServer() {
int port = grpc_pick_unused_port_or_die();
grpc_core::JoinHostPort(&address_, "localhost", port);
service_ = grpc::gcp::CreateFakeHandshakerService();
grpc::ServerBuilder builder;
builder.AddListeningPort(address_.get(), grpc::InsecureServerCredentials());
builder.RegisterService(service_.get());
server_ = builder.BuildAndStart();
gpr_log(GPR_INFO, "Fake handshaker server listening on %s", address_.get());
}
~FakeHandshakeServer() {
server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
}
const char* address() { return address_.get(); }
private:
grpc_core::UniquePtr<char> address_;
std::unique_ptr<grpc::Service> service_;
std::unique_ptr<grpc::Server> server_;
};
class TestServer {
public:
explicit TestServer(const char* fake_handshake_server_address) {
grpc_alts_credentials_options* alts_options =
grpc_alts_credentials_server_options_create();
grpc_server_credentials* server_creds =
grpc_alts_server_credentials_create_customized(
alts_options, fake_handshake_server_address,
true /* enable_untrusted_alts */);
grpc_alts_credentials_options_destroy(alts_options);
server_ = grpc_server_create(nullptr, nullptr);
server_cq_ = grpc_completion_queue_create_for_next(nullptr);
grpc_server_register_completion_queue(server_, server_cq_, nullptr);
int port = grpc_pick_unused_port_or_die();
GPR_ASSERT(grpc_core::JoinHostPort(&server_addr_, "localhost", port));
GPR_ASSERT(grpc_server_add_secure_http2_port(server_, server_addr_.get(),
server_creds));
grpc_server_credentials_release(server_creds);
grpc_server_start(server_);
gpr_log(GPR_DEBUG, "Start TestServer %p. listen on %s", this,
server_addr_.get());
server_thd_ =
std::unique_ptr<std::thread>(new std::thread(PollUntilShutdown, this));
}
~TestServer() {
gpr_log(GPR_DEBUG, "Begin dtor of TestServer %p", this);
grpc_server_shutdown_and_notify(server_, server_cq_, this);
server_thd_->join();
grpc_server_destroy(server_);
grpc_completion_queue_shutdown(server_cq_);
drain_cq(server_cq_);
grpc_completion_queue_destroy(server_cq_);
}
const char* address() { return server_addr_.get(); }
static void PollUntilShutdown(const TestServer* self) {
grpc_event ev = grpc_completion_queue_next(
self->server_cq_, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr);
GPR_ASSERT(ev.type == GRPC_OP_COMPLETE);
GPR_ASSERT(ev.tag == self);
gpr_log(GPR_DEBUG, "TestServer %p stop polling", self);
}
private:
grpc_server* server_;
grpc_completion_queue* server_cq_;
std::unique_ptr<std::thread> server_thd_;
grpc_core::UniquePtr<char> server_addr_;
};
class ConnectLoopRunner {
public:
explicit ConnectLoopRunner(
const char* server_address, const char* fake_handshake_server_addr,
int per_connect_deadline_seconds, size_t loops,
grpc_connectivity_state expected_connectivity_states)
: server_address_(std::unique_ptr<char>(gpr_strdup(server_address))),
fake_handshake_server_addr_(
std::unique_ptr<char>(gpr_strdup(fake_handshake_server_addr))),
per_connect_deadline_seconds_(per_connect_deadline_seconds),
loops_(loops),
expected_connectivity_states_(expected_connectivity_states) {
thd_ = std::unique_ptr<std::thread>(new std::thread(ConnectLoop, this));
}
~ConnectLoopRunner() { thd_->join(); }
static void ConnectLoop(const ConnectLoopRunner* self) {
for (size_t i = 0; i < self->loops_; i++) {
gpr_log(GPR_DEBUG, "runner:%p connect_loop begin loop %ld", self, i);
grpc_completion_queue* cq =
grpc_completion_queue_create_for_next(nullptr);
grpc_channel* channel = create_secure_channel_for_test(
self->server_address_.get(), self->fake_handshake_server_addr_.get());
// Connect, forcing an ALTS handshake
gpr_timespec connect_deadline =
grpc_timeout_seconds_to_deadline(self->per_connect_deadline_seconds_);
grpc_connectivity_state state =
grpc_channel_check_connectivity_state(channel, 1);
ASSERT_EQ(state, GRPC_CHANNEL_IDLE);
while (state != self->expected_connectivity_states_) {
if (self->expected_connectivity_states_ ==
GRPC_CHANNEL_TRANSIENT_FAILURE) {
ASSERT_NE(state, GRPC_CHANNEL_READY); // sanity check
} else {
ASSERT_EQ(self->expected_connectivity_states_, GRPC_CHANNEL_READY);
}
grpc_channel_watch_connectivity_state(
channel, state, gpr_inf_future(GPR_CLOCK_REALTIME), cq, nullptr);
grpc_event ev =
grpc_completion_queue_next(cq, connect_deadline, nullptr);
ASSERT_EQ(ev.type, GRPC_OP_COMPLETE)
<< "connect_loop runner:" << std::hex << self
<< " got ev.type:" << ev.type << " i:" << i;
ASSERT_TRUE(ev.success);
state = grpc_channel_check_connectivity_state(channel, 1);
}
grpc_channel_destroy(channel);
grpc_completion_queue_shutdown(cq);
drain_cq(cq);
grpc_completion_queue_destroy(cq);
gpr_log(GPR_DEBUG, "runner:%p connect_loop finished loop %ld", self, i);
}
}
private:
std::unique_ptr<char> server_address_;
std::unique_ptr<char> fake_handshake_server_addr_;
int per_connect_deadline_seconds_;
size_t loops_;
grpc_connectivity_state expected_connectivity_states_;
std::unique_ptr<std::thread> thd_;
};
// Perform a few ALTS handshakes sequentially (using the fake, in-process ALTS
// handshake server).
TEST(AltsConcurrentConnectivityTest, TestBasicClientServerHandshakes) {
FakeHandshakeServer fake_handshake_server;
TestServer test_server(fake_handshake_server.address());
{
ConnectLoopRunner runner(
test_server.address(), fake_handshake_server.address(),
5 /* per connect deadline seconds */, 10 /* loops */,
GRPC_CHANNEL_READY /* expected connectivity states */);
}
}
/* Run a bunch of concurrent ALTS handshakes on concurrent channels
* (using the fake, in-process handshake server). */
TEST(AltsConcurrentConnectivityTest, TestConcurrentClientServerHandshakes) {
FakeHandshakeServer fake_handshake_server;
// Test
{
TestServer test_server(fake_handshake_server.address());
gpr_timespec test_deadline = grpc_timeout_seconds_to_deadline(20);
size_t num_concurrent_connects = 50;
std::vector<std::unique_ptr<ConnectLoopRunner>> connect_loop_runners;
gpr_log(GPR_DEBUG,
"start performing concurrent expected-to-succeed connects");
for (size_t i = 0; i < num_concurrent_connects; i++) {
connect_loop_runners.push_back(
std::unique_ptr<ConnectLoopRunner>(new ConnectLoopRunner(
test_server.address(), fake_handshake_server.address(),
15 /* per connect deadline seconds */, 5 /* loops */,
GRPC_CHANNEL_READY /* expected connectivity states */)));
}
connect_loop_runners.clear();
gpr_log(GPR_DEBUG,
"done performing concurrent expected-to-succeed connects");
if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), test_deadline) > 0) {
gpr_log(GPR_DEBUG, "Test took longer than expected.");
abort();
}
}
}
class FakeTcpServer {
public:
enum ProcessReadResult {
CONTINUE_READING,
CLOSE_SOCKET,
};
FakeTcpServer(
const std::function<ProcessReadResult(int, int, int)>& process_read_cb)
: process_read_cb_(process_read_cb) {
port_ = grpc_pick_unused_port_or_die();
accept_socket_ = socket(AF_INET6, SOCK_STREAM, 0);
char* addr_str;
GPR_ASSERT(gpr_asprintf(&addr_str, "[::]:%d", port_));
address_ = std::unique_ptr<char>(addr_str);
GPR_ASSERT(accept_socket_ != -1);
if (accept_socket_ == -1) {
gpr_log(GPR_ERROR, "Failed to create socket: %d", errno);
abort();
}
int val = 1;
if (setsockopt(accept_socket_, SOL_SOCKET, SO_REUSEADDR, &val,
sizeof(val)) != 0) {
gpr_log(GPR_ERROR,
"Failed to set SO_REUSEADDR on socket bound to [::1]:%d : %d",
port_, errno);
abort();
}
if (fcntl(accept_socket_, F_SETFL, O_NONBLOCK) != 0) {
gpr_log(GPR_ERROR, "Failed to set O_NONBLOCK on socket: %d", errno);
abort();
}
sockaddr_in6 addr;
memset(&addr, 0, sizeof(addr));
addr.sin6_family = AF_INET6;
addr.sin6_port = htons(port_);
((char*)&addr.sin6_addr)[15] = 1;
if (bind(accept_socket_, (const sockaddr*)&addr, sizeof(addr)) != 0) {
gpr_log(GPR_ERROR, "Failed to bind socket to [::1]:%d : %d", port_,
errno);
abort();
}
if (listen(accept_socket_, 100)) {
gpr_log(GPR_ERROR, "Failed to listen on socket bound to [::1]:%d : %d",
port_, errno);
abort();
}
gpr_event_init(&stop_ev_);
run_server_loop_thd_ =
std::unique_ptr<std::thread>(new std::thread(RunServerLoop, this));
}
~FakeTcpServer() {
gpr_log(GPR_DEBUG,
"FakeTcpServer stop and "
"join server thread");
gpr_event_set(&stop_ev_, (void*)1);
run_server_loop_thd_->join();
gpr_log(GPR_DEBUG,
"FakeTcpServer join server "
"thread complete");
}
const char* address() { return address_.get(); }
static ProcessReadResult CloseSocketUponReceivingBytesFromPeer(
int bytes_received_size, int read_error, int s) {
if (bytes_received_size < 0 && read_error != EAGAIN &&
read_error != EWOULDBLOCK) {
gpr_log(GPR_ERROR, "Failed to receive from peer socket: %d. errno: %d", s,
errno);
abort();
}
if (bytes_received_size >= 0) {
gpr_log(GPR_DEBUG,
"Fake TCP server received %d bytes from peer socket: %d. Close "
"the "
"connection.",
bytes_received_size, s);
return CLOSE_SOCKET;
}
return CONTINUE_READING;
}
static ProcessReadResult CloseSocketUponCloseFromPeer(int bytes_received_size,
int read_error, int s) {
if (bytes_received_size < 0 && read_error != EAGAIN &&
read_error != EWOULDBLOCK) {
gpr_log(GPR_ERROR, "Failed to receive from peer socket: %d. errno: %d", s,
errno);
abort();
}
if (bytes_received_size == 0) {
// The peer has shut down the connection.
gpr_log(GPR_DEBUG,
"Fake TCP server received 0 bytes from peer socket: %d. Close "
"the "
"connection.",
s);
return CLOSE_SOCKET;
}
return CONTINUE_READING;
}
// Run a loop that periodically, every 10 ms:
// 1) Checks if there are any new TCP connections to accept.
// 2) Checks if any data has arrived yet on established connections,
// and reads from them if so, processing the sockets as configured.
static void RunServerLoop(FakeTcpServer* self) {
std::set<int> peers;
while (!gpr_event_get(&self->stop_ev_)) {
int p = accept(self->accept_socket_, nullptr, nullptr);
if (p == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
gpr_log(GPR_ERROR, "Failed to accept connection: %d", errno);
abort();
}
if (p != -1) {
gpr_log(GPR_DEBUG, "accepted peer socket: %d", p);
if (fcntl(p, F_SETFL, O_NONBLOCK) != 0) {
gpr_log(GPR_ERROR,
"Failed to set O_NONBLOCK on peer socket:%d errno:%d", p,
errno);
abort();
}
peers.insert(p);
}
auto it = peers.begin();
while (it != peers.end()) {
int p = *it;
char buf[100];
int bytes_received_size = recv(p, buf, 100, 0);
ProcessReadResult r =
self->process_read_cb_(bytes_received_size, errno, p);
if (r == CLOSE_SOCKET) {
close(p);
it = peers.erase(it);
} else {
GPR_ASSERT(r == CONTINUE_READING);
it++;
}
}
gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
gpr_time_from_millis(10, GPR_TIMESPAN)));
}
for (auto it = peers.begin(); it != peers.end(); it++) {
close(*it);
}
close(self->accept_socket_);
}
private:
int accept_socket_;
int port_;
gpr_event stop_ev_;
std::unique_ptr<char> address_;
std::unique_ptr<std::thread> run_server_loop_thd_;
std::function<ProcessReadResult(int, int, int)> process_read_cb_;
};
/* This test is intended to make sure that ALTS handshakes we correctly
* fail fast when the security handshaker gets an error while reading
* from the remote peer, after having earlier sent the first bytes of the
* ALTS handshake to the peer, i.e. after getting into the middle of a
* handshake. */
TEST(AltsConcurrentConnectivityTest,
TestHandshakeFailsFastWhenPeerEndpointClosesConnectionAfterAccepting) {
FakeHandshakeServer fake_handshake_server;
FakeTcpServer fake_tcp_server(
FakeTcpServer::CloseSocketUponReceivingBytesFromPeer);
{
gpr_timespec test_deadline = grpc_timeout_seconds_to_deadline(20);
std::vector<std::unique_ptr<ConnectLoopRunner>> connect_loop_runners;
size_t num_concurrent_connects = 100;
gpr_log(GPR_DEBUG, "start performing concurrent expected-to-fail connects");
for (size_t i = 0; i < num_concurrent_connects; i++) {
connect_loop_runners.push_back(std::unique_ptr<
ConnectLoopRunner>(new ConnectLoopRunner(
fake_tcp_server.address(), fake_handshake_server.address(),
10 /* per connect deadline seconds */, 3 /* loops */,
GRPC_CHANNEL_TRANSIENT_FAILURE /* expected connectivity states */)));
}
connect_loop_runners.clear();
gpr_log(GPR_DEBUG, "done performing concurrent expected-to-fail connects");
if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), test_deadline) > 0) {
gpr_log(GPR_ERROR,
"Exceeded test deadline. ALTS handshakes might not be failing "
"fast when the peer endpoint closes the connection abruptly");
abort();
}
}
}
} // namespace
int main(int argc, char** argv) {
grpc_init();
grpc::testing::TestEnvironment env(argc, argv);
::testing::InitGoogleTest(&argc, argv);
auto result = RUN_ALL_TESTS();
grpc_shutdown();
return result;
}

@ -320,8 +320,12 @@ static tsi_result mock_client_start(alts_handshaker_client* client) {
if (!should_handshaker_client_api_succeed) {
return TSI_INTERNAL_ERROR;
}
/* Note that the alts_tsi_handshaker needs to set its
* has_sent_start_message field field to true
* before the call to alts_handshaker_client_start is made because
* because it's unsafe to access it afterwards. */
alts_handshaker_client_check_fields_for_testing(
client, on_client_start_success_cb, nullptr, false, nullptr);
client, on_client_start_success_cb, nullptr, true, nullptr);
/* Populate handshaker response for client_start request. */
grpc_byte_buffer** recv_buffer_ptr =
alts_handshaker_client_get_recv_buffer_addr_for_testing(client);
@ -339,7 +343,7 @@ static tsi_result mock_server_start(alts_handshaker_client* client,
return TSI_INTERNAL_ERROR;
}
alts_handshaker_client_check_fields_for_testing(
client, on_server_start_success_cb, nullptr, false, nullptr);
client, on_server_start_success_cb, nullptr, true, nullptr);
grpc_slice slice = grpc_empty_slice();
GPR_ASSERT(grpc_slice_cmp(*bytes_received, slice) == 0);
/* Populate handshaker response for server_start request. */
@ -404,6 +408,12 @@ static tsi_handshaker* create_test_handshaker(bool is_client) {
return handshaker;
}
static void run_tsi_handshaker_destroy_with_exec_ctx(
tsi_handshaker* handshaker) {
grpc_core::ExecCtx exec_ctx;
tsi_handshaker_destroy(handshaker);
}
static void check_handshaker_next_invalid_input() {
/* Initialization. */
tsi_handshaker* handshaker = create_test_handshaker(true);
@ -416,7 +426,7 @@ static void check_handshaker_next_invalid_input() {
nullptr, nullptr,
nullptr) == TSI_INVALID_ARGUMENT);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(handshaker);
}
static void check_handshaker_shutdown_invalid_input() {
@ -425,7 +435,7 @@ static void check_handshaker_shutdown_invalid_input() {
/* Check nullptr handshaker. */
tsi_handshaker_shutdown(nullptr);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(handshaker);
}
static void check_handshaker_next_success() {
@ -462,8 +472,8 @@ static void check_handshaker_next_success() {
nullptr, on_server_next_success_cb, nullptr) == TSI_ASYNC);
wait(&tsi_to_caller_notification);
/* Cleanup. */
tsi_handshaker_destroy(server_handshaker);
tsi_handshaker_destroy(client_handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(server_handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(client_handshaker);
}
static void check_handshaker_next_with_shutdown() {
@ -481,7 +491,7 @@ static void check_handshaker_next_with_shutdown() {
nullptr, on_client_next_success_cb,
nullptr) == TSI_HANDSHAKE_SHUTDOWN);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(handshaker);
}
static void check_handle_response_with_shutdown(void* /*unused*/) {
@ -520,8 +530,8 @@ static void check_handshaker_next_failure() {
nullptr, check_must_not_be_called,
nullptr) == TSI_INTERNAL_ERROR);
/* Cleanup. */
tsi_handshaker_destroy(server_handshaker);
tsi_handshaker_destroy(client_handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(server_handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(client_handshaker);
}
static void on_invalid_input_cb(tsi_result status, void* user_data,
@ -584,7 +594,7 @@ static void check_handle_response_invalid_input() {
alts_handshaker_client_handle_response(client, false);
/* Cleanup. */
grpc_slice_unref(slice);
tsi_handshaker_destroy(handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(handshaker);
notification_destroy(&caller_to_tsi_notification);
notification_destroy(&tsi_to_caller_notification);
}
@ -622,7 +632,7 @@ static void check_handle_response_invalid_resp() {
recv_buffer, GRPC_STATUS_OK);
alts_handshaker_client_handle_response(client, true);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(handshaker);
notification_destroy(&caller_to_tsi_notification);
notification_destroy(&tsi_to_caller_notification);
}
@ -675,7 +685,7 @@ static void check_handle_response_failure() {
recv_buffer, GRPC_STATUS_OK);
alts_handshaker_client_handle_response(client, true /* is_ok*/);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(handshaker);
notification_destroy(&caller_to_tsi_notification);
notification_destroy(&tsi_to_caller_notification);
}
@ -714,7 +724,7 @@ static void check_handle_response_after_shutdown() {
recv_buffer, GRPC_STATUS_OK);
alts_handshaker_client_handle_response(client, true);
/* Cleanup. */
tsi_handshaker_destroy(handshaker);
run_tsi_handshaker_destroy_with_exec_ctx(handshaker);
notification_destroy(&caller_to_tsi_notification);
notification_destroy(&tsi_to_caller_notification);
}

@ -3035,6 +3035,24 @@
],
"uses_polling": true
},
{
"args": [],
"benchmark": false,
"ci_platforms": [
"linux"
],
"cpu_cost": 1.0,
"exclude_configs": [],
"exclude_iomgrs": [],
"flaky": false,
"gtest": false,
"language": "c++",
"name": "alts_concurrent_connectivity_test",
"platforms": [
"linux"
],
"uses_polling": true
},
{
"args": [],
"benchmark": false,

Loading…
Cancel
Save