diff --git a/CMakeLists.txt b/CMakeLists.txt index e29d8f47b97..71626d5a7bf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -324,6 +324,12 @@ protobuf_generate_grpc_cpp( protobuf_generate_grpc_cpp( src/proto/grpc/testing/xds/orca_load_report_for_test.proto ) +protobuf_generate_grpc_cpp( + test/core/tsi/alts/fake_handshaker/handshaker.proto +) +protobuf_generate_grpc_cpp( + test/core/tsi/alts/fake_handshaker/transport_security_common.proto +) if(gRPC_BUILD_TESTS) add_custom_target(buildtests_c) @@ -594,6 +600,9 @@ if(gRPC_BUILD_TESTS) add_custom_target(buildtests_cxx) add_dependencies(buildtests_cxx alarm_test) + if(_gRPC_PLATFORM_LINUX) + add_dependencies(buildtests_cxx alts_concurrent_connectivity_test) + endif() add_dependencies(buildtests_cxx alts_counter_test) add_dependencies(buildtests_cxx alts_crypt_test) add_dependencies(buildtests_cxx alts_crypter_test) @@ -10190,6 +10199,60 @@ target_link_libraries(alarm_test ) +endif() +if(gRPC_BUILD_TESTS) +if(_gRPC_PLATFORM_LINUX) + + add_executable(alts_concurrent_connectivity_test + ${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc + ${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc + ${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/handshaker.pb.h + ${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.h + ${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc + ${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc + ${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.h + ${_gRPC_PROTO_GENS_DIR}/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.h + test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc + test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc + third_party/googletest/googletest/src/gtest-all.cc + third_party/googletest/googlemock/src/gmock-all.cc + ) + + target_include_directories(alts_concurrent_connectivity_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_BENCHMARK_INCLUDE_DIR} + ${_gRPC_CARES_INCLUDE_DIR} + ${_gRPC_GFLAGS_INCLUDE_DIR} + ${_gRPC_PROTOBUF_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} + ) + + target_link_libraries(alts_concurrent_connectivity_test + ${_gRPC_PROTOBUF_LIBRARIES} + ${_gRPC_ALLTARGETS_LIBRARIES} + grpc++_test_util + grpc_test_util + grpc++ + grpc + gpr + grpc++_test_config + ${_gRPC_GFLAGS_LIBRARIES} + ) + + +endif() endif() if(gRPC_BUILD_TESTS) diff --git a/Makefile b/Makefile index 82f6a32b895..ffde7c586db 100644 --- a/Makefile +++ b/Makefile @@ -1143,6 +1143,7 @@ udp_server_test: $(BINDIR)/$(CONFIG)/udp_server_test uri_fuzzer_test: $(BINDIR)/$(CONFIG)/uri_fuzzer_test uri_parser_test: $(BINDIR)/$(CONFIG)/uri_parser_test alarm_test: $(BINDIR)/$(CONFIG)/alarm_test +alts_concurrent_connectivity_test: $(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test alts_counter_test: $(BINDIR)/$(CONFIG)/alts_counter_test alts_crypt_test: $(BINDIR)/$(CONFIG)/alts_crypt_test alts_crypter_test: $(BINDIR)/$(CONFIG)/alts_crypter_test @@ -1625,6 +1626,7 @@ buildtests_c: privatelibs_c \ ifeq ($(EMBED_OPENSSL),true) buildtests_cxx: privatelibs_cxx \ $(BINDIR)/$(CONFIG)/alarm_test \ + $(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test \ $(BINDIR)/$(CONFIG)/alts_counter_test \ $(BINDIR)/$(CONFIG)/alts_crypt_test \ $(BINDIR)/$(CONFIG)/alts_crypter_test \ @@ -1796,6 +1798,7 @@ buildtests_cxx: privatelibs_cxx \ else buildtests_cxx: privatelibs_cxx \ $(BINDIR)/$(CONFIG)/alarm_test \ + $(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test \ $(BINDIR)/$(CONFIG)/alts_counter_test \ $(BINDIR)/$(CONFIG)/alts_crypt_test \ $(BINDIR)/$(CONFIG)/alts_crypter_test \ @@ -2234,6 +2237,8 @@ flaky_test_c: buildtests_c test_cxx: buildtests_cxx $(E) "[RUN] Testing alarm_test" $(Q) $(BINDIR)/$(CONFIG)/alarm_test || ( echo test alarm_test failed ; exit 1 ) + $(E) "[RUN] Testing alts_concurrent_connectivity_test" + $(Q) $(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test || ( echo test alts_concurrent_connectivity_test failed ; exit 1 ) $(E) "[RUN] Testing alts_counter_test" $(Q) $(BINDIR)/$(CONFIG)/alts_counter_test || ( echo test alts_counter_test failed ; exit 1 ) $(E) "[RUN] Testing alts_crypt_test" @@ -3053,6 +3058,38 @@ $(GENDIR)/src/proto/grpc/testing/xds/orca_load_report_for_test.grpc.pb.cc: src/p $(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --grpc_out=$(GENDIR) --plugin=protoc-gen-grpc=$(PROTOC_PLUGINS_DIR)/grpc_cpp_plugin$(EXECUTABLE_SUFFIX) $< endif +ifeq ($(NO_PROTOC),true) +$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc: protoc_dep_error +$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc: protoc_dep_error +else + +$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc: test/core/tsi/alts/fake_handshaker/handshaker.proto $(PROTOBUF_DEP) $(PROTOC_PLUGINS) + $(E) "[PROTOC] Generating protobuf CC file from $<" + $(Q) mkdir -p `dirname $@` + $(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --cpp_out=$(GENDIR) $< + +$(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc: test/core/tsi/alts/fake_handshaker/handshaker.proto $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc $(PROTOBUF_DEP) $(PROTOC_PLUGINS) + $(E) "[GRPC] Generating gRPC's protobuf service CC file from $<" + $(Q) mkdir -p `dirname $@` + $(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --grpc_out=$(GENDIR) --plugin=protoc-gen-grpc=$(PROTOC_PLUGINS_DIR)/grpc_cpp_plugin$(EXECUTABLE_SUFFIX) $< +endif + +ifeq ($(NO_PROTOC),true) +$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc: protoc_dep_error +$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc: protoc_dep_error +else + +$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc: test/core/tsi/alts/fake_handshaker/transport_security_common.proto $(PROTOBUF_DEP) $(PROTOC_PLUGINS) + $(E) "[PROTOC] Generating protobuf CC file from $<" + $(Q) mkdir -p `dirname $@` + $(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --cpp_out=$(GENDIR) $< + +$(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc: test/core/tsi/alts/fake_handshaker/transport_security_common.proto $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc $(PROTOBUF_DEP) $(PROTOC_PLUGINS) + $(E) "[GRPC] Generating gRPC's protobuf service CC file from $<" + $(Q) mkdir -p `dirname $@` + $(Q) $(PROTOC) -Ithird_party/protobuf/src -I. --grpc_out=$(GENDIR) --plugin=protoc-gen-grpc=$(PROTOC_PLUGINS_DIR)/grpc_cpp_plugin$(EXECUTABLE_SUFFIX) $< +endif + ifeq ($(CONFIG),stapprof) src/core/profiling/stap_timers.c: $(GENDIR)/src/core/profiling/stap_probes.h @@ -13375,6 +13412,60 @@ endif endif +ALTS_CONCURRENT_CONNECTIVITY_TEST_SRC = \ + $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc \ + $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc \ + test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc \ + test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc \ + +ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS = $(addprefix $(OBJDIR)/$(CONFIG)/, $(addsuffix .o, $(basename $(ALTS_CONCURRENT_CONNECTIVITY_TEST_SRC)))) +ifeq ($(NO_SECURE),true) + +# You can't build secure targets if you don't have OpenSSL. + +$(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test: openssl_dep_error + +else + + + + +ifeq ($(NO_PROTOBUF),true) + +# You can't build the protoc plugins or protobuf-enabled targets if you don't have protobuf 3.5.0+. + +$(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test: protobuf_dep_error + +else + +$(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test: $(PROTOBUF_DEP) $(ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS) $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a + $(E) "[LD] Linking $@" + $(Q) mkdir -p `dirname $@` + $(Q) $(LDXX) $(LDFLAGS) $(ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS) $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a $(LDLIBSXX) $(LDLIBS_PROTOBUF) $(LDLIBS) $(LDLIBS_SECURE) $(GTEST_LIB) -o $(BINDIR)/$(CONFIG)/alts_concurrent_connectivity_test + +endif + +endif + +$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/fake_handshaker/handshaker.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a + +$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/fake_handshaker/transport_security_common.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a + +$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a + +$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LIBDIR)/$(CONFIG)/libgrpc++_test_config.a + +deps_alts_concurrent_connectivity_test: $(ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS:.o=.dep) + +ifneq ($(NO_SECURE),true) +ifneq ($(NO_DEPS),true) +-include $(ALTS_CONCURRENT_CONNECTIVITY_TEST_OBJS:.o=.dep) +endif +endif +$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.o: $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc +$(OBJDIR)/$(CONFIG)/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.o: $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.pb.cc $(GENDIR)/test/core/tsi/alts/fake_handshaker/transport_security_common.grpc.pb.cc + + ALTS_COUNTER_TEST_SRC = \ test/core/tsi/alts/frame_protector/alts_counter_test.cc \ diff --git a/build.yaml b/build.yaml index 78a5d32d76a..a341bc18450 100644 --- a/build.yaml +++ b/build.yaml @@ -3946,6 +3946,25 @@ targets: - grpc++_unsecure - grpc_unsecure - gpr +- name: alts_concurrent_connectivity_test + build: test + language: c++ + headers: + - test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h + src: + - test/core/tsi/alts/fake_handshaker/handshaker.proto + - test/core/tsi/alts/fake_handshaker/transport_security_common.proto + - test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc + - test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc + deps: + - grpc++_test_util + - grpc_test_util + - grpc++ + - grpc + - gpr + - grpc++_test_config + platforms: + - linux - name: alts_counter_test build: test language: c++ diff --git a/src/core/lib/security/credentials/alts/alts_credentials.cc b/src/core/lib/security/credentials/alts/alts_credentials.cc index 4c3c7e113ee..cbc43d565a4 100644 --- a/src/core/lib/security/credentials/alts/alts_credentials.cc +++ b/src/core/lib/security/credentials/alts/alts_credentials.cc @@ -40,7 +40,9 @@ grpc_alts_credentials::grpc_alts_credentials( options_(grpc_alts_credentials_options_copy(options)), handshaker_service_url_(handshaker_service_url == nullptr ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) - : gpr_strdup(handshaker_service_url)) {} + : gpr_strdup(handshaker_service_url)) { + grpc_alts_set_rpc_protocol_versions(&options_->rpc_versions); +} grpc_alts_credentials::~grpc_alts_credentials() { grpc_alts_credentials_options_destroy(options_); @@ -63,7 +65,9 @@ grpc_alts_server_credentials::grpc_alts_server_credentials( options_(grpc_alts_credentials_options_copy(options)), handshaker_service_url_(handshaker_service_url == nullptr ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) - : gpr_strdup(handshaker_service_url)) {} + : gpr_strdup(handshaker_service_url)) { + grpc_alts_set_rpc_protocol_versions(&options_->rpc_versions); +} grpc_core::RefCountedPtr grpc_alts_server_credentials::create_security_connector() { diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc index a20cf448f8d..5a535110d78 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -36,9 +36,7 @@ #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" #include "src/core/tsi/transport_security.h" -namespace { - -void alts_set_rpc_protocol_versions( +void grpc_alts_set_rpc_protocol_versions( grpc_gcp_rpc_protocol_versions* rpc_versions) { grpc_gcp_rpc_protocol_versions_set_max(rpc_versions, GRPC_PROTOCOL_VERSION_MAX_MAJOR, @@ -48,6 +46,8 @@ void alts_set_rpc_protocol_versions( GRPC_PROTOCOL_VERSION_MIN_MINOR); } +namespace { + void alts_check_peer(tsi_peer peer, grpc_core::RefCountedPtr* auth_context, grpc_closure* on_peer_checked) { @@ -72,11 +72,7 @@ class grpc_alts_channel_security_connector final : grpc_channel_security_connector(/*url_scheme=*/nullptr, std::move(channel_creds), std::move(request_metadata_creds)), - target_name_(gpr_strdup(target_name)) { - grpc_alts_credentials* creds = - static_cast(mutable_channel_creds()); - alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions); - } + target_name_(gpr_strdup(target_name)) {} ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); } @@ -134,11 +130,8 @@ class grpc_alts_server_security_connector final grpc_alts_server_security_connector( grpc_core::RefCountedPtr server_creds) : grpc_server_security_connector(/*url_scheme=*/nullptr, - std::move(server_creds)) { - grpc_alts_server_credentials* creds = - reinterpret_cast(mutable_server_creds()); - alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions); - } + std::move(server_creds)) {} + ~grpc_alts_server_security_connector() override = default; void add_handshakers( @@ -193,7 +186,7 @@ grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) { return nullptr; } grpc_gcp_rpc_protocol_versions local_versions, peer_versions; - alts_set_rpc_protocol_versions(&local_versions); + grpc_alts_set_rpc_protocol_versions(&local_versions); grpc_slice slice = grpc_slice_from_copied_buffer( rpc_versions_prop->value.data, rpc_versions_prop->value.length); bool decode_result = diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.h b/src/core/lib/security/security_connector/alts/alts_security_connector.h index b96dc36b302..b68a4e97781 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.h +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.h @@ -57,6 +57,10 @@ grpc_core::RefCountedPtr grpc_alts_server_security_connector_create( grpc_core::RefCountedPtr server_creds); +/* Initializes rpc_versions. */ +void grpc_alts_set_rpc_protocol_versions( + grpc_gcp_rpc_protocol_versions* rpc_versions); + namespace grpc_core { namespace internal { diff --git a/src/core/lib/surface/channel.cc b/src/core/lib/surface/channel.cc index c6bebf715b2..95f06c6dbab 100644 --- a/src/core/lib/surface/channel.cc +++ b/src/core/lib/surface/channel.cc @@ -500,15 +500,18 @@ static void destroy_channel(void* arg, grpc_error* /*error*/) { grpc_shutdown(); } -void grpc_channel_destroy(grpc_channel* channel) { +void grpc_channel_destroy_internal(grpc_channel* channel) { grpc_transport_op* op = grpc_make_transport_op(nullptr); grpc_channel_element* elem; - grpc_core::ExecCtx exec_ctx; GRPC_API_TRACE("grpc_channel_destroy(channel=%p)", 1, (channel)); op->disconnect_with_error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Channel Destroyed"); elem = grpc_channel_stack_element(CHANNEL_STACK_FROM_CHANNEL(channel), 0); elem->filter->start_transport_op(elem, op); - GRPC_CHANNEL_INTERNAL_UNREF(channel, "channel"); } + +void grpc_channel_destroy(grpc_channel* channel) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_destroy_internal(channel); +} diff --git a/src/core/lib/surface/channel.h b/src/core/lib/surface/channel.h index 5136e93bcd5..280fc96a0bc 100644 --- a/src/core/lib/surface/channel.h +++ b/src/core/lib/surface/channel.h @@ -32,6 +32,10 @@ grpc_channel* grpc_channel_create(const char* target, grpc_transport* optional_transport, grpc_resource_user* resource_user = nullptr); +/** The same as grpc_channel_destroy, but doesn't create an ExecCtx, and so + * is safe to use from within core. */ +void grpc_channel_destroy_internal(grpc_channel* channel); + grpc_channel* grpc_channel_create_with_builder( grpc_channel_stack_builder* builder, grpc_channel_stack_type channel_stack_type); diff --git a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc index 55fd066dba5..7318d6ddeb6 100644 --- a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc +++ b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc @@ -49,12 +49,8 @@ typedef struct alts_grpc_handshaker_client { * that validates the data to be sent to handshaker service in a testing use * case. */ alts_grpc_caller grpc_caller; - /* A callback function provided by gRPC to handle the response returned from - * handshaker service. It also serves to bring the control safely back to - * application when dedicated CQ and thread are used. */ - grpc_iomgr_cb_func grpc_cb; /* A gRPC closure to be scheduled when the response from handshaker service - * is received. It will be initialized with grpc_cb. */ + * is received. It will be initialized with the injected grpc RPC callback. */ grpc_closure on_handshaker_service_resp_recv; /* Buffers containing information to be sent (or received) to (or from) the * handshaker service. */ @@ -415,6 +411,11 @@ static void handshaker_client_shutdown(alts_handshaker_client* c) { } } +static void handshaker_call_unref(void* arg, grpc_error* error) { + grpc_call* call = static_cast(arg); + grpc_call_unref(call); +} + static void handshaker_client_destruct(alts_handshaker_client* c) { if (c == nullptr) { return; @@ -422,7 +423,15 @@ static void handshaker_client_destruct(alts_handshaker_client* c) { alts_grpc_handshaker_client* client = reinterpret_cast(c); if (client->call != nullptr) { - grpc_call_unref(client->call); + // Throw this grpc_call_unref over to the ExecCtx so that + // we invoke it at the bottom of the call stack and + // prevent lock inversion problems due to nested ExecCtx flushing. + // TODO(apolcyn): we could remove this indirection and call + // grpc_call_unref inline if there was an internal variant of + // grpc_call_unref that didn't need to flush an ExecCtx. + GRPC_CLOSURE_SCHED(GRPC_CLOSURE_CREATE(handshaker_call_unref, client->call, + grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); } } @@ -454,7 +463,6 @@ alts_handshaker_client* alts_grpc_handshaker_client_create( client->target_name = grpc_slice_copy(target_name); client->recv_bytes = grpc_empty_slice(); grpc_metadata_array_init(&client->recv_initial_metadata); - client->grpc_cb = grpc_cb; client->is_client = is_client; client->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE; client->buffer = static_cast(gpr_zalloc(client->buffer_size)); @@ -469,8 +477,8 @@ alts_handshaker_client* alts_grpc_handshaker_client_create( GRPC_MILLIS_INF_FUTURE, nullptr); client->base.vtable = vtable_for_testing == nullptr ? &vtable : vtable_for_testing; - GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, client->grpc_cb, - client, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, grpc_cb, client, + grpc_schedule_on_exec_ctx); grpc_slice_unref_internal(slice); return &client->base; } diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc index c6018af533a..e56448d798c 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc @@ -30,9 +30,11 @@ #include #include +#include "src/core/lib/gprpp/sync.h" #include "src/core/lib/gprpp/thd.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel.h" #include "src/core/tsi/alts/frame_protector/alts_frame_protector.h" #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" #include "src/core/tsi/alts/handshaker/alts_shared_resource.h" @@ -42,7 +44,6 @@ /* Main struct for ALTS TSI handshaker. */ struct alts_tsi_handshaker { tsi_handshaker base; - alts_handshaker_client* client; grpc_slice target_name; bool is_client; bool has_sent_start_message; @@ -52,6 +53,16 @@ struct alts_tsi_handshaker { grpc_alts_credentials_options* options; alts_handshaker_client_vtable* client_vtable_for_testing; grpc_channel* channel; + bool use_dedicated_cq; + // mu synchronizes all fields below. Note these are the + // only fields that can be concurrently accessed (due to + // potential concurrency of tsi_handshaker_shutdown and + // tsi_handshaker_next). + gpr_mu mu; + alts_handshaker_client* client; + // shutdown effectively follows base.handshake_shutdown, + // but is synchronized by the mutex of this object. + bool shutdown; }; /* Main struct for ALTS TSI handshaker result. */ @@ -272,22 +283,11 @@ static void on_handshaker_service_resp_recv_dedicated(void* arg, nullptr, &resource->storage); } -static tsi_result handshaker_next( - tsi_handshaker* self, const unsigned char* received_bytes, - size_t received_bytes_size, const unsigned char** /*bytes_to_send*/, - size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/, - tsi_handshaker_on_next_done_cb cb, void* user_data) { - if (self == nullptr || cb == nullptr) { - gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()"); - return TSI_INVALID_ARGUMENT; - } - if (self->handshake_shutdown) { - gpr_log(GPR_ERROR, "TSI handshake shutdown"); - return TSI_HANDSHAKE_SHUTDOWN; - } - alts_tsi_handshaker* handshaker = - reinterpret_cast(self); - tsi_result ok = TSI_OK; +/* Returns TSI_OK if and only if no error is encountered. */ +static tsi_result alts_tsi_handshaker_continue_handshaker_next( + alts_tsi_handshaker* handshaker, const unsigned char* received_bytes, + size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb, + void* user_data) { if (!handshaker->has_created_handshaker_client) { if (handshaker->channel == nullptr) { grpc_alts_shared_resource_dedicated_start( @@ -303,15 +303,24 @@ static tsi_result handshaker_next( handshaker->channel == nullptr ? grpc_alts_get_shared_resource_dedicated()->channel : handshaker->channel; - handshaker->client = alts_grpc_handshaker_client_create( + alts_handshaker_client* client = alts_grpc_handshaker_client_create( handshaker, channel, handshaker->handshaker_service_url, handshaker->interested_parties, handshaker->options, handshaker->target_name, grpc_cb, cb, user_data, handshaker->client_vtable_for_testing, handshaker->is_client); - if (handshaker->client == nullptr) { + if (client == nullptr) { gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client"); return TSI_FAILED_PRECONDITION; } + { + grpc_core::MutexLock lock(&handshaker->mu); + GPR_ASSERT(handshaker->client == nullptr); + handshaker->client = client; + if (handshaker->shutdown) { + gpr_log(GPR_ERROR, "TSI handshake shutdown"); + return TSI_HANDSHAKE_SHUTDOWN; + } + } handshaker->has_created_handshaker_client = true; } if (handshaker->channel == nullptr && @@ -324,18 +333,100 @@ static tsi_result handshaker_next( : grpc_slice_from_copied_buffer( reinterpret_cast(received_bytes), received_bytes_size); + tsi_result ok = TSI_OK; if (!handshaker->has_sent_start_message) { + handshaker->has_sent_start_message = true; ok = handshaker->is_client ? alts_handshaker_client_start_client(handshaker->client) : alts_handshaker_client_start_server(handshaker->client, &slice); - handshaker->has_sent_start_message = true; + // It's unsafe for the current thread to access any state in handshaker + // at this point, since alts_handshaker_client_start_client/server + // have potentially just started an op batch on the handshake call. + // The completion callback for that batch is unsynchronized and so + // can invoke the TSI next API callback from any thread, at which point + // there is nothing taking ownership of this handshaker to prevent it + // from being destroyed. } else { ok = alts_handshaker_client_next(handshaker->client, &slice); } grpc_slice_unref_internal(slice); - if (ok != TSI_OK) { - gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests"); - return ok; + return ok; +} + +struct alts_tsi_handshaker_continue_handshaker_next_args { + alts_tsi_handshaker* handshaker; + grpc_core::UniquePtr received_bytes; + size_t received_bytes_size; + tsi_handshaker_on_next_done_cb cb; + void* user_data; + grpc_closure closure; +}; + +static void alts_tsi_handshaker_create_channel(void* arg, + grpc_error* unused_error) { + alts_tsi_handshaker_continue_handshaker_next_args* next_args = + static_cast(arg); + alts_tsi_handshaker* handshaker = next_args->handshaker; + GPR_ASSERT(handshaker->channel == nullptr); + handshaker->channel = grpc_insecure_channel_create( + next_args->handshaker->handshaker_service_url, nullptr, nullptr); + tsi_result continue_next_result = + alts_tsi_handshaker_continue_handshaker_next( + handshaker, next_args->received_bytes.get(), + next_args->received_bytes_size, next_args->cb, next_args->user_data); + if (continue_next_result != TSI_OK) { + next_args->cb(continue_next_result, next_args->user_data, nullptr, 0, + nullptr); + } + grpc_core::Delete(next_args); +} + +static tsi_result handshaker_next( + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** /*bytes_to_send*/, + size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/, + tsi_handshaker_on_next_done_cb cb, void* user_data) { + if (self == nullptr || cb == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()"); + return TSI_INVALID_ARGUMENT; + } + alts_tsi_handshaker* handshaker = + reinterpret_cast(self); + { + grpc_core::MutexLock lock(&handshaker->mu); + if (handshaker->shutdown) { + gpr_log(GPR_ERROR, "TSI handshake shutdown"); + return TSI_HANDSHAKE_SHUTDOWN; + } + } + if (handshaker->channel == nullptr && !handshaker->use_dedicated_cq) { + alts_tsi_handshaker_continue_handshaker_next_args* args = + grpc_core::New(); + args->handshaker = handshaker; + args->received_bytes = nullptr; + args->received_bytes_size = received_bytes_size; + if (received_bytes_size > 0) { + args->received_bytes = grpc_core::UniquePtr( + static_cast(gpr_zalloc(received_bytes_size))); + memcpy(args->received_bytes.get(), received_bytes, received_bytes_size); + } + args->cb = cb; + args->user_data = user_data; + GRPC_CLOSURE_INIT(&args->closure, alts_tsi_handshaker_create_channel, args, + grpc_schedule_on_exec_ctx); + // We continue this handshaker_next call at the bottom of the ExecCtx just + // so that we can invoke grpc_channel_create at the bottom of the call + // stack. Doing so avoids potential lock cycles between g_init_mu and other + // mutexes within core that might be held on the current call stack + // (note that g_init_mu gets acquired during channel creation). + GRPC_CLOSURE_SCHED(&args->closure, GRPC_ERROR_NONE); + } else { + tsi_result ok = alts_tsi_handshaker_continue_handshaker_next( + handshaker, received_bytes, received_bytes_size, cb, user_data); + if (ok != TSI_OK) { + gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests"); + return ok; + } } return TSI_ASYNC; } @@ -358,12 +449,14 @@ static tsi_result handshaker_next_dedicated( static void handshaker_shutdown(tsi_handshaker* self) { GPR_ASSERT(self != nullptr); - if (self->handshake_shutdown) { - return; - } alts_tsi_handshaker* handshaker = reinterpret_cast(self); + grpc_core::MutexLock lock(&handshaker->mu); + if (handshaker->shutdown) { + return; + } alts_handshaker_client_shutdown(handshaker->client); + handshaker->shutdown = true; } static void handshaker_destroy(tsi_handshaker* self) { @@ -376,9 +469,10 @@ static void handshaker_destroy(tsi_handshaker* self) { grpc_slice_unref_internal(handshaker->target_name); grpc_alts_credentials_options_destroy(handshaker->options); if (handshaker->channel != nullptr) { - grpc_channel_destroy(handshaker->channel); + grpc_channel_destroy_internal(handshaker->channel); } gpr_free(handshaker->handshaker_service_url); + gpr_mu_destroy(&handshaker->mu); gpr_free(handshaker); } @@ -400,7 +494,8 @@ static const tsi_handshaker_vtable handshaker_vtable_dedicated = { bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) { GPR_ASSERT(handshaker != nullptr); - return handshaker->base.handshake_shutdown; + grpc_core::MutexLock lock(&handshaker->mu); + return handshaker->shutdown; } tsi_result alts_tsi_handshaker_create( @@ -414,7 +509,8 @@ tsi_result alts_tsi_handshaker_create( } alts_tsi_handshaker* handshaker = static_cast(gpr_zalloc(sizeof(*handshaker))); - bool use_dedicated_cq = interested_parties == nullptr; + gpr_mu_init(&handshaker->mu); + handshaker->use_dedicated_cq = interested_parties == nullptr; handshaker->client = nullptr; handshaker->is_client = is_client; handshaker->has_sent_start_message = false; @@ -425,13 +521,9 @@ tsi_result alts_tsi_handshaker_create( handshaker->has_created_handshaker_client = false; handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url); handshaker->options = grpc_alts_credentials_options_copy(options); - handshaker->base.vtable = - use_dedicated_cq ? &handshaker_vtable_dedicated : &handshaker_vtable; - handshaker->channel = - use_dedicated_cq - ? nullptr - : grpc_insecure_channel_create(handshaker->handshaker_service_url, - nullptr, nullptr); + handshaker->base.vtable = handshaker->use_dedicated_cq + ? &handshaker_vtable_dedicated + : &handshaker_vtable; *self = &handshaker->base; return TSI_OK; } diff --git a/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h b/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h index eb4bfdffa12..c6d767b5799 100644 --- a/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h +++ b/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h @@ -15,6 +15,10 @@ * limitations under the License. * */ + +#ifndef TEST_CORE_TSI_ALTS_FAKE_HANDSHAKER_FAKE_HANDSHAKER_SERVER_H +#define TEST_CORE_TSI_ALTS_FAKE_HANDSHAKER_FAKE_HANDSHAKER_SERVER_H + #include #include @@ -27,3 +31,5 @@ std::unique_ptr CreateFakeHandshakerService(); } // namespace gcp } // namespace grpc + +#endif // TEST_CORE_TSI_ALTS_FAKE_HANDSHAKER_FAKE_HANDSHAKER_SERVER_H diff --git a/test/core/tsi/alts/handshaker/BUILD b/test/core/tsi/alts/handshaker/BUILD index ed97056c003..42cf3b985f1 100644 --- a/test/core/tsi/alts/handshaker/BUILD +++ b/test/core/tsi/alts/handshaker/BUILD @@ -77,3 +77,22 @@ grpc_cc_test( "//test/core/util:grpc_test_util", ], ) + +grpc_cc_test( + name = "alts_concurrent_connectivity_test", + srcs = [ + "alts_concurrent_connectivity_test.cc", + ], + language = "C++", + deps = [ + "//:alts_util", + "//:grpc", + "//test/core/util:grpc_test_util", + "//test/core/tsi/alts/fake_handshaker:fake_handshaker_lib", + "//test/core/end2end:cq_verifier", + ], + external_deps = ["gtest"], + # TODO(apolcyn): make the fake TCP server used in this + # test portable to Windows. + tags = ["no_windows"], +) diff --git a/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc b/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc new file mode 100644 index 00000000000..204129ad0ba --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc @@ -0,0 +1,476 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/security/credentials/alts/alts_credentials.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/alts/alts_security_connector.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +#include "test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h" +#include "test/core/util/memory_counters.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#include "test/core/end2end/cq_verifier.h" + +namespace { + +void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(5000), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +grpc_channel* create_secure_channel_for_test( + const char* server_addr, const char* fake_handshake_server_addr) { + grpc_alts_credentials_options* alts_options = + grpc_alts_credentials_client_options_create(); + grpc_channel_credentials* channel_creds = + grpc_alts_credentials_create_customized(alts_options, + fake_handshake_server_addr, + true /* enable_untrusted_alts */); + grpc_alts_credentials_options_destroy(alts_options); + // The main goal of these tests are to stress concurrent ALTS handshakes, + // so we prevent subchnannel sharing. + grpc_arg disable_subchannel_sharing_arg = + grpc_channel_arg_integer_create(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true); + grpc_channel_args channel_args = {1, &disable_subchannel_sharing_arg}; + grpc_channel* channel = grpc_secure_channel_create(channel_creds, server_addr, + &channel_args, nullptr); + grpc_channel_credentials_release(channel_creds); + return channel; +} + +class FakeHandshakeServer { + public: + FakeHandshakeServer() { + int port = grpc_pick_unused_port_or_die(); + grpc_core::JoinHostPort(&address_, "localhost", port); + service_ = grpc::gcp::CreateFakeHandshakerService(); + grpc::ServerBuilder builder; + builder.AddListeningPort(address_.get(), grpc::InsecureServerCredentials()); + builder.RegisterService(service_.get()); + server_ = builder.BuildAndStart(); + gpr_log(GPR_INFO, "Fake handshaker server listening on %s", address_.get()); + } + + ~FakeHandshakeServer() { + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + } + + const char* address() { return address_.get(); } + + private: + grpc_core::UniquePtr address_; + std::unique_ptr service_; + std::unique_ptr server_; +}; + +class TestServer { + public: + explicit TestServer(const char* fake_handshake_server_address) { + grpc_alts_credentials_options* alts_options = + grpc_alts_credentials_server_options_create(); + grpc_server_credentials* server_creds = + grpc_alts_server_credentials_create_customized( + alts_options, fake_handshake_server_address, + true /* enable_untrusted_alts */); + grpc_alts_credentials_options_destroy(alts_options); + server_ = grpc_server_create(nullptr, nullptr); + server_cq_ = grpc_completion_queue_create_for_next(nullptr); + grpc_server_register_completion_queue(server_, server_cq_, nullptr); + int port = grpc_pick_unused_port_or_die(); + GPR_ASSERT(grpc_core::JoinHostPort(&server_addr_, "localhost", port)); + GPR_ASSERT(grpc_server_add_secure_http2_port(server_, server_addr_.get(), + server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(server_); + gpr_log(GPR_DEBUG, "Start TestServer %p. listen on %s", this, + server_addr_.get()); + server_thd_ = + std::unique_ptr(new std::thread(PollUntilShutdown, this)); + } + + ~TestServer() { + gpr_log(GPR_DEBUG, "Begin dtor of TestServer %p", this); + grpc_server_shutdown_and_notify(server_, server_cq_, this); + server_thd_->join(); + grpc_server_destroy(server_); + grpc_completion_queue_shutdown(server_cq_); + drain_cq(server_cq_); + grpc_completion_queue_destroy(server_cq_); + } + + const char* address() { return server_addr_.get(); } + + static void PollUntilShutdown(const TestServer* self) { + grpc_event ev = grpc_completion_queue_next( + self->server_cq_, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == self); + gpr_log(GPR_DEBUG, "TestServer %p stop polling", self); + } + + private: + grpc_server* server_; + grpc_completion_queue* server_cq_; + std::unique_ptr server_thd_; + grpc_core::UniquePtr server_addr_; +}; + +class ConnectLoopRunner { + public: + explicit ConnectLoopRunner( + const char* server_address, const char* fake_handshake_server_addr, + int per_connect_deadline_seconds, size_t loops, + grpc_connectivity_state expected_connectivity_states) + : server_address_(std::unique_ptr(gpr_strdup(server_address))), + fake_handshake_server_addr_( + std::unique_ptr(gpr_strdup(fake_handshake_server_addr))), + per_connect_deadline_seconds_(per_connect_deadline_seconds), + loops_(loops), + expected_connectivity_states_(expected_connectivity_states) { + thd_ = std::unique_ptr(new std::thread(ConnectLoop, this)); + } + + ~ConnectLoopRunner() { thd_->join(); } + + static void ConnectLoop(const ConnectLoopRunner* self) { + for (size_t i = 0; i < self->loops_; i++) { + gpr_log(GPR_DEBUG, "runner:%p connect_loop begin loop %ld", self, i); + grpc_completion_queue* cq = + grpc_completion_queue_create_for_next(nullptr); + grpc_channel* channel = create_secure_channel_for_test( + self->server_address_.get(), self->fake_handshake_server_addr_.get()); + // Connect, forcing an ALTS handshake + gpr_timespec connect_deadline = + grpc_timeout_seconds_to_deadline(self->per_connect_deadline_seconds_); + grpc_connectivity_state state = + grpc_channel_check_connectivity_state(channel, 1); + ASSERT_EQ(state, GRPC_CHANNEL_IDLE); + while (state != self->expected_connectivity_states_) { + if (self->expected_connectivity_states_ == + GRPC_CHANNEL_TRANSIENT_FAILURE) { + ASSERT_NE(state, GRPC_CHANNEL_READY); // sanity check + } else { + ASSERT_EQ(self->expected_connectivity_states_, GRPC_CHANNEL_READY); + } + grpc_channel_watch_connectivity_state( + channel, state, gpr_inf_future(GPR_CLOCK_REALTIME), cq, nullptr); + grpc_event ev = + grpc_completion_queue_next(cq, connect_deadline, nullptr); + ASSERT_EQ(ev.type, GRPC_OP_COMPLETE) + << "connect_loop runner:" << std::hex << self + << " got ev.type:" << ev.type << " i:" << i; + ASSERT_TRUE(ev.success); + state = grpc_channel_check_connectivity_state(channel, 1); + } + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + drain_cq(cq); + grpc_completion_queue_destroy(cq); + gpr_log(GPR_DEBUG, "runner:%p connect_loop finished loop %ld", self, i); + } + } + + private: + std::unique_ptr server_address_; + std::unique_ptr fake_handshake_server_addr_; + int per_connect_deadline_seconds_; + size_t loops_; + grpc_connectivity_state expected_connectivity_states_; + std::unique_ptr thd_; +}; + +// Perform a few ALTS handshakes sequentially (using the fake, in-process ALTS +// handshake server). +TEST(AltsConcurrentConnectivityTest, TestBasicClientServerHandshakes) { + FakeHandshakeServer fake_handshake_server; + TestServer test_server(fake_handshake_server.address()); + { + ConnectLoopRunner runner( + test_server.address(), fake_handshake_server.address(), + 5 /* per connect deadline seconds */, 10 /* loops */, + GRPC_CHANNEL_READY /* expected connectivity states */); + } +} + +/* Run a bunch of concurrent ALTS handshakes on concurrent channels + * (using the fake, in-process handshake server). */ +TEST(AltsConcurrentConnectivityTest, TestConcurrentClientServerHandshakes) { + FakeHandshakeServer fake_handshake_server; + // Test + { + TestServer test_server(fake_handshake_server.address()); + gpr_timespec test_deadline = grpc_timeout_seconds_to_deadline(20); + size_t num_concurrent_connects = 50; + std::vector> connect_loop_runners; + gpr_log(GPR_DEBUG, + "start performing concurrent expected-to-succeed connects"); + for (size_t i = 0; i < num_concurrent_connects; i++) { + connect_loop_runners.push_back( + std::unique_ptr(new ConnectLoopRunner( + test_server.address(), fake_handshake_server.address(), + 15 /* per connect deadline seconds */, 5 /* loops */, + GRPC_CHANNEL_READY /* expected connectivity states */))); + } + connect_loop_runners.clear(); + gpr_log(GPR_DEBUG, + "done performing concurrent expected-to-succeed connects"); + if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), test_deadline) > 0) { + gpr_log(GPR_DEBUG, "Test took longer than expected."); + abort(); + } + } +} + +class FakeTcpServer { + public: + enum ProcessReadResult { + CONTINUE_READING, + CLOSE_SOCKET, + }; + + FakeTcpServer( + const std::function& process_read_cb) + : process_read_cb_(process_read_cb) { + port_ = grpc_pick_unused_port_or_die(); + accept_socket_ = socket(AF_INET6, SOCK_STREAM, 0); + char* addr_str; + GPR_ASSERT(gpr_asprintf(&addr_str, "[::]:%d", port_)); + address_ = std::unique_ptr(addr_str); + GPR_ASSERT(accept_socket_ != -1); + if (accept_socket_ == -1) { + gpr_log(GPR_ERROR, "Failed to create socket: %d", errno); + abort(); + } + int val = 1; + if (setsockopt(accept_socket_, SOL_SOCKET, SO_REUSEADDR, &val, + sizeof(val)) != 0) { + gpr_log(GPR_ERROR, + "Failed to set SO_REUSEADDR on socket bound to [::1]:%d : %d", + port_, errno); + abort(); + } + if (fcntl(accept_socket_, F_SETFL, O_NONBLOCK) != 0) { + gpr_log(GPR_ERROR, "Failed to set O_NONBLOCK on socket: %d", errno); + abort(); + } + sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(port_); + ((char*)&addr.sin6_addr)[15] = 1; + if (bind(accept_socket_, (const sockaddr*)&addr, sizeof(addr)) != 0) { + gpr_log(GPR_ERROR, "Failed to bind socket to [::1]:%d : %d", port_, + errno); + abort(); + } + if (listen(accept_socket_, 100)) { + gpr_log(GPR_ERROR, "Failed to listen on socket bound to [::1]:%d : %d", + port_, errno); + abort(); + } + gpr_event_init(&stop_ev_); + run_server_loop_thd_ = + std::unique_ptr(new std::thread(RunServerLoop, this)); + } + + ~FakeTcpServer() { + gpr_log(GPR_DEBUG, + "FakeTcpServer stop and " + "join server thread"); + gpr_event_set(&stop_ev_, (void*)1); + run_server_loop_thd_->join(); + gpr_log(GPR_DEBUG, + "FakeTcpServer join server " + "thread complete"); + } + + const char* address() { return address_.get(); } + + static ProcessReadResult CloseSocketUponReceivingBytesFromPeer( + int bytes_received_size, int read_error, int s) { + if (bytes_received_size < 0 && read_error != EAGAIN && + read_error != EWOULDBLOCK) { + gpr_log(GPR_ERROR, "Failed to receive from peer socket: %d. errno: %d", s, + errno); + abort(); + } + if (bytes_received_size >= 0) { + gpr_log(GPR_DEBUG, + "Fake TCP server received %d bytes from peer socket: %d. Close " + "the " + "connection.", + bytes_received_size, s); + return CLOSE_SOCKET; + } + return CONTINUE_READING; + } + + static ProcessReadResult CloseSocketUponCloseFromPeer(int bytes_received_size, + int read_error, int s) { + if (bytes_received_size < 0 && read_error != EAGAIN && + read_error != EWOULDBLOCK) { + gpr_log(GPR_ERROR, "Failed to receive from peer socket: %d. errno: %d", s, + errno); + abort(); + } + if (bytes_received_size == 0) { + // The peer has shut down the connection. + gpr_log(GPR_DEBUG, + "Fake TCP server received 0 bytes from peer socket: %d. Close " + "the " + "connection.", + s); + return CLOSE_SOCKET; + } + return CONTINUE_READING; + } + + // Run a loop that periodically, every 10 ms: + // 1) Checks if there are any new TCP connections to accept. + // 2) Checks if any data has arrived yet on established connections, + // and reads from them if so, processing the sockets as configured. + static void RunServerLoop(FakeTcpServer* self) { + std::set peers; + while (!gpr_event_get(&self->stop_ev_)) { + int p = accept(self->accept_socket_, nullptr, nullptr); + if (p == -1 && errno != EAGAIN && errno != EWOULDBLOCK) { + gpr_log(GPR_ERROR, "Failed to accept connection: %d", errno); + abort(); + } + if (p != -1) { + gpr_log(GPR_DEBUG, "accepted peer socket: %d", p); + if (fcntl(p, F_SETFL, O_NONBLOCK) != 0) { + gpr_log(GPR_ERROR, + "Failed to set O_NONBLOCK on peer socket:%d errno:%d", p, + errno); + abort(); + } + peers.insert(p); + } + auto it = peers.begin(); + while (it != peers.end()) { + int p = *it; + char buf[100]; + int bytes_received_size = recv(p, buf, 100, 0); + ProcessReadResult r = + self->process_read_cb_(bytes_received_size, errno, p); + if (r == CLOSE_SOCKET) { + close(p); + it = peers.erase(it); + } else { + GPR_ASSERT(r == CONTINUE_READING); + it++; + } + } + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(10, GPR_TIMESPAN))); + } + for (auto it = peers.begin(); it != peers.end(); it++) { + close(*it); + } + close(self->accept_socket_); + } + + private: + int accept_socket_; + int port_; + gpr_event stop_ev_; + std::unique_ptr address_; + std::unique_ptr run_server_loop_thd_; + std::function process_read_cb_; +}; + +/* This test is intended to make sure that ALTS handshakes we correctly + * fail fast when the security handshaker gets an error while reading + * from the remote peer, after having earlier sent the first bytes of the + * ALTS handshake to the peer, i.e. after getting into the middle of a + * handshake. */ +TEST(AltsConcurrentConnectivityTest, + TestHandshakeFailsFastWhenPeerEndpointClosesConnectionAfterAccepting) { + FakeHandshakeServer fake_handshake_server; + FakeTcpServer fake_tcp_server( + FakeTcpServer::CloseSocketUponReceivingBytesFromPeer); + { + gpr_timespec test_deadline = grpc_timeout_seconds_to_deadline(20); + std::vector> connect_loop_runners; + size_t num_concurrent_connects = 100; + gpr_log(GPR_DEBUG, "start performing concurrent expected-to-fail connects"); + for (size_t i = 0; i < num_concurrent_connects; i++) { + connect_loop_runners.push_back(std::unique_ptr< + ConnectLoopRunner>(new ConnectLoopRunner( + fake_tcp_server.address(), fake_handshake_server.address(), + 10 /* per connect deadline seconds */, 3 /* loops */, + GRPC_CHANNEL_TRANSIENT_FAILURE /* expected connectivity states */))); + } + connect_loop_runners.clear(); + gpr_log(GPR_DEBUG, "done performing concurrent expected-to-fail connects"); + if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), test_deadline) > 0) { + gpr_log(GPR_ERROR, + "Exceeded test deadline. ALTS handshakes might not be failing " + "fast when the peer endpoint closes the connection abruptly"); + abort(); + } + } +} + +} // namespace + +int main(int argc, char** argv) { + grpc_init(); + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc index 21e9e2c4397..b85f7501027 100644 --- a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc +++ b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc @@ -320,8 +320,12 @@ static tsi_result mock_client_start(alts_handshaker_client* client) { if (!should_handshaker_client_api_succeed) { return TSI_INTERNAL_ERROR; } + /* Note that the alts_tsi_handshaker needs to set its + * has_sent_start_message field field to true + * before the call to alts_handshaker_client_start is made because + * because it's unsafe to access it afterwards. */ alts_handshaker_client_check_fields_for_testing( - client, on_client_start_success_cb, nullptr, false, nullptr); + client, on_client_start_success_cb, nullptr, true, nullptr); /* Populate handshaker response for client_start request. */ grpc_byte_buffer** recv_buffer_ptr = alts_handshaker_client_get_recv_buffer_addr_for_testing(client); @@ -339,7 +343,7 @@ static tsi_result mock_server_start(alts_handshaker_client* client, return TSI_INTERNAL_ERROR; } alts_handshaker_client_check_fields_for_testing( - client, on_server_start_success_cb, nullptr, false, nullptr); + client, on_server_start_success_cb, nullptr, true, nullptr); grpc_slice slice = grpc_empty_slice(); GPR_ASSERT(grpc_slice_cmp(*bytes_received, slice) == 0); /* Populate handshaker response for server_start request. */ @@ -404,6 +408,12 @@ static tsi_handshaker* create_test_handshaker(bool is_client) { return handshaker; } +static void run_tsi_handshaker_destroy_with_exec_ctx( + tsi_handshaker* handshaker) { + grpc_core::ExecCtx exec_ctx; + tsi_handshaker_destroy(handshaker); +} + static void check_handshaker_next_invalid_input() { /* Initialization. */ tsi_handshaker* handshaker = create_test_handshaker(true); @@ -416,7 +426,7 @@ static void check_handshaker_next_invalid_input() { nullptr, nullptr, nullptr) == TSI_INVALID_ARGUMENT); /* Cleanup. */ - tsi_handshaker_destroy(handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); } static void check_handshaker_shutdown_invalid_input() { @@ -425,7 +435,7 @@ static void check_handshaker_shutdown_invalid_input() { /* Check nullptr handshaker. */ tsi_handshaker_shutdown(nullptr); /* Cleanup. */ - tsi_handshaker_destroy(handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); } static void check_handshaker_next_success() { @@ -462,8 +472,8 @@ static void check_handshaker_next_success() { nullptr, on_server_next_success_cb, nullptr) == TSI_ASYNC); wait(&tsi_to_caller_notification); /* Cleanup. */ - tsi_handshaker_destroy(server_handshaker); - tsi_handshaker_destroy(client_handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(server_handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(client_handshaker); } static void check_handshaker_next_with_shutdown() { @@ -481,7 +491,7 @@ static void check_handshaker_next_with_shutdown() { nullptr, on_client_next_success_cb, nullptr) == TSI_HANDSHAKE_SHUTDOWN); /* Cleanup. */ - tsi_handshaker_destroy(handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); } static void check_handle_response_with_shutdown(void* /*unused*/) { @@ -520,8 +530,8 @@ static void check_handshaker_next_failure() { nullptr, check_must_not_be_called, nullptr) == TSI_INTERNAL_ERROR); /* Cleanup. */ - tsi_handshaker_destroy(server_handshaker); - tsi_handshaker_destroy(client_handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(server_handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(client_handshaker); } static void on_invalid_input_cb(tsi_result status, void* user_data, @@ -584,7 +594,7 @@ static void check_handle_response_invalid_input() { alts_handshaker_client_handle_response(client, false); /* Cleanup. */ grpc_slice_unref(slice); - tsi_handshaker_destroy(handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); notification_destroy(&caller_to_tsi_notification); notification_destroy(&tsi_to_caller_notification); } @@ -622,7 +632,7 @@ static void check_handle_response_invalid_resp() { recv_buffer, GRPC_STATUS_OK); alts_handshaker_client_handle_response(client, true); /* Cleanup. */ - tsi_handshaker_destroy(handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); notification_destroy(&caller_to_tsi_notification); notification_destroy(&tsi_to_caller_notification); } @@ -675,7 +685,7 @@ static void check_handle_response_failure() { recv_buffer, GRPC_STATUS_OK); alts_handshaker_client_handle_response(client, true /* is_ok*/); /* Cleanup. */ - tsi_handshaker_destroy(handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); notification_destroy(&caller_to_tsi_notification); notification_destroy(&tsi_to_caller_notification); } @@ -714,7 +724,7 @@ static void check_handle_response_after_shutdown() { recv_buffer, GRPC_STATUS_OK); alts_handshaker_client_handle_response(client, true); /* Cleanup. */ - tsi_handshaker_destroy(handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); notification_destroy(&caller_to_tsi_notification); notification_destroy(&tsi_to_caller_notification); } diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 3c1d60bf699..b504ebbd08e 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -3035,6 +3035,24 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": false, + "language": "c++", + "name": "alts_concurrent_connectivity_test", + "platforms": [ + "linux" + ], + "uses_polling": true + }, { "args": [], "benchmark": false,