From 76c82265b40e3c15a728747ba7eb7e032614e7b8 Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Fri, 20 Jan 2023 12:14:37 -0800 Subject: [PATCH] WRR: implement WRR LB policy (#31904) * WRR: port StaticStrideScheduler to OSS * WIP * Automated change: Fix sanity tests * fix build * remove unused aliases * fix another type mismatch * remove unnecessary include * move benchmarks to their own file, and don't run it on windows * Automated change: Fix sanity tests * add OOB reporting * generate_projects * clang-format * add config parser test * clang-tidy and minimize lock contention * add config defaults * add oob_reporting_period config field and add basic test * Automated change: Fix sanity tests * fix test * change test to use basic RR * WIP: started exposing peer address to LB policy API * first WRR test passing! * small cleanup * port RR fix to WRR * test helper refactoring * more test helper refactoring * WIP: trying to fix test to have the right weights * more WIP -- need to make pickers DualRefCounted * fix timer ref handling and get tests working * clang-format * iwyu and generate_projects * fix build * add test for OOB reporting * keep only READY subchannels in the picker * add file missed in a previous commit * fix sanity * iwyu * add weight expiration period * add tests for weight update period and OOB reporting period * Automated change: Fix sanity tests * lower bound for timer interval * consistently apply grpc_test_slowdown_factor() * cache time in test * add blackout_period tests * avoid some unnecessary copies * clang-format * add field to config test * simplify orca watcher tracking * attempt to fix build * iwyu * generate_projects * add "_experimental" suffix to policy name * WRR: update tests to cover qps plumbing * WIP * more WIP * basic WRR e2e test working * add OOB test * fix sanity * ignore duplicate addresses * Automated change: Fix sanity tests * add new tracer to doc/environment_variables.md * retain scheduler state across pickers * Automated change: Fix sanity tests * use separate mutexes for scheduler and timer * sort addresses to avoid index churn * remove fetch_sub for wrap around in RR case Co-authored-by: markdroth --- BUILD | 2 + CMakeLists.txt | 80 ++ Makefile | 4 + build_autogenerated.yaml | 30 + config.m4 | 3 + config.w32 | 3 + doc/environment_variables.md | 2 + gRPC-C++.podspec | 4 + gRPC-Core.podspec | 6 + grpc.gemspec | 4 + grpc.gyp | 4 + package.xml | 4 + src/core/BUILD | 44 + .../filters/client_channel/client_channel.cc | 6 +- .../lb_policy/oob_backend_metric.cc | 79 +- .../lb_policy/oob_backend_metric_internal.h | 117 +++ .../weighted_round_robin.cc | 972 ++++++++++++++++++ src/core/lib/load_balancing/lb_policy.cc | 9 + src/core/lib/load_balancing/lb_policy.h | 8 +- .../plugin_registry/grpc_plugin_registry.cc | 3 + src/python/grpcio/grpc_core_dependencies.py | 2 + test/core/client_channel/lb_policy/BUILD | 30 + .../lb_policy/lb_policy_test_lib.h | 156 ++- .../weighted_round_robin_config_test.cc | 85 ++ .../lb_policy/weighted_round_robin_test.cc | 626 +++++++++++ test/cpp/end2end/client_lb_end2end_test.cc | 169 ++- tools/doxygen/Doxyfile.c++.internal | 4 + tools/doxygen/Doxyfile.core.internal | 4 + tools/run_tests/generated/tests.json | 48 + 29 files changed, 2396 insertions(+), 112 deletions(-) create mode 100644 src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h create mode 100644 src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc create mode 100644 test/core/client_channel/lb_policy/weighted_round_robin_config_test.cc create mode 100644 test/core/client_channel/lb_policy/weighted_round_robin_test.cc diff --git a/BUILD b/BUILD index 252a928f264..fa7730856ac 100644 --- a/BUILD +++ b/BUILD @@ -751,6 +751,7 @@ grpc_cc_library( "//src/core:grpc_lb_policy_priority", "//src/core:grpc_lb_policy_ring_hash", "//src/core:grpc_lb_policy_round_robin", + "//src/core:grpc_lb_policy_weighted_round_robin", "//src/core:grpc_lb_policy_weighted_target", "//src/core:grpc_channel_idle_filter", "//src/core:grpc_message_size_filter", @@ -2676,6 +2677,7 @@ grpc_cc_library( "//src/core:ext/filters/client_channel/lb_call_state_internal.h", "//src/core:ext/filters/client_channel/lb_policy/child_policy_handler.h", "//src/core:ext/filters/client_channel/lb_policy/oob_backend_metric.h", + "//src/core:ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h", "//src/core:ext/filters/client_channel/local_subchannel_pool.h", "//src/core:ext/filters/client_channel/retry_filter.h", "//src/core:ext/filters/client_channel/retry_service_config.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index c9573d2bf89..0b80f699032 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1268,6 +1268,8 @@ if(gRPC_BUILD_TESTS) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) add_dependencies(buildtests_cxx wakeup_fd_posix_test) endif() + add_dependencies(buildtests_cxx weighted_round_robin_config_test) + add_dependencies(buildtests_cxx weighted_round_robin_test) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_POSIX OR _gRPC_PLATFORM_WINDOWS) add_dependencies(buildtests_cxx win_socket_test) endif() @@ -1723,6 +1725,8 @@ add_library(grpc src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc src/core/ext/filters/client_channel/lb_policy/rls/rls.cc src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc src/core/ext/filters/client_channel/lb_policy/xds/cds.cc src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc @@ -2719,6 +2723,8 @@ add_library(grpc_unsecure src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc src/core/ext/filters/client_channel/lb_policy/rls/rls.cc src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc src/core/ext/filters/client_channel/local_subchannel_pool.cc src/core/ext/filters/client_channel/resolver/binder/binder_resolver.cc @@ -21882,6 +21888,80 @@ if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) endif() +endif() +if(gRPC_BUILD_TESTS) + +add_executable(weighted_round_robin_config_test + test/core/client_channel/lb_policy/weighted_round_robin_config_test.cc + third_party/googletest/googletest/src/gtest-all.cc + third_party/googletest/googlemock/src/gmock-all.cc +) + +target_include_directories(weighted_round_robin_config_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(weighted_round_robin_config_test + ${_gRPC_BASELIB_LIBRARIES} + ${_gRPC_PROTOBUF_LIBRARIES} + ${_gRPC_ZLIB_LIBRARIES} + ${_gRPC_ALLTARGETS_LIBRARIES} + grpc_test_util +) + + +endif() +if(gRPC_BUILD_TESTS) + +add_executable(weighted_round_robin_test + test/core/client_channel/lb_policy/weighted_round_robin_test.cc + third_party/googletest/googletest/src/gtest-all.cc + third_party/googletest/googlemock/src/gmock-all.cc +) + +target_include_directories(weighted_round_robin_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(weighted_round_robin_test + ${_gRPC_BASELIB_LIBRARIES} + ${_gRPC_PROTOBUF_LIBRARIES} + ${_gRPC_ZLIB_LIBRARIES} + ${_gRPC_ALLTARGETS_LIBRARIES} + grpc_test_util +) + + endif() if(gRPC_BUILD_TESTS) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_POSIX OR _gRPC_PLATFORM_WINDOWS) diff --git a/Makefile b/Makefile index 9da3b0afaab..171d8077987 100644 --- a/Makefile +++ b/Makefile @@ -988,6 +988,8 @@ LIBGRPC_SRC = \ src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc \ src/core/ext/filters/client_channel/lb_policy/rls/rls.cc \ src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc \ + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc \ + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc \ src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc \ src/core/ext/filters/client_channel/lb_policy/xds/cds.cc \ src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc \ @@ -1843,6 +1845,8 @@ LIBGRPC_UNSECURE_SRC = \ src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc \ src/core/ext/filters/client_channel/lb_policy/rls/rls.cc \ src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc \ + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc \ + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc \ src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc \ src/core/ext/filters/client_channel/local_subchannel_pool.cc \ src/core/ext/filters/client_channel/resolver/binder/binder_resolver.cc \ diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 96c925938f0..363229e26c5 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -348,9 +348,11 @@ libs: - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h - src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h - src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h + - src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h - src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h - src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h - src/core/ext/filters/client_channel/lb_policy/subchannel_list.h + - src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h - src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.h - src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h - src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.h @@ -1110,6 +1112,8 @@ libs: - src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc - src/core/ext/filters/client_channel/lb_policy/rls/rls.cc - src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc + - src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc + - src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc - src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc - src/core/ext/filters/client_channel/lb_policy/xds/cds.cc - src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc @@ -1984,9 +1988,11 @@ libs: - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h - src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h - src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h + - src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h - src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h - src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h - src/core/ext/filters/client_channel/lb_policy/subchannel_list.h + - src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h - src/core/ext/filters/client_channel/local_subchannel_pool.h - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h @@ -2366,6 +2372,8 @@ libs: - src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc - src/core/ext/filters/client_channel/lb_policy/rls/rls.cc - src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc + - src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc + - src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc - src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc - src/core/ext/filters/client_channel/local_subchannel_pool.cc - src/core/ext/filters/client_channel/resolver/binder/binder_resolver.cc @@ -12292,6 +12300,28 @@ targets: - linux - posix - mac +- name: weighted_round_robin_config_test + gtest: true + build: test + language: c++ + headers: [] + src: + - test/core/client_channel/lb_policy/weighted_round_robin_config_test.cc + deps: + - grpc_test_util + uses_polling: false +- name: weighted_round_robin_test + gtest: true + build: test + language: c++ + headers: + - test/core/client_channel/lb_policy/lb_policy_test_lib.h + - test/core/event_engine/mock_event_engine.h + src: + - test/core/client_channel/lb_policy/weighted_round_robin_test.cc + deps: + - grpc_test_util + uses_polling: false - name: win_socket_test gtest: true build: test diff --git a/config.m4 b/config.m4 index ba527b00a2e..b592af708f0 100644 --- a/config.m4 +++ b/config.m4 @@ -70,6 +70,8 @@ if test "$PHP_GRPC" != "no"; then src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc \ src/core/ext/filters/client_channel/lb_policy/rls/rls.cc \ src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc \ + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc \ + src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc \ src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc \ src/core/ext/filters/client_channel/lb_policy/xds/cds.cc \ src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc \ @@ -1255,6 +1257,7 @@ if test "$PHP_GRPC" != "no"; then PHP_ADD_BUILD_DIR($ext_builddir/src/core/ext/filters/client_channel/lb_policy/ring_hash) PHP_ADD_BUILD_DIR($ext_builddir/src/core/ext/filters/client_channel/lb_policy/rls) PHP_ADD_BUILD_DIR($ext_builddir/src/core/ext/filters/client_channel/lb_policy/round_robin) + PHP_ADD_BUILD_DIR($ext_builddir/src/core/ext/filters/client_channel/lb_policy/weighted_round_robin) PHP_ADD_BUILD_DIR($ext_builddir/src/core/ext/filters/client_channel/lb_policy/weighted_target) PHP_ADD_BUILD_DIR($ext_builddir/src/core/ext/filters/client_channel/lb_policy/xds) PHP_ADD_BUILD_DIR($ext_builddir/src/core/ext/filters/client_channel/resolver) diff --git a/config.w32 b/config.w32 index f6ccff8b75f..04f916eaa46 100644 --- a/config.w32 +++ b/config.w32 @@ -36,6 +36,8 @@ if (PHP_GRPC != "no") { "src\\core\\ext\\filters\\client_channel\\lb_policy\\ring_hash\\ring_hash.cc " + "src\\core\\ext\\filters\\client_channel\\lb_policy\\rls\\rls.cc " + "src\\core\\ext\\filters\\client_channel\\lb_policy\\round_robin\\round_robin.cc " + + "src\\core\\ext\\filters\\client_channel\\lb_policy\\weighted_round_robin\\static_stride_scheduler.cc " + + "src\\core\\ext\\filters\\client_channel\\lb_policy\\weighted_round_robin\\weighted_round_robin.cc " + "src\\core\\ext\\filters\\client_channel\\lb_policy\\weighted_target\\weighted_target.cc " + "src\\core\\ext\\filters\\client_channel\\lb_policy\\xds\\cds.cc " + "src\\core\\ext\\filters\\client_channel\\lb_policy\\xds\\xds_attributes.cc " + @@ -1253,6 +1255,7 @@ if (PHP_GRPC != "no") { FSO.CreateFolder(base_dir+"\\ext\\grpc\\src\\core\\ext\\filters\\client_channel\\lb_policy\\ring_hash"); FSO.CreateFolder(base_dir+"\\ext\\grpc\\src\\core\\ext\\filters\\client_channel\\lb_policy\\rls"); FSO.CreateFolder(base_dir+"\\ext\\grpc\\src\\core\\ext\\filters\\client_channel\\lb_policy\\round_robin"); + FSO.CreateFolder(base_dir+"\\ext\\grpc\\src\\core\\ext\\filters\\client_channel\\lb_policy\\weighted_round_robin"); FSO.CreateFolder(base_dir+"\\ext\\grpc\\src\\core\\ext\\filters\\client_channel\\lb_policy\\weighted_target"); FSO.CreateFolder(base_dir+"\\ext\\grpc\\src\\core\\ext\\filters\\client_channel\\lb_policy\\xds"); FSO.CreateFolder(base_dir+"\\ext\\grpc\\src\\core\\ext\\filters\\client_channel\\resolver"); diff --git a/doc/environment_variables.md b/doc/environment_variables.md index 67f3a810018..d2d0ba9e798 100644 --- a/doc/environment_variables.md +++ b/doc/environment_variables.md @@ -84,6 +84,8 @@ some configuration as environment variables that can be set. - ring_hash_lb - traces the ring hash load balancing policy - rls_lb - traces the RLS load balancing policy - round_robin - traces the round_robin load balancing policy + - weighted_round_robin_lb - traces the weighted_round_robin load balancing + policy - queue_pluck - grpc_authz_api - traces gRPC authorization - server_channel - lightweight trace of significant server channel events diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index 2d23a003774..26804333e9f 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -260,9 +260,11 @@ Pod::Spec.new do |s| 'src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h', 'src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h', 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h', + 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h', 'src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h', 'src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h', 'src/core/ext/filters/client_channel/lb_policy/subchannel_list.h', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.h', @@ -1192,9 +1194,11 @@ Pod::Spec.new do |s| 'src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h', 'src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h', 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h', + 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h', 'src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h', 'src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h', 'src/core/ext/filters/client_channel/lb_policy/subchannel_list.h', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.h', diff --git a/gRPC-Core.podspec b/gRPC-Core.podspec index b1dfc6f2b4f..54fc1a2cf78 100644 --- a/gRPC-Core.podspec +++ b/gRPC-Core.podspec @@ -252,6 +252,7 @@ Pod::Spec.new do |s| 'src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h', 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc', 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h', + 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h', 'src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.cc', 'src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h', 'src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc', @@ -261,6 +262,9 @@ Pod::Spec.new do |s| 'src/core/ext/filters/client_channel/lb_policy/rls/rls.cc', 'src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc', 'src/core/ext/filters/client_channel/lb_policy/subchannel_list.h', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc', 'src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc', 'src/core/ext/filters/client_channel/lb_policy/xds/cds.cc', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc', @@ -1882,9 +1886,11 @@ Pod::Spec.new do |s| 'src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h', 'src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h', 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h', + 'src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h', 'src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h', 'src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h', 'src/core/ext/filters/client_channel/lb_policy/subchannel_list.h', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.h', diff --git a/grpc.gemspec b/grpc.gemspec index 363a0966a7a..d39e0aa79b8 100644 --- a/grpc.gemspec +++ b/grpc.gemspec @@ -163,6 +163,7 @@ Gem::Specification.new do |s| s.files += %w( src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h ) + s.files += %w( src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.cc ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc ) @@ -172,6 +173,9 @@ Gem::Specification.new do |s| s.files += %w( src/core/ext/filters/client_channel/lb_policy/rls/rls.cc ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/subchannel_list.h ) + s.files += %w( src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc ) + s.files += %w( src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h ) + s.files += %w( src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/xds/cds.cc ) s.files += %w( src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc ) diff --git a/grpc.gyp b/grpc.gyp index 7340cd25979..c9c892d2d5c 100644 --- a/grpc.gyp +++ b/grpc.gyp @@ -401,6 +401,8 @@ 'src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc', 'src/core/ext/filters/client_channel/lb_policy/rls/rls.cc', 'src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc', 'src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc', 'src/core/ext/filters/client_channel/lb_policy/xds/cds.cc', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc', @@ -1198,6 +1200,8 @@ 'src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc', 'src/core/ext/filters/client_channel/lb_policy/rls/rls.cc', 'src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc', 'src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc', 'src/core/ext/filters/client_channel/local_subchannel_pool.cc', 'src/core/ext/filters/client_channel/resolver/binder/binder_resolver.cc', diff --git a/package.xml b/package.xml index fccd53a75a9..85b1f7d00e3 100644 --- a/package.xml +++ b/package.xml @@ -145,6 +145,7 @@ + @@ -154,6 +155,9 @@ + + + diff --git a/src/core/BUILD b/src/core/BUILD index 12004978ad2..6d19833917e 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -2410,6 +2410,7 @@ grpc_cc_library( deps = [ "channel_args", "closure", + "dual_ref_counted", "error", "grpc_backend_metric_data", "iomgr_fwd", @@ -4365,6 +4366,49 @@ grpc_cc_library( deps = ["//:gpr"], ) +grpc_cc_library( + name = "grpc_lb_policy_weighted_round_robin", + srcs = [ + "ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc", + ], + external_deps = [ + "absl/base:core_headers", + "absl/random", + "absl/status", + "absl/status:statusor", + "absl/strings", + "absl/types:optional", + ], + language = "c++", + deps = [ + "channel_args", + "grpc_backend_metric_data", + "grpc_lb_subchannel_list", + "json", + "json_args", + "json_object_loader", + "lb_policy", + "lb_policy_factory", + "ref_counted", + "resolved_address", + "static_stride_scheduler", + "subchannel_interface", + "time", + "validation_errors", + "//:config", + "//:debug_location", + "//:exec_ctx", + "//:gpr", + "//:grpc_base", + "//:grpc_client_channel", + "//:grpc_trace", + "//:orphanable", + "//:ref_counted_ptr", + "//:server_address", + "//:sockaddr_utils", + ], +) + grpc_cc_library( name = "grpc_outlier_detection_header", hdrs = [ diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index 2ed844b2d8d..b3ea5b5df47 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -2948,8 +2948,12 @@ void ClientChannel::LoadBalancedCall::RecordCallCompletion( if (lb_subchannel_call_tracker_ != nullptr) { Metadata trailing_metadata(recv_trailing_metadata_); BackendMetricAccessor backend_metric_accessor(this); + const char* peer_string = + peer_string_ != nullptr + ? reinterpret_cast(gpr_atm_acq_load(peer_string_)) + : ""; LoadBalancingPolicy::SubchannelCallTrackerInterface::FinishArgs args = { - status, &trailing_metadata, &backend_metric_accessor}; + peer_string, status, &trailing_metadata, &backend_metric_accessor}; lb_subchannel_call_tracker_->Finish(args); lb_subchannel_call_tracker_.reset(); } diff --git a/src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc b/src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc index 7d7d29a48c5..610010549e7 100644 --- a/src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc +++ b/src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc @@ -26,7 +26,6 @@ #include #include -#include "absl/base/thread_annotations.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -43,8 +42,8 @@ #include "src/core/ext/filters/client_channel/backend_metric.h" #include "src/core/ext/filters/client_channel/client_channel_channelz.h" +#include "src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h" #include "src/core/ext/filters/client_channel/subchannel.h" -#include "src/core/ext/filters/client_channel/subchannel_interface_internal.h" #include "src/core/ext/filters/client_channel/subchannel_stream_client.h" #include "src/core/lib/channel/channel_trace.h" #include "src/core/lib/debug/trace.h" @@ -54,7 +53,6 @@ #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/sync.h" #include "src/core/lib/gprpp/time.h" -#include "src/core/lib/gprpp/unique_type_name.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/exec_ctx.h" @@ -64,81 +62,8 @@ namespace grpc_core { -namespace { - TraceFlag grpc_orca_client_trace(false, "orca_client"); -class OrcaWatcher; - -// This producer is registered with a subchannel. It creates a -// streaming ORCA call and reports the resulting backend metrics to all -// registered watchers. -class OrcaProducer : public Subchannel::DataProducerInterface { - public: - void Start(RefCountedPtr subchannel); - - void Orphan() override; - - static UniqueTypeName Type() { - static UniqueTypeName::Factory kFactory("orca"); - return kFactory.Create(); - } - - UniqueTypeName type() const override { return Type(); } - - // Adds and removes watchers. - void AddWatcher(OrcaWatcher* watcher); - void RemoveWatcher(OrcaWatcher* watcher); - - private: - class ConnectivityWatcher; - class OrcaStreamEventHandler; - - // Returns the minimum requested reporting interval across all watchers. - Duration GetMinIntervalLocked() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(&mu_); - - // Starts a new stream if we have a connected subchannel. - // Called whenever the reporting interval changes or the subchannel - // transitions to state READY. - void MaybeStartStreamLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&mu_); - - // Handles a connectivity state change on the subchannel. - void OnConnectivityStateChange(grpc_connectivity_state state); - - // Called to notify watchers of a new backend metric report. - void NotifyWatchers(const BackendMetricData& backend_metric_data); - - RefCountedPtr subchannel_; - RefCountedPtr connected_subchannel_; - ConnectivityWatcher* connectivity_watcher_; - Mutex mu_; - std::set watchers_ ABSL_GUARDED_BY(mu_); - Duration report_interval_ ABSL_GUARDED_BY(mu_) = Duration::Infinity(); - OrphanablePtr stream_client_ ABSL_GUARDED_BY(mu_); -}; - -// This watcher is returned to the LB policy and added to the -// client channel SubchannelWrapper. -class OrcaWatcher : public InternalSubchannelDataWatcherInterface { - public: - OrcaWatcher(Duration report_interval, - std::unique_ptr watcher) - : report_interval_(report_interval), watcher_(std::move(watcher)) {} - ~OrcaWatcher() override; - - Duration report_interval() const { return report_interval_; } - OobBackendMetricWatcher* watcher() const { return watcher_.get(); } - - // When the client channel sees this wrapper, it will pass it the real - // subchannel to use. - void SetSubchannel(Subchannel* subchannel) override; - - private: - const Duration report_interval_; - std::unique_ptr watcher_; - RefCountedPtr producer_; -}; - // // OrcaProducer::ConnectivityWatcher // @@ -404,8 +329,6 @@ void OrcaWatcher::SetSubchannel(Subchannel* subchannel) { producer_->AddWatcher(this); } -} // namespace - std::unique_ptr MakeOobBackendMetricWatcher(Duration report_interval, std::unique_ptr watcher) { diff --git a/src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h b/src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h new file mode 100644 index 00000000000..0b3237f9149 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h @@ -0,0 +1,117 @@ +// +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_LB_POLICY_OOB_BACKEND_METRIC_INTERNAL_H +#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_LB_POLICY_OOB_BACKEND_METRIC_INTERNAL_H + +#include + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/strings/string_view.h" + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/backend_metric_data.h" +#include "src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h" +#include "src/core/ext/filters/client_channel/subchannel.h" +#include "src/core/ext/filters/client_channel/subchannel_interface_internal.h" +#include "src/core/ext/filters/client_channel/subchannel_stream_client.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/time.h" +#include "src/core/lib/gprpp/unique_type_name.h" + +namespace grpc_core { + +class OrcaWatcher; + +// This producer is registered with a subchannel. It creates a +// streaming ORCA call and reports the resulting backend metrics to all +// registered watchers. +class OrcaProducer : public Subchannel::DataProducerInterface { + public: + void Start(RefCountedPtr subchannel); + + void Orphan() override; + + static UniqueTypeName Type() { + static UniqueTypeName::Factory kFactory("orca"); + return kFactory.Create(); + } + + UniqueTypeName type() const override { return Type(); } + + // Adds and removes watchers. + void AddWatcher(OrcaWatcher* watcher); + void RemoveWatcher(OrcaWatcher* watcher); + + private: + class ConnectivityWatcher; + class OrcaStreamEventHandler; + + // Returns the minimum requested reporting interval across all watchers. + Duration GetMinIntervalLocked() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(&mu_); + + // Starts a new stream if we have a connected subchannel. + // Called whenever the reporting interval changes or the subchannel + // transitions to state READY. + void MaybeStartStreamLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&mu_); + + // Handles a connectivity state change on the subchannel. + void OnConnectivityStateChange(grpc_connectivity_state state); + + // Called to notify watchers of a new backend metric report. + void NotifyWatchers(const BackendMetricData& backend_metric_data); + + RefCountedPtr subchannel_; + RefCountedPtr connected_subchannel_; + ConnectivityWatcher* connectivity_watcher_; + Mutex mu_; + std::set watchers_ ABSL_GUARDED_BY(mu_); + Duration report_interval_ ABSL_GUARDED_BY(mu_) = Duration::Infinity(); + OrphanablePtr stream_client_ ABSL_GUARDED_BY(mu_); +}; + +// This watcher is returned to the LB policy and added to the +// client channel SubchannelWrapper. +class OrcaWatcher : public InternalSubchannelDataWatcherInterface { + public: + OrcaWatcher(Duration report_interval, + std::unique_ptr watcher) + : report_interval_(report_interval), watcher_(std::move(watcher)) {} + ~OrcaWatcher() override; + + Duration report_interval() const { return report_interval_; } + OobBackendMetricWatcher* watcher() const { return watcher_.get(); } + + // When the client channel sees this wrapper, it will pass it the real + // subchannel to use. + void SetSubchannel(Subchannel* subchannel) override; + + private: + const Duration report_interval_; + std::unique_ptr watcher_; + RefCountedPtr producer_; +}; + +} // namespace grpc_core + +#endif // GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_LB_POLICY_OOB_BACKEND_METRIC_INTERNAL_H diff --git a/src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc b/src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc new file mode 100644 index 00000000000..bcb76542b90 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc @@ -0,0 +1,972 @@ +// +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/random/random.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/backend_metric_data.h" +#include "src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h" +#include "src/core/ext/filters/client_channel/lb_policy/subchannel_list.h" +#include "src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/time.h" +#include "src/core/lib/gprpp/validation_errors.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/resolved_address.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/json/json_args.h" +#include "src/core/lib/json/json_object_loader.h" +#include "src/core/lib/load_balancing/lb_policy.h" +#include "src/core/lib/load_balancing/lb_policy_factory.h" +#include "src/core/lib/load_balancing/subchannel_interface.h" +#include "src/core/lib/resolver/server_address.h" +#include "src/core/lib/transport/connectivity_state.h" + +namespace grpc_core { + +TraceFlag grpc_lb_wrr_trace(false, "weighted_round_robin_lb"); + +namespace { + +constexpr absl::string_view kWeightedRoundRobin = + "weighted_round_robin_experimental"; + +// Config for WRR policy. +class WeightedRoundRobinConfig : public LoadBalancingPolicy::Config { + public: + WeightedRoundRobinConfig() = default; + + WeightedRoundRobinConfig(const WeightedRoundRobinConfig&) = delete; + WeightedRoundRobinConfig& operator=(const WeightedRoundRobinConfig&) = delete; + + WeightedRoundRobinConfig(WeightedRoundRobinConfig&&) = delete; + WeightedRoundRobinConfig& operator=(WeightedRoundRobinConfig&&) = delete; + + absl::string_view name() const override { return kWeightedRoundRobin; } + + bool enable_oob_load_report() const { return enable_oob_load_report_; } + Duration oob_reporting_period() const { return oob_reporting_period_; } + Duration blackout_period() const { return blackout_period_; } + Duration weight_update_period() const { return weight_update_period_; } + Duration weight_expiration_period() const { + return weight_expiration_period_; + } + + static const JsonLoaderInterface* JsonLoader(const JsonArgs&) { + static const auto* loader = + JsonObjectLoader() + .OptionalField("enableOobLoadReport", + &WeightedRoundRobinConfig::enable_oob_load_report_) + .OptionalField("oobReportingPeriod", + &WeightedRoundRobinConfig::oob_reporting_period_) + .OptionalField("blackoutPeriod", + &WeightedRoundRobinConfig::blackout_period_) + .OptionalField("weightUpdatePeriod", + &WeightedRoundRobinConfig::weight_update_period_) + .OptionalField("weightExpirationPeriod", + &WeightedRoundRobinConfig::weight_expiration_period_) + .Finish(); + return loader; + } + + void JsonPostLoad(const Json&, const JsonArgs&, ValidationErrors*) { + // Impose lower bound of 100ms on weightUpdatePeriod. + weight_update_period_ = + std::max(weight_update_period_, Duration::Milliseconds(100)); + } + + private: + bool enable_oob_load_report_ = false; + Duration oob_reporting_period_ = Duration::Seconds(10); + Duration blackout_period_ = Duration::Seconds(10); + Duration weight_update_period_ = Duration::Seconds(1); + Duration weight_expiration_period_ = Duration::Minutes(3); +}; + +// WRR LB policy. +class WeightedRoundRobin : public LoadBalancingPolicy { + public: + explicit WeightedRoundRobin(Args args); + + absl::string_view name() const override { return kWeightedRoundRobin; } + + absl::Status UpdateLocked(UpdateArgs args) override; + void ResetBackoffLocked() override; + + private: + // Represents the weight for a given address. + class AddressWeight : public RefCounted { + public: + AddressWeight(RefCountedPtr wrr, std::string key) + : wrr_(std::move(wrr)), key_(std::move(key)) {} + ~AddressWeight() override; + + void MaybeUpdateWeight(double qps, double cpu_utilization); + + float GetWeight(Timestamp now, Duration weight_expiration_period, + Duration blackout_period); + + void ResetNonEmptySince(); + + private: + RefCountedPtr wrr_; + const std::string key_; + + Mutex mu_; + float weight_ ABSL_GUARDED_BY(&mu_) = 0; + Timestamp non_empty_since_ ABSL_GUARDED_BY(&mu_) = Timestamp::InfFuture(); + Timestamp last_update_time_ ABSL_GUARDED_BY(&mu_) = Timestamp::InfPast(); + }; + + // Forward declaration. + class WeightedRoundRobinSubchannelList; + + // Data for a particular subchannel in a subchannel list. + // This subclass adds the following functionality: + // - Tracks the previous connectivity state of the subchannel, so that + // we know how many subchannels are in each state. + class WeightedRoundRobinSubchannelData + : public SubchannelData { + public: + WeightedRoundRobinSubchannelData( + SubchannelList* subchannel_list, + const ServerAddress& address, RefCountedPtr sc); + + absl::optional connectivity_state() const { + return logical_connectivity_state_; + } + + RefCountedPtr weight() const { return weight_; } + + private: + class OobWatcher : public OobBackendMetricWatcher { + public: + explicit OobWatcher(RefCountedPtr weight) + : weight_(std::move(weight)) {} + + void OnBackendMetricReport( + const BackendMetricData& backend_metric_data) override; + + private: + RefCountedPtr weight_; + }; + + // Performs connectivity state updates that need to be done only + // after we have started watching. + void ProcessConnectivityChangeLocked( + absl::optional old_state, + grpc_connectivity_state new_state) override; + + // Updates the logical connectivity state. + void UpdateLogicalConnectivityStateLocked( + grpc_connectivity_state connectivity_state); + + // The logical connectivity state of the subchannel. + // Note that the logical connectivity state may differ from the + // actual reported state in some cases (e.g., after we see + // TRANSIENT_FAILURE, we ignore any subsequent state changes until + // we see READY). + absl::optional logical_connectivity_state_; + + RefCountedPtr weight_; + }; + + // A list of subchannels. + class WeightedRoundRobinSubchannelList + : public SubchannelList { + public: + WeightedRoundRobinSubchannelList(WeightedRoundRobin* policy, + ServerAddressList addresses, + const ChannelArgs& args) + : SubchannelList(policy, + (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace) + ? "WeightedRoundRobinSubchannelList" + : nullptr), + std::move(addresses), policy->channel_control_helper(), + args) { + // Need to maintain a ref to the LB policy as long as we maintain + // any references to subchannels, since the subchannels' + // pollset_sets will include the LB policy's pollset_set. + policy->Ref(DEBUG_LOCATION, "subchannel_list").release(); + } + + ~WeightedRoundRobinSubchannelList() override { + WeightedRoundRobin* p = static_cast(policy()); + p->Unref(DEBUG_LOCATION, "subchannel_list"); + } + + // Updates the counters of subchannels in each state when a + // subchannel transitions from old_state to new_state. + void UpdateStateCountersLocked( + absl::optional old_state, + grpc_connectivity_state new_state); + + // Ensures that the right subchannel list is used and then updates + // the aggregated connectivity state based on the subchannel list's + // state counters. + void MaybeUpdateAggregatedConnectivityStateLocked( + absl::Status status_for_tf); + + private: + std::string CountersString() const { + return absl::StrCat("num_subchannels=", num_subchannels(), + " num_ready=", num_ready_, + " num_connecting=", num_connecting_, + " num_transient_failure=", num_transient_failure_); + } + + size_t num_ready_ = 0; + size_t num_connecting_ = 0; + size_t num_transient_failure_ = 0; + + absl::Status last_failure_; + }; + + // A picker that performs WRR picks with weights based on + // endpoint-reported utilization and QPS. + class Picker : public SubchannelPicker { + public: + Picker(RefCountedPtr wrr, + WeightedRoundRobinSubchannelList* subchannel_list); + + ~Picker() override; + + PickResult Pick(PickArgs args) override; + + void Orphan() override; + + private: + // A call tracker that collects per-call endpoint utilization reports. + class SubchannelCallTracker : public SubchannelCallTrackerInterface { + public: + explicit SubchannelCallTracker(RefCountedPtr weight) + : weight_(std::move(weight)) {} + + void Start() override {} + + void Finish(FinishArgs args) override; + + private: + RefCountedPtr weight_; + }; + + // Info stored about each subchannel. + struct SubchannelInfo { + SubchannelInfo(RefCountedPtr subchannel, + RefCountedPtr weight) + : subchannel(std::move(subchannel)), weight(std::move(weight)) {} + + RefCountedPtr subchannel; + RefCountedPtr weight; + }; + + // Returns the index into subchannels_ to be picked. + size_t PickIndex(); + + // Builds a new scheduler and swaps it into place, then starts a + // timer for the next update. + void BuildSchedulerAndStartTimerLocked() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&timer_mu_); + + RefCountedPtr wrr_; + const bool use_per_rpc_utilization_; + const Duration weight_update_period_; + const Duration weight_expiration_period_; + const Duration blackout_period_; + std::vector subchannels_; + + Mutex scheduler_mu_; + std::shared_ptr scheduler_ + ABSL_GUARDED_BY(&scheduler_mu_); + + Mutex timer_mu_ ABSL_ACQUIRED_BEFORE(&scheduler_mu_); + absl::optional + timer_handle_ ABSL_GUARDED_BY(&timer_mu_); + + // Used when falling back to RR. + std::atomic last_picked_index_; + }; + + ~WeightedRoundRobin() override; + + void ShutdownLocked() override; + + RefCountedPtr GetOrCreateWeight( + const grpc_resolved_address& address); + + RefCountedPtr config_; + + // List of subchannels. + RefCountedPtr subchannel_list_; + // Latest pending subchannel list. + // When we get an updated address list, we create a new subchannel list + // for it here, and we wait to swap it into subchannel_list_ until the new + // list becomes READY. + RefCountedPtr + latest_pending_subchannel_list_; + + Mutex address_weight_map_mu_; + std::map> address_weight_map_ + ABSL_GUARDED_BY(&address_weight_map_mu_); + + bool shutdown_ = false; + + absl::BitGen bit_gen_; + + // Accessed by picker. + std::atomic scheduler_state_{absl::Uniform(bit_gen_)}; +}; + +// +// WeightedRoundRobin::AddressWeight +// + +WeightedRoundRobin::AddressWeight::~AddressWeight() { + MutexLock lock(&wrr_->address_weight_map_mu_); + auto it = wrr_->address_weight_map_.find(key_); + if (it != wrr_->address_weight_map_.end() && it->second == this) { + wrr_->address_weight_map_.erase(it); + } +} + +void WeightedRoundRobin::AddressWeight::MaybeUpdateWeight( + double qps, double cpu_utilization) { + // Compute weight. + float weight = 0; + if (qps > 0 && cpu_utilization > 0) weight = qps / cpu_utilization; + if (weight == 0) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, + "[WRR %p] subchannel %s: qps=%f, cpu_utilization=%f: weight=%f " + "(not updating)", + wrr_.get(), key_.c_str(), qps, cpu_utilization, weight); + } + return; + } + Timestamp now = Timestamp::Now(); + // Grab the lock and update the data. + MutexLock lock(&mu_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, + "[WRR %p] subchannel %s: qps=%f, cpu_utilization=%f: setting " + "weight=%f weight_=%f now=%s last_update_time_=%s " + "non_empty_since_=%s", + wrr_.get(), key_.c_str(), qps, cpu_utilization, weight, weight_, + now.ToString().c_str(), last_update_time_.ToString().c_str(), + non_empty_since_.ToString().c_str()); + } + if (non_empty_since_ == Timestamp::InfFuture()) non_empty_since_ = now; + weight_ = weight; + last_update_time_ = now; +} + +float WeightedRoundRobin::AddressWeight::GetWeight( + Timestamp now, Duration weight_expiration_period, + Duration blackout_period) { + MutexLock lock(&mu_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, + "[WRR %p] subchannel %s: getting weight: now=%s " + "weight_expiration_period=%s blackout_period=%s " + "last_update_time_=%s non_empty_since_=%s weight_=%f", + wrr_.get(), key_.c_str(), now.ToString().c_str(), + weight_expiration_period.ToString().c_str(), + blackout_period.ToString().c_str(), + last_update_time_.ToString().c_str(), + non_empty_since_.ToString().c_str(), weight_); + } + // If the most recent update was longer ago than the expiration + // period, reset non_empty_since_ so that we apply the blackout period + // again if we start getting data again in the future, and return 0. + if (now - last_update_time_ >= weight_expiration_period) { + non_empty_since_ = Timestamp::InfFuture(); + return 0; + } + // If we don't have at least blackout_period worth of data, return 0. + if (blackout_period > Duration::Zero() && + now - non_empty_since_ < blackout_period) { + return 0; + } + // Otherwise, return the weight. + return weight_; +} + +void WeightedRoundRobin::AddressWeight::ResetNonEmptySince() { + MutexLock lock(&mu_); + non_empty_since_ = Timestamp::InfFuture(); +} + +// +// WeightedRoundRobin::Picker::SubchannelCallTracker +// + +void WeightedRoundRobin::Picker::SubchannelCallTracker::Finish( + FinishArgs args) { + auto* backend_metric_data = + args.backend_metric_accessor->GetBackendMetricData(); + double qps = 0; + double cpu_utilization = 0; + if (backend_metric_data != nullptr) { + qps = backend_metric_data->qps; + cpu_utilization = backend_metric_data->cpu_utilization; + } + weight_->MaybeUpdateWeight(qps, cpu_utilization); +} + +// +// WeightedRoundRobin::Picker +// + +WeightedRoundRobin::Picker::Picker( + RefCountedPtr wrr, + WeightedRoundRobinSubchannelList* subchannel_list) + : wrr_(std::move(wrr)), + use_per_rpc_utilization_(!wrr_->config_->enable_oob_load_report()), + weight_update_period_(wrr_->config_->weight_update_period()), + weight_expiration_period_(wrr_->config_->weight_expiration_period()), + blackout_period_(wrr_->config_->blackout_period()), + last_picked_index_(absl::Uniform(wrr_->bit_gen_)) { + for (size_t i = 0; i < subchannel_list->num_subchannels(); ++i) { + WeightedRoundRobinSubchannelData* sd = subchannel_list->subchannel(i); + if (sd->connectivity_state() == GRPC_CHANNEL_READY) { + subchannels_.emplace_back(sd->subchannel()->Ref(), sd->weight()); + } + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, + "[WRR %p picker %p] created picker from subchannel_list=%p " + "with %" PRIuPTR " subchannels", + wrr_.get(), this, subchannel_list, subchannels_.size()); + } + BuildSchedulerAndStartTimerLocked(); +} + +WeightedRoundRobin::Picker::~Picker() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p picker %p] destroying picker", wrr_.get(), this); + } +} + +void WeightedRoundRobin::Picker::Orphan() { + MutexLock lock(&timer_mu_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p picker %p] cancelling timer", wrr_.get(), this); + } + wrr_->channel_control_helper()->GetEventEngine()->Cancel(*timer_handle_); + timer_handle_.reset(); +} + +WeightedRoundRobin::PickResult WeightedRoundRobin::Picker::Pick( + PickArgs /*args*/) { + size_t index = PickIndex(); + GPR_ASSERT(index < subchannels_.size()); + auto& subchannel_info = subchannels_[index]; + // Collect per-call utilization data if needed. + std::unique_ptr subchannel_call_tracker; + if (use_per_rpc_utilization_) { + subchannel_call_tracker = + std::make_unique(subchannel_info.weight); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, + "[WRR %p picker %p] returning index %" PRIuPTR ", subchannel=%p", + wrr_.get(), this, index, subchannel_info.subchannel.get()); + } + return PickResult::Complete(subchannel_info.subchannel, + std::move(subchannel_call_tracker)); +} + +size_t WeightedRoundRobin::Picker::PickIndex() { + // Grab a ref to the scheduler. + std::shared_ptr scheduler; + { + MutexLock lock(&scheduler_mu_); + scheduler = scheduler_; + } + // If we have a scheduler, use it to do a WRR pick. + if (scheduler != nullptr) return scheduler->Pick(); + // We don't have a scheduler (i.e., either all of the weights are 0 or + // there is only one subchannel), so fall back to RR. + return last_picked_index_.fetch_add(1) % subchannels_.size(); +} + +void WeightedRoundRobin::Picker::BuildSchedulerAndStartTimerLocked() { + // Build scheduler. + const Timestamp now = Timestamp::Now(); + std::vector weights; + weights.reserve(subchannels_.size()); + for (const auto& subchannel : subchannels_) { + weights.push_back(subchannel.weight->GetWeight( + now, weight_expiration_period_, blackout_period_)); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p picker %p] new weights: %s", wrr_.get(), this, + absl::StrJoin(weights, " ").c_str()); + } + auto scheduler_or = StaticStrideScheduler::Make( + weights, [this]() { return wrr_->scheduler_state_.fetch_add(1); }); + std::shared_ptr scheduler; + if (scheduler_or.has_value()) { + scheduler = + std::make_shared(std::move(*scheduler_or)); + } + { + MutexLock lock(&scheduler_mu_); + scheduler_ = std::move(scheduler); + } + // Start timer. + WeakRefCountedPtr self = WeakRef(); + timer_handle_ = wrr_->channel_control_helper()->GetEventEngine()->RunAfter( + weight_update_period_, [self = std::move(self)]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + { + MutexLock lock(&self->timer_mu_); + if (self->timer_handle_.has_value()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p picker %p] timer fired", + self->wrr_.get(), self.get()); + } + self->BuildSchedulerAndStartTimerLocked(); + } + } + // Release ref before ExecCtx goes out of scope. + self.reset(); + }); +} + +// +// WeightedRoundRobin +// + +WeightedRoundRobin::WeightedRoundRobin(Args args) + : LoadBalancingPolicy(std::move(args)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p] Created", this); + } +} + +WeightedRoundRobin::~WeightedRoundRobin() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p] Destroying Round Robin policy", this); + } + GPR_ASSERT(subchannel_list_ == nullptr); + GPR_ASSERT(latest_pending_subchannel_list_ == nullptr); +} + +void WeightedRoundRobin::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p] Shutting down", this); + } + shutdown_ = true; + subchannel_list_.reset(); + latest_pending_subchannel_list_.reset(); +} + +void WeightedRoundRobin::ResetBackoffLocked() { + subchannel_list_->ResetBackoffLocked(); + if (latest_pending_subchannel_list_ != nullptr) { + latest_pending_subchannel_list_->ResetBackoffLocked(); + } +} + +absl::Status WeightedRoundRobin::UpdateLocked(UpdateArgs args) { + config_ = std::move(args.config); + ServerAddressList addresses; + if (args.addresses.ok()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p] received update with %" PRIuPTR " addresses", + this, args.addresses->size()); + } + // Weed out duplicate addresses. Also sort the addresses so that if + // the set of the addresses don't change, their indexes in the + // subchannel list don't change, since this avoids unnecessary churn + // in the picker. Note that this does not ensure that if a given + // address remains present that it will have the same index; if, + // for example, an address at the end of the list is replaced with one + // that sorts much earlier in the list, then all of the addresses in + // between those two positions will have changed indexes. + struct AddressLessThan { + bool operator()(const ServerAddress& address1, + const ServerAddress& address2) const { + const grpc_resolved_address& addr1 = address1.address(); + const grpc_resolved_address& addr2 = address2.address(); + if (addr1.len != addr2.len) return addr1.len < addr2.len; + return memcmp(addr1.addr, addr2.addr, addr1.len) < 0; + } + }; + std::set ordered_addresses( + args.addresses->begin(), args.addresses->end()); + addresses = + ServerAddressList(ordered_addresses.begin(), ordered_addresses.end()); + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p] received update with address error: %s", this, + args.addresses.status().ToString().c_str()); + } + // If we already have a subchannel list, then keep using the existing + // list, but still report back that the update was not accepted. + if (subchannel_list_ != nullptr) return args.addresses.status(); + } + // Create new subchannel list, replacing the previous pending list, if any. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace) && + latest_pending_subchannel_list_ != nullptr) { + gpr_log(GPR_INFO, "[WRR %p] replacing previous pending subchannel list %p", + this, latest_pending_subchannel_list_.get()); + } + latest_pending_subchannel_list_ = + MakeRefCounted( + this, std::move(addresses), args.args); + latest_pending_subchannel_list_->StartWatchingLocked(); + // If the new list is empty, immediately promote it to + // subchannel_list_ and report TRANSIENT_FAILURE. + if (latest_pending_subchannel_list_->num_subchannels() == 0) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace) && + subchannel_list_ != nullptr) { + gpr_log(GPR_INFO, "[WRR %p] replacing previous subchannel list %p", this, + subchannel_list_.get()); + } + subchannel_list_ = std::move(latest_pending_subchannel_list_); + absl::Status status = + args.addresses.ok() ? absl::UnavailableError(absl::StrCat( + "empty address list: ", args.resolution_note)) + : args.addresses.status(); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + MakeRefCounted(status)); + return status; + } + // Otherwise, if this is the initial update, immediately promote it to + // subchannel_list_ and report CONNECTING. + if (subchannel_list_.get() == nullptr) { + subchannel_list_ = std::move(latest_pending_subchannel_list_); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_CONNECTING, absl::Status(), + MakeRefCounted(Ref(DEBUG_LOCATION, "QueuePicker"))); + } + return absl::OkStatus(); +} + +RefCountedPtr +WeightedRoundRobin::GetOrCreateWeight(const grpc_resolved_address& address) { + auto key = grpc_sockaddr_to_uri(&address); + if (!key.ok()) return nullptr; + MutexLock lock(&address_weight_map_mu_); + auto it = address_weight_map_.find(*key); + if (it != address_weight_map_.end()) { + auto weight = it->second->RefIfNonZero(); + if (weight != nullptr) return weight; + } + auto weight = + MakeRefCounted(Ref(DEBUG_LOCATION, "AddressWeight"), *key); + address_weight_map_.emplace(*key, weight.get()); + return weight; +} + +// +// WeightedRoundRobin::WeightedRoundRobinSubchannelList +// + +void WeightedRoundRobin::WeightedRoundRobinSubchannelList:: + UpdateStateCountersLocked(absl::optional old_state, + grpc_connectivity_state new_state) { + if (old_state.has_value()) { + GPR_ASSERT(*old_state != GRPC_CHANNEL_SHUTDOWN); + if (*old_state == GRPC_CHANNEL_READY) { + GPR_ASSERT(num_ready_ > 0); + --num_ready_; + } else if (*old_state == GRPC_CHANNEL_CONNECTING) { + GPR_ASSERT(num_connecting_ > 0); + --num_connecting_; + } else if (*old_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + GPR_ASSERT(num_transient_failure_ > 0); + --num_transient_failure_; + } + } + GPR_ASSERT(new_state != GRPC_CHANNEL_SHUTDOWN); + if (new_state == GRPC_CHANNEL_READY) { + ++num_ready_; + } else if (new_state == GRPC_CHANNEL_CONNECTING) { + ++num_connecting_; + } else if (new_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + ++num_transient_failure_; + } +} + +void WeightedRoundRobin::WeightedRoundRobinSubchannelList:: + MaybeUpdateAggregatedConnectivityStateLocked(absl::Status status_for_tf) { + WeightedRoundRobin* p = static_cast(policy()); + // If this is latest_pending_subchannel_list_, then swap it into + // subchannel_list_ in the following cases: + // - subchannel_list_ has no READY subchannels. + // - This list has at least one READY subchannel and we have seen the + // initial connectivity state notification for all subchannels. + // - All of the subchannels in this list are in TRANSIENT_FAILURE. + // (This may cause the channel to go from READY to TRANSIENT_FAILURE, + // but we're doing what the control plane told us to do.) + if (p->latest_pending_subchannel_list_.get() == this && + (p->subchannel_list_->num_ready_ == 0 || + (num_ready_ > 0 && AllSubchannelsSeenInitialState()) || + num_transient_failure_ == num_subchannels())) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + const std::string old_counters_string = + p->subchannel_list_ != nullptr ? p->subchannel_list_->CountersString() + : ""; + gpr_log( + GPR_INFO, + "[WRR %p] swapping out subchannel list %p (%s) in favor of %p (%s)", + p, p->subchannel_list_.get(), old_counters_string.c_str(), this, + CountersString().c_str()); + } + p->subchannel_list_ = std::move(p->latest_pending_subchannel_list_); + } + // Only set connectivity state if this is the current subchannel list. + if (p->subchannel_list_.get() != this) return; + // First matching rule wins: + // 1) ANY subchannel is READY => policy is READY. + // 2) ANY subchannel is CONNECTING => policy is CONNECTING. + // 3) ALL subchannels are TRANSIENT_FAILURE => policy is TRANSIENT_FAILURE. + if (num_ready_ > 0) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p] reporting READY with subchannel list %p", p, + this); + } + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_READY, absl::Status(), + MakeRefCounted(p->Ref(), this)); + } else if (num_connecting_ > 0) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, "[WRR %p] reporting CONNECTING with subchannel list %p", + p, this); + } + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_CONNECTING, absl::Status(), + MakeRefCounted(p->Ref(DEBUG_LOCATION, "QueuePicker"))); + } else if (num_transient_failure_ == num_subchannels()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log( + GPR_INFO, + "[WRR %p] reporting TRANSIENT_FAILURE with subchannel list %p: %s", p, + this, status_for_tf.ToString().c_str()); + } + if (!status_for_tf.ok()) { + last_failure_ = absl::UnavailableError( + absl::StrCat("connections to all backends failing; last error: ", + status_for_tf.ToString())); + } + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, last_failure_, + MakeRefCounted(last_failure_)); + } +} + +// +// WeightedRoundRobin::WeightedRoundRobinSubchannelData::OobWatcher +// + +void WeightedRoundRobin::WeightedRoundRobinSubchannelData::OobWatcher:: + OnBackendMetricReport(const BackendMetricData& backend_metric_data) { + weight_->MaybeUpdateWeight(backend_metric_data.qps, + backend_metric_data.cpu_utilization); +} + +// +// WeightedRoundRobin::WeightedRoundRobinSubchannelData +// + +WeightedRoundRobin::WeightedRoundRobinSubchannelData:: + WeightedRoundRobinSubchannelData( + SubchannelList* subchannel_list, + const ServerAddress& address, RefCountedPtr sc) + : SubchannelData(subchannel_list, address, std::move(sc)), + weight_(static_cast(subchannel_list->policy()) + ->GetOrCreateWeight(address.address())) { + // Start OOB watch if configured. + WeightedRoundRobin* p = + static_cast(subchannel_list->policy()); + if (p->config_->enable_oob_load_report()) { + subchannel()->AddDataWatcher( + MakeOobBackendMetricWatcher(p->config_->oob_reporting_period(), + std::make_unique(weight_))); + } +} + +void WeightedRoundRobin::WeightedRoundRobinSubchannelData:: + ProcessConnectivityChangeLocked( + absl::optional old_state, + grpc_connectivity_state new_state) { + WeightedRoundRobin* p = + static_cast(subchannel_list()->policy()); + GPR_ASSERT(subchannel() != nullptr); + // If this is not the initial state notification and the new state is + // TRANSIENT_FAILURE or IDLE, re-resolve. + // Note that we don't want to do this on the initial state notification, + // because that would result in an endless loop of re-resolution. + if (old_state.has_value() && (new_state == GRPC_CHANNEL_TRANSIENT_FAILURE || + new_state == GRPC_CHANNEL_IDLE)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, + "[WRR %p] Subchannel %p reported %s; requesting re-resolution", p, + subchannel(), ConnectivityStateName(new_state)); + } + p->channel_control_helper()->RequestReresolution(); + } + if (new_state == GRPC_CHANNEL_IDLE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, + "[WRR %p] Subchannel %p reported IDLE; requesting connection", p, + subchannel()); + } + subchannel()->RequestConnection(); + } else if (new_state == GRPC_CHANNEL_READY) { + // If we transition back to READY state, restart the blackout period. + // Note that we cannot guarantee that we will never receive + // lingering callbacks for backend metric reports from the previous + // connection after the new connection has been established, but they + // should be masked by new backend metric reports from the new + // connection by the time the blackout period ends. + weight_->ResetNonEmptySince(); + } + // Update logical connectivity state. + UpdateLogicalConnectivityStateLocked(new_state); + // Update the policy state. + subchannel_list()->MaybeUpdateAggregatedConnectivityStateLocked( + connectivity_status()); +} + +void WeightedRoundRobin::WeightedRoundRobinSubchannelData:: + UpdateLogicalConnectivityStateLocked( + grpc_connectivity_state connectivity_state) { + WeightedRoundRobin* p = + static_cast(subchannel_list()->policy()); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log( + GPR_INFO, + "[WRR %p] connectivity changed for subchannel %p, subchannel_list %p " + "(index %" PRIuPTR " of %" PRIuPTR "): prev_state=%s new_state=%s", + p, subchannel(), subchannel_list(), Index(), + subchannel_list()->num_subchannels(), + (logical_connectivity_state_.has_value() + ? ConnectivityStateName(*logical_connectivity_state_) + : "N/A"), + ConnectivityStateName(connectivity_state)); + } + // Decide what state to report for aggregation purposes. + // If the last logical state was TRANSIENT_FAILURE, then ignore the + // state change unless the new state is READY. + if (logical_connectivity_state_.has_value() && + *logical_connectivity_state_ == GRPC_CHANNEL_TRANSIENT_FAILURE && + connectivity_state != GRPC_CHANNEL_READY) { + return; + } + // If the new state is IDLE, treat it as CONNECTING, since it will + // immediately transition into CONNECTING anyway. + if (connectivity_state == GRPC_CHANNEL_IDLE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) { + gpr_log(GPR_INFO, + "[WRR %p] subchannel %p, subchannel_list %p (index %" PRIuPTR + " of %" PRIuPTR "): treating IDLE as CONNECTING", + p, subchannel(), subchannel_list(), Index(), + subchannel_list()->num_subchannels()); + } + connectivity_state = GRPC_CHANNEL_CONNECTING; + } + // If no change, return false. + if (logical_connectivity_state_.has_value() && + *logical_connectivity_state_ == connectivity_state) { + return; + } + // Otherwise, update counters and logical state. + subchannel_list()->UpdateStateCountersLocked(logical_connectivity_state_, + connectivity_state); + logical_connectivity_state_ = connectivity_state; +} + +// +// factory +// + +class WeightedRoundRobinFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + absl::string_view name() const override { return kWeightedRoundRobin; } + + absl::StatusOr> + ParseLoadBalancingConfig(const Json& json) const override { + if (json.type() == Json::Type::JSON_NULL) { + // priority was mentioned as a policy in the deprecated + // loadBalancingPolicy field or in the client API. + return absl::InvalidArgumentError( + "field:loadBalancingPolicy error:priority policy requires " + "configuration. Please use loadBalancingConfig field of service " + "config instead."); + } + return LoadRefCountedFromJson( + json, JsonArgs(), "errors validating priority LB policy config"); + } +}; + +} // namespace + +void RegisterWeightedRoundRobinLbPolicy(CoreConfiguration::Builder* builder) { + builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory( + std::make_unique()); +} + +} // namespace grpc_core diff --git a/src/core/lib/load_balancing/lb_policy.cc b/src/core/lib/load_balancing/lb_policy.cc index 0916fa0fa04..229a201544d 100644 --- a/src/core/lib/load_balancing/lb_policy.cc +++ b/src/core/lib/load_balancing/lb_policy.cc @@ -53,6 +53,15 @@ void LoadBalancingPolicy::Orphan() { Unref(DEBUG_LOCATION, "Orphan"); } +// +// LoadBalancingPolicy::SubchannelPicker +// + +LoadBalancingPolicy::SubchannelPicker::SubchannelPicker() + : DualRefCounted(GRPC_TRACE_FLAG_ENABLED(grpc_trace_lb_policy_refcount) + ? "SubchannelPicker" + : nullptr) {} + // // LoadBalancingPolicy::QueuePicker // diff --git a/src/core/lib/load_balancing/lb_policy.h b/src/core/lib/load_balancing/lb_policy.h index 6ed839bd225..c504c42da32 100644 --- a/src/core/lib/load_balancing/lb_policy.h +++ b/src/core/lib/load_balancing/lb_policy.h @@ -40,6 +40,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/debug/trace.h" #include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/dual_ref_counted.h" #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" @@ -176,6 +177,7 @@ class LoadBalancingPolicy : public InternallyRefCounted { /// implementation does not take ownership, so any data that needs to be /// used after returning must be copied. struct FinishArgs { + absl::string_view peer_address; absl::Status status; MetadataInterface* trailing_metadata; BackendMetricAccessor* backend_metric_accessor; @@ -256,11 +258,13 @@ class LoadBalancingPolicy : public InternallyRefCounted { /// Currently, pickers are always accessed from within the /// client_channel data plane mutex, so they do not have to be /// thread-safe. - class SubchannelPicker : public RefCounted { + class SubchannelPicker : public DualRefCounted { public: - SubchannelPicker() = default; + SubchannelPicker(); virtual PickResult Pick(PickArgs args) = 0; + + void Orphan() override {} }; /// A proxy object implemented by the client channel and used by the diff --git a/src/core/plugin_registry/grpc_plugin_registry.cc b/src/core/plugin_registry/grpc_plugin_registry.cc index f8cd6de28c9..84aa9c09372 100644 --- a/src/core/plugin_registry/grpc_plugin_registry.cc +++ b/src/core/plugin_registry/grpc_plugin_registry.cc @@ -60,6 +60,8 @@ extern void RegisterOutlierDetectionLbPolicy( extern void RegisterWeightedTargetLbPolicy(CoreConfiguration::Builder* builder); extern void RegisterPickFirstLbPolicy(CoreConfiguration::Builder* builder); extern void RegisterRoundRobinLbPolicy(CoreConfiguration::Builder* builder); +extern void RegisterWeightedRoundRobinLbPolicy( + CoreConfiguration::Builder* builder); extern void RegisterRingHashLbPolicy(CoreConfiguration::Builder* builder); extern void RegisterHttpProxyMapper(CoreConfiguration::Builder* builder); #ifndef GRPC_NO_RLS @@ -82,6 +84,7 @@ void BuildCoreConfiguration(CoreConfiguration::Builder* builder) { RegisterWeightedTargetLbPolicy(builder); RegisterPickFirstLbPolicy(builder); RegisterRoundRobinLbPolicy(builder); + RegisterWeightedRoundRobinLbPolicy(builder); RegisterRingHashLbPolicy(builder); BuildClientChannelConfiguration(builder); SecurityRegisterHandshakerFactories(builder); diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index 8c725b1931e..652e4c2d118 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -45,6 +45,8 @@ CORE_SOURCE_FILES = [ 'src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc', 'src/core/ext/filters/client_channel/lb_policy/rls/rls.cc', 'src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc', + 'src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc', 'src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc', 'src/core/ext/filters/client_channel/lb_policy/xds/cds.cc', 'src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc', diff --git a/test/core/client_channel/lb_policy/BUILD b/test/core/client_channel/lb_policy/BUILD index a6dbc918529..50e31e5de3c 100644 --- a/test/core/client_channel/lb_policy/BUILD +++ b/test/core/client_channel/lb_policy/BUILD @@ -179,3 +179,33 @@ grpc_cc_test( "//src/core:static_stride_scheduler", ], ) + +grpc_cc_test( + name = "weighted_round_robin_config_test", + srcs = ["weighted_round_robin_config_test.cc"], + external_deps = [ + "gtest", + ], + language = "C++", + tags = ["no_test_ios"], + uses_event_engine = False, + uses_polling = False, + deps = [ + "//:grpc", + "//test/core/util:grpc_test_util", + ], +) + +grpc_cc_test( + name = "weighted_round_robin_test", + srcs = ["weighted_round_robin_test.cc"], + external_deps = ["gtest"], + language = "C++", + uses_polling = False, + deps = [ + ":lb_policy_test_lib", + "//src/core:grpc_lb_policy_weighted_round_robin", + "//test/core/event_engine:mock_event_engine", + "//test/core/util:grpc_test_util", + ], +) diff --git a/test/core/client_channel/lb_policy/lb_policy_test_lib.h b/test/core/client_channel/lb_policy/lb_policy_test_lib.h index 1d239cd2bf7..e8611a4690d 100644 --- a/test/core/client_channel/lb_policy/lb_policy_test_lib.h +++ b/test/core/client_channel/lb_policy/lb_policy_test_lib.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -50,6 +51,9 @@ #include #include "src/core/ext/filters/client_channel/lb_call_state_internal.h" +#include "src/core/ext/filters/client_channel/lb_policy/backend_metric_data.h" +#include "src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h" +#include "src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h" #include "src/core/ext/filters/client_channel/subchannel_pool_interface.h" #include "src/core/lib/address_utils/parse_address.h" #include "src/core/lib/address_utils/sockaddr_utils.h" @@ -61,6 +65,7 @@ #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/time.h" #include "src/core/lib/gprpp/unique_type_name.h" #include "src/core/lib/gprpp/work_serializer.h" #include "src/core/lib/iomgr/exec_ctx.h" @@ -92,6 +97,13 @@ class LoadBalancingPolicyTest : public ::testing::Test { std::shared_ptr work_serializer) : state_(state), work_serializer_(std::move(work_serializer)) {} + ~FakeSubchannel() override { + if (orca_watcher_ != nullptr) { + MutexLock lock(&state_->backend_metric_watcher_mu_); + state_->watchers_.erase(orca_watcher_.get()); + } + } + SubchannelState* state() const { return state_; } private: @@ -145,15 +157,23 @@ class LoadBalancingPolicyTest : public ::testing::Test { state_->requested_connection_ = true; } - // Don't need these methods here, so they're no-ops. + void AddDataWatcher( + std::unique_ptr watcher) override { + MutexLock lock(&state_->backend_metric_watcher_mu_); + GPR_ASSERT(orca_watcher_ == nullptr); + orca_watcher_.reset(static_cast(watcher.release())); + state_->watchers_.insert(orca_watcher_.get()); + } + + // Don't need this method, so it's a no-op. void ResetBackoff() override {} - void AddDataWatcher(std::unique_ptr) override {} SubchannelState* state_; std::shared_ptr work_serializer_; std::map watcher_map_; + std::unique_ptr orca_watcher_; }; explicit SubchannelState(absl::string_view address) @@ -232,13 +252,37 @@ class LoadBalancingPolicyTest : public ::testing::Test { return MakeRefCounted(this, std::move(work_serializer)); } + // Sends an OOB backend metric report to all watchers. + void SendOobBackendMetricReport(const BackendMetricData& backend_metrics) { + MutexLock lock(&backend_metric_watcher_mu_); + for (const auto* watcher : watchers_) { + watcher->watcher()->OnBackendMetricReport(backend_metrics); + } + } + + // Checks that all OOB watchers have the expected reporting period. + void CheckOobReportingPeriod(Duration expected, + SourceLocation location = SourceLocation()) { + MutexLock lock(&backend_metric_watcher_mu_); + for (const auto* watcher : watchers_) { + EXPECT_EQ(watcher->report_interval(), expected) + << location.file() << ":" << location.line(); + } + } + private: const std::string address_; + Mutex mu_; ConnectivityStateTracker state_tracker_ ABSL_GUARDED_BY(&mu_); + Mutex requested_connection_mu_; bool requested_connection_ ABSL_GUARDED_BY(&requested_connection_mu_) = false; + + Mutex backend_metric_watcher_mu_; + std::set watchers_ + ABSL_GUARDED_BY(&backend_metric_watcher_mu_); }; // A fake helper to be passed to the LB policy. @@ -263,8 +307,17 @@ class LoadBalancingPolicyTest : public ::testing::Test { }; FakeHelper(LoadBalancingPolicyTest* test, - std::shared_ptr work_serializer) - : test_(test), work_serializer_(std::move(work_serializer)) {} + std::shared_ptr work_serializer, + std::shared_ptr + event_engine) + : test_(test), + work_serializer_(std::move(work_serializer)), + event_engine_(std::move(event_engine)) {} + + bool QueueEmpty() { + MutexLock lock(&mu_); + return queue_.empty(); + } // Called at test tear-down time to ensure that we have not left any // unexpected events in the queue. @@ -293,6 +346,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { << location.file() << ":" << location.line(); if (update == nullptr) return absl::nullopt; StateUpdate result = std::move(*update); + gpr_log(GPR_INFO, "got next state update: %s", result.ToString().c_str()); queue_.pop_front(); return std::move(result); } @@ -349,7 +403,10 @@ class LoadBalancingPolicyTest : public ::testing::Test { grpc_connectivity_state state, const absl::Status& status, RefCountedPtr picker) override { MutexLock lock(&mu_); - queue_.push_back(StateUpdate{state, status, std::move(picker)}); + StateUpdate update{state, status, std::move(picker)}; + gpr_log(GPR_INFO, "state update from LB policy: %s", + update.ToString().c_str()); + queue_.push_back(std::move(update)); } void RequestReresolution() override { @@ -360,13 +417,15 @@ class LoadBalancingPolicyTest : public ::testing::Test { absl::string_view GetAuthority() override { return "server.example.com"; } grpc_event_engine::experimental::EventEngine* GetEventEngine() override { - return grpc_event_engine::experimental::GetDefaultEventEngine().get(); + return event_engine_.get(); } void AddTraceEvent(TraceSeverity, absl::string_view) override {} LoadBalancingPolicyTest* test_; std::shared_ptr work_serializer_; + std::shared_ptr event_engine_; + Mutex mu_; std::deque queue_ ABSL_GUARDED_BY(&mu_); }; @@ -428,6 +487,24 @@ class LoadBalancingPolicyTest : public ::testing::Test { std::map attributes_; }; + // A fake BackendMetricAccessor implementation, for passing to + // SubchannelCallTrackerInterface::Finish(). + class FakeBackendMetricAccessor + : public LoadBalancingPolicy::BackendMetricAccessor { + public: + explicit FakeBackendMetricAccessor( + absl::optional backend_metric_data) + : backend_metric_data_(std::move(backend_metric_data)) {} + + const BackendMetricData* GetBackendMetricData() override { + if (backend_metric_data_.has_value()) return &*backend_metric_data_; + return nullptr; + } + + private: + const absl::optional backend_metric_data_; + }; + LoadBalancingPolicyTest() : work_serializer_(std::make_shared()) {} @@ -445,7 +522,8 @@ class LoadBalancingPolicyTest : public ::testing::Test { // Creates a new FakeHelper for the new LB policy, and sets helper_ to // point to the FakeHelper. OrphanablePtr MakeLbPolicy(absl::string_view name) { - auto helper = std::make_unique(this, work_serializer_); + auto helper = + std::make_unique(this, work_serializer_, event_engine_); helper_ = helper.get(); LoadBalancingPolicy::Args args = {work_serializer_, std::move(helper), ChannelArgs()}; @@ -514,10 +592,17 @@ class LoadBalancingPolicyTest : public ::testing::Test { bool WaitForStateUpdate( std::function continue_predicate, SourceLocation location = SourceLocation()) { + gpr_log(GPR_INFO, "==> WaitForStateUpdate()"); while (true) { auto update = helper_->GetNextStateUpdate(location); - if (!update.has_value()) return false; - if (!continue_predicate(std::move(*update))) return true; + if (!update.has_value()) { + gpr_log(GPR_INFO, "WaitForStateUpdate() returning false"); + return false; + } + if (!continue_predicate(std::move(*update))) { + gpr_log(GPR_INFO, "WaitForStateUpdate() returning true"); + return true; + } } } @@ -552,6 +637,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { // update for state READY, whose picker is returned. RefCountedPtr WaitForConnected( SourceLocation location = SourceLocation()) { + gpr_log(GPR_INFO, "==> WaitForConnected()"); RefCountedPtr final_picker; WaitForStateUpdate( [&](FakeHelper::StateUpdate update) { @@ -611,7 +697,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { WaitForRoundRobinListChange( absl::Span old_addresses, absl::Span new_addresses, - const std::map call_attributes = {}, + const std::map& call_attributes = {}, size_t num_iterations = 3, SourceLocation location = SourceLocation()) { gpr_log(GPR_INFO, "Waiting for expected RR addresses..."); RefCountedPtr retval; @@ -625,7 +711,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { // Get enough picks to round-robin num_iterations times across all // expected addresses. auto picks = GetCompletePicks(update.picker.get(), num_picks, - call_attributes, location); + call_attributes, nullptr, location); EXPECT_TRUE(picks.has_value()) << location.file() << ":" << location.line(); if (!picks.has_value()) return false; @@ -681,7 +767,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { // Requests a pick on picker and expects a Queue result. void ExpectPickQueued( LoadBalancingPolicy::SubchannelPicker* picker, - const std::map call_attributes = {}, + const std::map& call_attributes = {}, SourceLocation location = SourceLocation()) { ASSERT_NE(picker, nullptr); auto pick_result = DoPick(picker, call_attributes); @@ -694,9 +780,15 @@ class LoadBalancingPolicyTest : public ::testing::Test { // Requests a pick on picker and expects a Complete result. // The address of the resulting subchannel is returned, or nullopt if // the result was something other than Complete. + // If the complete pick includes a SubchannelCallTrackerInterface, then if + // subchannel_call_tracker is non-null, it will be set to point to the + // call tracker; otherwise, the call tracker will be invoked + // automatically to represent a complete call with no backend metric data. absl::optional ExpectPickComplete( LoadBalancingPolicy::SubchannelPicker* picker, - const std::map call_attributes = {}, + const std::map& call_attributes = {}, + std::unique_ptr* + subchannel_call_tracker = nullptr, SourceLocation location = SourceLocation()) { auto pick_result = DoPick(picker, call_attributes); auto* complete = absl::get_if( @@ -706,14 +798,30 @@ class LoadBalancingPolicyTest : public ::testing::Test { if (complete == nullptr) return absl::nullopt; auto* subchannel = static_cast( complete->subchannel.get()); - return subchannel->state()->address(); + std::string address = subchannel->state()->address(); + if (complete->subchannel_call_tracker != nullptr) { + if (subchannel_call_tracker != nullptr) { + *subchannel_call_tracker = std::move(complete->subchannel_call_tracker); + } else { + complete->subchannel_call_tracker->Start(); + FakeMetadata metadata({}); + FakeBackendMetricAccessor backend_metric_accessor({}); + LoadBalancingPolicy::SubchannelCallTrackerInterface::FinishArgs args = { + address, absl::OkStatus(), &metadata, &backend_metric_accessor}; + complete->subchannel_call_tracker->Finish(args); + } + } + return address; } // Gets num_picks complete picks from picker and returns the resulting // list of addresses, or nullopt if a non-complete pick was returned. absl::optional> GetCompletePicks( LoadBalancingPolicy::SubchannelPicker* picker, size_t num_picks, - const std::map call_attributes = {}, + const std::map& call_attributes = {}, + std::vector< + std::unique_ptr>* + subchannel_call_trackers = nullptr, SourceLocation location = SourceLocation()) { EXPECT_NE(picker, nullptr); if (picker == nullptr) { @@ -721,9 +829,19 @@ class LoadBalancingPolicyTest : public ::testing::Test { } std::vector results; for (size_t i = 0; i < num_picks; ++i) { - auto address = ExpectPickComplete(picker, call_attributes, location); + std::unique_ptr + subchannel_call_tracker; + auto address = ExpectPickComplete(picker, call_attributes, + subchannel_call_trackers == nullptr + ? nullptr + : &subchannel_call_tracker, + location); if (!address.has_value()) return absl::nullopt; results.emplace_back(std::move(*address)); + if (subchannel_call_trackers != nullptr) { + subchannel_call_trackers->emplace_back( + std::move(subchannel_call_tracker)); + } } return results; } @@ -750,10 +868,10 @@ class LoadBalancingPolicyTest : public ::testing::Test { void ExpectRoundRobinPicks( LoadBalancingPolicy::SubchannelPicker* picker, absl::Span addresses, - const std::map call_attributes = {}, + const std::map& call_attributes = {}, size_t num_iterations = 3, SourceLocation location = SourceLocation()) { auto picks = GetCompletePicks(picker, num_iterations * addresses.size(), - call_attributes, location); + call_attributes, nullptr, location); ASSERT_TRUE(picks.has_value()) << location.file() << ":" << location.line(); EXPECT_TRUE(PicksAreRoundRobin(addresses, *picks)) << " Actual: " << absl::StrJoin(*picks, ", ") @@ -823,6 +941,8 @@ class LoadBalancingPolicyTest : public ::testing::Test { } std::shared_ptr work_serializer_; + std::shared_ptr event_engine_ = + grpc_event_engine::experimental::GetDefaultEventEngine(); FakeHelper* helper_ = nullptr; std::map subchannel_pool_; }; diff --git a/test/core/client_channel/lb_policy/weighted_round_robin_config_test.cc b/test/core/client_channel/lb_policy/weighted_round_robin_config_test.cc new file mode 100644 index 00000000000..5e9ec642eb9 --- /dev/null +++ b/test/core/client_channel/lb_policy/weighted_round_robin_config_test.cc @@ -0,0 +1,85 @@ +// +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "gtest/gtest.h" + +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/service_config/service_config.h" +#include "src/core/lib/service_config/service_config_impl.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +TEST(WeightedRoundRobinConfigTest, EmptyConfig) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"weighted_round_robin_experimental\":{\n" + " }\n" + " }]\n" + "}\n"; + auto service_config = + ServiceConfigImpl::Create(ChannelArgs(), service_config_json); + ASSERT_TRUE(service_config.ok()); + EXPECT_NE(*service_config, nullptr); +} + +TEST(WeightedRoundRobinConfigTest, InvalidTypes) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"weighted_round_robin_experimental\":{\n" + " \"enableOobLoadReport\": 5,\n" + " \"oobReportingPeriod\": true,\n" + " \"blackoutPeriod\": [],\n" + " \"weightUpdatePeriod\": {},\n" + " \"weightExpirationPeriod\": {}\n" + " }\n" + " }]\n" + "}\n"; + auto service_config = + ServiceConfigImpl::Create(ChannelArgs(), service_config_json); + ASSERT_FALSE(service_config.ok()); + EXPECT_EQ(service_config.status(), + absl::InvalidArgumentError( + "errors validating service config: [field:loadBalancingConfig " + "error:errors validating priority LB policy config: [" + "field:blackoutPeriod error:is not a string; " + "field:enableOobLoadReport error:is not a boolean; " + "field:oobReportingPeriod error:is not a string; " + "field:weightExpirationPeriod error:is not a string; " + "field:weightUpdatePeriod error:is not a string]]")); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(&argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/client_channel/lb_policy/weighted_round_robin_test.cc b/test/core/client_channel/lb_policy/weighted_round_robin_test.cc new file mode 100644 index 00000000000..2b631872b22 --- /dev/null +++ b/test/core/client_channel/lb_policy/weighted_round_robin_test.cc @@ -0,0 +1,626 @@ +// +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/backend_metric_data.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/time.h" +#include "src/core/lib/gprpp/unique_type_name.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/load_balancing/lb_policy.h" +#include "test/core/client_channel/lb_policy/lb_policy_test_lib.h" +#include "test/core/event_engine/mock_event_engine.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +using ::grpc_event_engine::experimental::EventEngine; +using ::grpc_event_engine::experimental::MockEventEngine; + +class WeightedRoundRobinTest : public LoadBalancingPolicyTest { + protected: + class ConfigBuilder { + public: + ConfigBuilder() { + // Set blackout period to 1s to make tests fast and deterministic. + SetBlackoutPeriod(Duration::Seconds(1)); + } + + ConfigBuilder& SetEnableOobLoadReport(bool value) { + json_["enableOobLoadReport"] = value; + return *this; + } + ConfigBuilder& SetOobReportingPeriod(Duration duration) { + json_["oobReportingPeriod"] = duration.ToJsonString(); + return *this; + } + ConfigBuilder& SetBlackoutPeriod(Duration duration) { + json_["blackoutPeriod"] = duration.ToJsonString(); + return *this; + } + ConfigBuilder& SetWeightUpdatePeriod(Duration duration) { + json_["weightUpdatePeriod"] = duration.ToJsonString(); + return *this; + } + ConfigBuilder& SetWeightExpirationPeriod(Duration duration) { + json_["weightExpirationPeriod"] = duration.ToJsonString(); + return *this; + } + + RefCountedPtr Build() { + Json config = Json::Array{ + Json::Object{{"weighted_round_robin_experimental", json_}}}; + gpr_log(GPR_INFO, "CONFIG: %s", config.Dump().c_str()); + return MakeConfig(config); + } + + private: + Json::Object json_; + }; + + // A custom time cache for which InvalidateCache() is a no-op. This + // ensures that when the timer callback instantiates its own ExecCtx + // and therefore its own ScopedTimeCache, it continues to see the time + // that we are injecting in the test. + class TestTimeCache final : public Timestamp::ScopedSource { + public: + TestTimeCache() : cached_time_(previous()->Now()) {} + + Timestamp Now() override { return cached_time_; } + void InvalidateCache() override {} + + void IncrementBy(Duration duration) { cached_time_ += duration; } + + private: + Timestamp cached_time_; + }; + + WeightedRoundRobinTest() { + mock_ee_ = std::make_shared(); + event_engine_ = mock_ee_; + auto capture = [this](std::chrono::duration duration, + absl::AnyInvocable callback) { + EXPECT_EQ(duration, expected_weight_update_interval_) + << "Expected: " << expected_weight_update_interval_.count() << "ns" + << "\n Actual: " << duration.count() << "ns"; + intptr_t key = next_key_++; + timer_callbacks_[key] = std::move(callback); + return EventEngine::TaskHandle{key, 0}; + }; + ON_CALL(*mock_ee_, + RunAfter(::testing::_, ::testing::A>())) + .WillByDefault(capture); + auto cancel = [this](EventEngine::TaskHandle handle) { + auto it = timer_callbacks_.find(handle.keys[0]); + if (it == timer_callbacks_.end()) return false; + timer_callbacks_.erase(it); + return true; + }; + ON_CALL(*mock_ee_, Cancel(::testing::_)).WillByDefault(cancel); + lb_policy_ = MakeLbPolicy("weighted_round_robin_experimental"); + } + + ~WeightedRoundRobinTest() override { + EXPECT_TRUE(timer_callbacks_.empty()) + << "WARNING: Test did not run all timer callbacks"; + } + + void RunTimerCallback() { + ASSERT_EQ(timer_callbacks_.size(), 1UL); + auto it = timer_callbacks_.begin(); + ASSERT_NE(it->second, nullptr); + std::move(it->second)(); + timer_callbacks_.erase(it); + } + + RefCountedPtr + SendInitialUpdateAndWaitForConnected( + absl::Span addresses, + ConfigBuilder config_builder = ConfigBuilder(), + absl::Span update_addresses = {}, + SourceLocation location = SourceLocation()) { + if (update_addresses.empty()) update_addresses = addresses; + EXPECT_EQ(ApplyUpdate(BuildUpdate(update_addresses, config_builder.Build()), + lb_policy_.get()), + absl::OkStatus()); + // Expect the initial CONNECTNG update with a picker that queues. + ExpectConnectingUpdate(location); + // RR should have created a subchannel for each address. + for (size_t i = 0; i < addresses.size(); ++i) { + auto* subchannel = FindSubchannel(addresses[i]); + EXPECT_NE(subchannel, nullptr) + << addresses[i] << " at " << location.file() << ":" + << location.line(); + if (subchannel == nullptr) return nullptr; + // RR should ask each subchannel to connect. + EXPECT_TRUE(subchannel->ConnectionRequested()) + << addresses[i] << " at " << location.file() << ":" + << location.line(); + // The subchannel will connect successfully. + subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); + subchannel->SetConnectivityState(GRPC_CHANNEL_READY); + } + return WaitForConnected(location); + } + + // Returns a map indicating the number of picks for each address. + static std::map MakePickMap( + absl::Span picks) { + std::map actual; + for (const auto& address : picks) { + ++actual.emplace(address, 0).first->second; + } + return actual; + } + + // Returns a human-readable string representing the number of picks + // for each address. + static std::string PickMapString( + std::map pick_map) { + return absl::StrJoin(pick_map, ",", absl::PairFormatter("=")); + } + + // Returns the number of picks we need to do to check the specified + // expectations. + static size_t NumPicksNeeded(const std::map& expected) { + size_t num_picks = 0; + for (const auto& p : expected) { + num_picks += p.second; + } + return num_picks; + } + + // For each pick in picks, reports the backend metrics to the LB policy. + static void ReportBackendMetrics( + absl::Span picks, + const std::vector< + std::unique_ptr>& + subchannel_call_trackers, + const std::map>& + backend_metrics) { + for (size_t i = 0; i < picks.size(); ++i) { + const auto& address = picks[i]; + auto& subchannel_call_tracker = subchannel_call_trackers[i]; + if (subchannel_call_tracker != nullptr) { + subchannel_call_tracker->Start(); + absl::optional backend_metric_data; + auto it = backend_metrics.find(address); + if (it != backend_metrics.end()) { + backend_metric_data.emplace(); + backend_metric_data->qps = it->second.first; + backend_metric_data->cpu_utilization = it->second.second; + } + FakeMetadata metadata({}); + FakeBackendMetricAccessor backend_metric_accessor( + std::move(backend_metric_data)); + LoadBalancingPolicy::SubchannelCallTrackerInterface::FinishArgs args = { + address, absl::OkStatus(), &metadata, &backend_metric_accessor}; + subchannel_call_tracker->Finish(args); + } + } + } + + void ReportOobBackendMetrics( + std::map> + backend_metrics) { + for (const auto& p : backend_metrics) { + auto* subchannel = FindSubchannel(p.first); + BackendMetricData backend_metric_data; + backend_metric_data.qps = p.second.first; + backend_metric_data.cpu_utilization = p.second.second; + subchannel->SendOobBackendMetricReport(backend_metric_data); + } + } + + void ExpectWeightedRoundRobinPicks( + LoadBalancingPolicy::SubchannelPicker* picker, + std::map> + backend_metrics, + std::map expected, + SourceLocation location = SourceLocation()) { + std::vector< + std::unique_ptr> + subchannel_call_trackers; + auto picks = GetCompletePicks(picker, NumPicksNeeded(expected), {}, + &subchannel_call_trackers, location); + ASSERT_TRUE(picks.has_value()) << location.file() << ":" << location.line(); + gpr_log(GPR_INFO, "PICKS: %s", absl::StrJoin(*picks, " ").c_str()); + ReportBackendMetrics(*picks, subchannel_call_trackers, backend_metrics); + auto actual = MakePickMap(*picks); + gpr_log(GPR_INFO, "Pick map: %s", PickMapString(actual).c_str()); + EXPECT_EQ(expected, actual) + << "Expected: " << PickMapString(expected) + << "\nActual: " << PickMapString(actual) << "\nat " << location.file() + << ":" << location.line(); + } + + bool WaitForWeightedRoundRobinPicks( + RefCountedPtr* picker, + std::map> + backend_metrics, + std::map expected, + absl::Duration timeout = absl::Seconds(5), + SourceLocation location = SourceLocation()) { + gpr_log(GPR_INFO, "==> WaitForWeightedRoundRobinPicks(): Expecting %s", + PickMapString(expected).c_str()); + size_t num_picks = NumPicksNeeded(expected); + absl::Time deadline = absl::Now() + timeout; + while (true) { + gpr_log(GPR_INFO, "TOP OF LOOP"); + // We need to see the expected weights for 3 consecutive passes, just + // to make sure we're consistently returning the right weights. + size_t num_passes = 0; + for (; num_passes < 3; ++num_passes) { + gpr_log(GPR_INFO, "PASS %" PRIuPTR ": DOING PICKS", num_passes); + std::vector> + subchannel_call_trackers; + auto picks = GetCompletePicks(picker->get(), num_picks, {}, + &subchannel_call_trackers, location); + EXPECT_TRUE(picks.has_value()) + << location.file() << ":" << location.line(); + if (!picks.has_value()) return false; + gpr_log(GPR_INFO, "PICKS: %s", absl::StrJoin(*picks, " ").c_str()); + // Report backend metrics to the LB policy. + ReportBackendMetrics(*picks, subchannel_call_trackers, backend_metrics); + // Check the observed weights. + auto actual = MakePickMap(*picks); + gpr_log(GPR_INFO, "Pick map:\nExpected: %s\n Actual: %s", + PickMapString(expected).c_str(), PickMapString(actual).c_str()); + if (expected != actual) { + // Make sure each address is one of the expected addresses, + // even if the weights aren't as expected. + for (const auto& address : *picks) { + bool found = expected.find(address) != expected.end(); + EXPECT_TRUE(found) + << "unexpected pick address " << address << " at " + << location.file() << ":" << location.line(); + if (!found) return false; + } + break; + } + // If there's another picker update in the queue, don't bother + // doing another pass, since we want to make sure we're using + // the latest picker. + if (!helper_->QueueEmpty()) break; + } + if (num_passes == 3) return true; + // If we're out of time, give up. + absl::Time now = absl::Now(); + EXPECT_LT(now, deadline) << location.file() << ":" << location.line(); + if (now >= deadline) return false; + // Get a new picker if there is an update; otherwise, wait for the + // weights to be recalculated. + if (!helper_->QueueEmpty()) { + *picker = ExpectState(GRPC_CHANNEL_READY, absl::OkStatus(), location); + EXPECT_NE(*picker, nullptr) + << location.file() << ":" << location.line(); + if (*picker == nullptr) return false; + } else { + gpr_log(GPR_INFO, "running timer callback..."); + RunTimerCallback(); + } + // Increment time. + time_cache_.IncrementBy(Duration::Seconds(1)); + } + } + + OrphanablePtr lb_policy_; + std::shared_ptr mock_ee_; + std::map> timer_callbacks_; + intptr_t next_key_ = 1; + EventEngine::Duration expected_weight_update_interval_ = + std::chrono::seconds(1); + TestTimeCache time_cache_; +}; + +TEST_F(WeightedRoundRobinTest, Basic) { + // Send address list to LB policy. + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + auto picker = SendInitialUpdateAndWaitForConnected(kAddresses); + ASSERT_NE(picker, nullptr); + // Address 0 gets weight 1, address 1 gets weight 3. + // No utilization report from backend 2, so it gets the average weight 2. + WaitForWeightedRoundRobinPicks( + &picker, {{kAddresses[0], {100, 0.9}}, {kAddresses[1], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 2}}); + // Now have backend 2 report utilization the same as backend 1, so its + // weight will be the same. + WaitForWeightedRoundRobinPicks( + &picker, + {{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); +} + +TEST_F(WeightedRoundRobinTest, IgnoresDuplicateAddresses) { + // Send address list to LB policy. + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + const std::array kUpdateAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443", + "ipv4:127.0.0.1:441"}; + auto picker = SendInitialUpdateAndWaitForConnected( + kAddresses, ConfigBuilder(), kUpdateAddresses); + ASSERT_NE(picker, nullptr); + // Address 0 gets weight 1, address 1 gets weight 3. + // No utilization report from backend 2, so it gets the average weight 2. + WaitForWeightedRoundRobinPicks( + &picker, {{kAddresses[0], {100, 0.9}}, {kAddresses[1], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 2}}); + // Now have backend 2 report utilization the same as backend 1, so its + // weight will be the same. + WaitForWeightedRoundRobinPicks( + &picker, + {{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); +} + +TEST_F(WeightedRoundRobinTest, FallsBackToRoundRobinWithoutWeights) { + // Send address list to LB policy. + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + auto picker = SendInitialUpdateAndWaitForConnected(kAddresses); + ASSERT_NE(picker, nullptr); + // Backends do not report utilization, so all are weighted the same. + WaitForWeightedRoundRobinPicks( + &picker, {}, + {{kAddresses[0], 1}, {kAddresses[1], 1}, {kAddresses[2], 1}}); +} + +TEST_F(WeightedRoundRobinTest, OobReporting) { + // Send address list to LB policy. + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + auto picker = SendInitialUpdateAndWaitForConnected( + kAddresses, ConfigBuilder().SetEnableOobLoadReport(true)); + ASSERT_NE(picker, nullptr); + // Address 0 gets weight 1, address 1 gets weight 3. + // No utilization report from backend 2, so it gets the average weight 2. + ReportOobBackendMetrics( + {{kAddresses[0], {100, 0.9}}, {kAddresses[1], {100, 0.3}}}); + WaitForWeightedRoundRobinPicks( + &picker, {}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 2}}); + // Now have backend 2 report utilization the same as backend 1, so its + // weight will be the same. + ReportOobBackendMetrics({{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}); + WaitForWeightedRoundRobinPicks( + &picker, {}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); + // Verify that OOB reporting interval is the default. + for (const auto& address : kAddresses) { + auto* subchannel = FindSubchannel(address); + ASSERT_NE(subchannel, nullptr); + subchannel->CheckOobReportingPeriod(Duration::Seconds(10)); + } +} + +TEST_F(WeightedRoundRobinTest, HonorsOobReportingPeriod) { + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + auto picker = SendInitialUpdateAndWaitForConnected( + kAddresses, + ConfigBuilder().SetEnableOobLoadReport(true).SetOobReportingPeriod( + Duration::Seconds(5))); + ASSERT_NE(picker, nullptr); + ReportOobBackendMetrics({{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}); + WaitForWeightedRoundRobinPicks( + &picker, {}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); + for (const auto& address : kAddresses) { + auto* subchannel = FindSubchannel(address); + ASSERT_NE(subchannel, nullptr); + subchannel->CheckOobReportingPeriod(Duration::Seconds(5)); + } +} + +TEST_F(WeightedRoundRobinTest, HonorsWeightUpdatePeriod) { + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + expected_weight_update_interval_ = std::chrono::seconds(2); + auto picker = SendInitialUpdateAndWaitForConnected( + kAddresses, ConfigBuilder().SetWeightUpdatePeriod(Duration::Seconds(2))); + ASSERT_NE(picker, nullptr); + WaitForWeightedRoundRobinPicks( + &picker, + {{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); +} + +TEST_F(WeightedRoundRobinTest, WeightUpdatePeriodLowerBound) { + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + expected_weight_update_interval_ = std::chrono::milliseconds(100); + auto picker = SendInitialUpdateAndWaitForConnected( + kAddresses, + ConfigBuilder().SetWeightUpdatePeriod(Duration::Milliseconds(10))); + ASSERT_NE(picker, nullptr); + WaitForWeightedRoundRobinPicks( + &picker, + {{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); +} + +TEST_F(WeightedRoundRobinTest, WeightExpirationPeriod) { + // Send address list to LB policy. + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + auto picker = SendInitialUpdateAndWaitForConnected( + kAddresses, + ConfigBuilder().SetWeightExpirationPeriod(Duration::Seconds(2))); + ASSERT_NE(picker, nullptr); + // All backends report weights. + WaitForWeightedRoundRobinPicks( + &picker, + {{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); + // Advance time to make weights stale and trigger the timer callback + // to recompute weights. + time_cache_.IncrementBy(Duration::Seconds(2)); + RunTimerCallback(); + // Picker should now be falling back to round-robin. + ExpectWeightedRoundRobinPicks( + picker.get(), {}, + {{kAddresses[0], 3}, {kAddresses[1], 3}, {kAddresses[2], 3}}); +} + +TEST_F(WeightedRoundRobinTest, BlackoutPeriodAfterWeightExpiration) { + // Send address list to LB policy. + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + auto picker = SendInitialUpdateAndWaitForConnected( + kAddresses, + ConfigBuilder().SetWeightExpirationPeriod(Duration::Seconds(2))); + ASSERT_NE(picker, nullptr); + // All backends report weights. + WaitForWeightedRoundRobinPicks( + &picker, + {{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); + // Advance time to make weights stale and trigger the timer callback + // to recompute weights. + time_cache_.IncrementBy(Duration::Seconds(2)); + RunTimerCallback(); + // Picker should now be falling back to round-robin. + ExpectWeightedRoundRobinPicks( + picker.get(), {}, + {{kAddresses[0], 3}, {kAddresses[1], 3}, {kAddresses[2], 3}}); + // Now start sending weights again. They should not be used yet, + // because we're still in the blackout period. + ExpectWeightedRoundRobinPicks( + picker.get(), + {{kAddresses[0], {100, 0.3}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.9}}}, + {{kAddresses[0], 3}, {kAddresses[1], 3}, {kAddresses[2], 3}}); + // Advance time past the blackout period. This should cause the + // weights to be used. + time_cache_.IncrementBy(Duration::Seconds(1)); + RunTimerCallback(); + ExpectWeightedRoundRobinPicks( + picker.get(), {}, + {{kAddresses[0], 3}, {kAddresses[1], 3}, {kAddresses[2], 1}}); +} + +TEST_F(WeightedRoundRobinTest, BlackoutPeriodAfterDisconnect) { + // Send address list to LB policy. + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + auto picker = SendInitialUpdateAndWaitForConnected( + kAddresses, + ConfigBuilder().SetWeightExpirationPeriod(Duration::Seconds(2))); + ASSERT_NE(picker, nullptr); + // All backends report weights. + WaitForWeightedRoundRobinPicks( + &picker, + {{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); + // Trigger disconnection and reconnection on address 2. + auto* subchannel = FindSubchannel(kAddresses[2]); + subchannel->SetConnectivityState(GRPC_CHANNEL_IDLE); + ExpectReresolutionRequest(); + EXPECT_TRUE(subchannel->ConnectionRequested()); + subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); + subchannel->SetConnectivityState(GRPC_CHANNEL_READY); + // Wait for the address to come back. Note that we have not advanced + // time, so the address will still be in the blackout period, + // resulting in it being assigned the average weight. + picker = ExpectState(GRPC_CHANNEL_READY, absl::OkStatus()); + WaitForWeightedRoundRobinPicks( + &picker, + {{kAddresses[0], {100, 0.9}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.3}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 2}}); + // Advance time to exceed the blackout period and trigger the timer + // callback to recompute weights. + time_cache_.IncrementBy(Duration::Seconds(1)); + RunTimerCallback(); + ExpectWeightedRoundRobinPicks( + picker.get(), + {{kAddresses[0], {100, 0.3}}, + {kAddresses[1], {100, 0.3}}, + {kAddresses[2], {100, 0.9}}}, + {{kAddresses[0], 1}, {kAddresses[1], 3}, {kAddresses[2], 3}}); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/cpp/end2end/client_lb_end2end_test.cc b/test/cpp/end2end/client_lb_end2end_test.cc index 73b5255485d..b853676c6d3 100644 --- a/test/cpp/end2end/client_lb_end2end_test.cc +++ b/test/cpp/end2end/client_lb_end2end_test.cc @@ -94,18 +94,49 @@ class MyTestServiceImpl : public TestServiceImpl { ++request_count_; } AddClient(context->peer()); + absl::optional load_report; + { + grpc_core::MutexLock lock(&load_report_mu_); + load_report = load_report_; + } if (request->has_param() && request->param().has_backend_metrics()) { - load_report_ = request->param().backend_metrics(); + if (!load_report.has_value()) load_report.emplace(); + const auto& request_metrics = request->param().backend_metrics(); + if (request_metrics.cpu_utilization() > 0) { + load_report->set_cpu_utilization(request_metrics.cpu_utilization()); + } + if (request_metrics.mem_utilization() > 0) { + load_report->set_mem_utilization(request_metrics.mem_utilization()); + } + if (request_metrics.rps_fractional() > 0) { + load_report->set_rps_fractional(request_metrics.rps_fractional()); + } + for (const auto& p : request_metrics.request_cost()) { + (*load_report->mutable_request_cost())[p.first] = p.second; + } + for (const auto& p : request_metrics.utilization()) { + (*load_report->mutable_utilization())[p.first] = p.second; + } + } + if (load_report.has_value()) { auto* recorder = context->ExperimentalGetCallMetricRecorder(); EXPECT_NE(recorder, nullptr); - recorder->RecordCpuUtilizationMetric(load_report_.cpu_utilization()) - .RecordMemoryUtilizationMetric(load_report_.mem_utilization()) - .RecordQpsMetric(load_report_.rps_fractional()); - for (const auto& p : load_report_.request_cost()) { - recorder->RecordRequestCostMetric(p.first, p.second); + recorder->RecordCpuUtilizationMetric(load_report->cpu_utilization()) + .RecordMemoryUtilizationMetric(load_report->mem_utilization()) + .RecordQpsMetric(load_report->rps_fractional()); + for (const auto& p : load_report->request_cost()) { + char* key = static_cast( + grpc_call_arena_alloc(context->c_call(), p.first.size() + 1)); + strncpy(key, p.first.data(), p.first.size()); + key[p.first.size()] = '\0'; + recorder->RecordRequestCostMetric(key, p.second); } - for (const auto& p : load_report_.utilization()) { - recorder->RecordUtilizationMetric(p.first, p.second); + for (const auto& p : load_report->utilization()) { + char* key = static_cast( + grpc_call_arena_alloc(context->c_call(), p.first.size() + 1)); + strncpy(key, p.first.data(), p.first.size()); + key[p.first.size()] = '\0'; + recorder->RecordUtilizationMetric(key, p.second); } } return TestServiceImpl::Echo(context, request, response); @@ -126,6 +157,15 @@ class MyTestServiceImpl : public TestServiceImpl { return clients_; } + // TODO(roth): Once the backend utilization APIs are updated, change + // this to use those instead of manually constructing the data for + // each call. + void SetLoadReport( + absl::optional load_report) { + grpc_core::MutexLock lock(&load_report_mu_); + load_report_ = std::move(load_report); + } + private: void AddClient(const std::string& client) { grpc_core::MutexLock lock(&clients_mu_); @@ -133,11 +173,14 @@ class MyTestServiceImpl : public TestServiceImpl { } grpc_core::Mutex mu_; - int request_count_ = 0; + int request_count_ ABSL_GUARDED_BY(&mu_) = 0; + grpc_core::Mutex clients_mu_; - std::set clients_; - // For strings storage. - xds::data::orca::v3::OrcaLoadReport load_report_; + std::set clients_ ABSL_GUARDED_BY(&clients_mu_); + + grpc_core::Mutex load_report_mu_; + absl::optional load_report_ + ABSL_GUARDED_BY(&load_report_mu_); }; class FakeResolverResponseGeneratorWrapper { @@ -218,7 +261,7 @@ class FakeResolverResponseGeneratorWrapper { if (service_config_json != nullptr) { result.service_config = grpc_core::ServiceConfigImpl::Create( grpc_core::ChannelArgs(), service_config_json); - GPR_ASSERT(result.service_config.ok()); + EXPECT_TRUE(result.service_config.ok()) << result.service_config.status(); } return result; } @@ -2835,6 +2878,106 @@ TEST_F(ControlPlaneStatusRewritingTest, RewritesFromConfigSelector) { "ABORTED: nope"); } +// +// WeightedRoundRobinTest +// + +using WeightedRoundRobinTest = ClientLbEnd2endTest; + +TEST_F(WeightedRoundRobinTest, Basic) { + const int kNumServers = 3; + StartServers(kNumServers); + // Tell each server to report the appropriate CPU utilization. + xds::data::orca::v3::OrcaLoadReport load_report; + load_report.set_rps_fractional(100); + load_report.set_cpu_utilization(0.9); + servers_[0]->service_.SetLoadReport(load_report); + load_report.set_cpu_utilization(0.3); + servers_[1]->service_.SetLoadReport(load_report); + servers_[2]->service_.SetLoadReport(load_report); + // Create channel. + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("", response_generator); + auto stub = BuildStub(channel); + const char kServiceConfig[] = + "{\n" + " \"loadBalancingConfig\": [\n" + " {\"weighted_round_robin_experimental\": {\n" + " \"blackoutPeriod\": \"0s\"\n" + " }}\n" + " ]\n" + "}"; + response_generator.SetNextResolution(GetServersPorts(), kServiceConfig); + // Wait for the right set of WRR picks. + size_t num_picks = 0; + SendRpcsUntil(DEBUG_LOCATION, stub, [&](const Status&) { + if (++num_picks == 7) { + gpr_log(GPR_INFO, "request counts: %d %d %d", + servers_[0]->service_.request_count(), + servers_[1]->service_.request_count(), + servers_[2]->service_.request_count()); + if (servers_[0]->service_.request_count() == 1 && + servers_[1]->service_.request_count() == 3 && + servers_[2]->service_.request_count() == 3) { + return false; + } + num_picks = 0; + ResetCounters(); + } + return true; + }); + // Check LB policy name for the channel. + EXPECT_EQ("weighted_round_robin_experimental", + channel->GetLoadBalancingPolicyName()); +} + +TEST_F(WeightedRoundRobinTest, OobReporting) { + const int kNumServers = 3; + StartServers(kNumServers); + // Tell each server to report the appropriate CPU utilization. + servers_[0]->orca_service_.SetCpuUtilization(0.9); + servers_[0]->orca_service_.SetQps(100); + servers_[1]->orca_service_.SetCpuUtilization(0.3); + servers_[1]->orca_service_.SetQps(100); + servers_[2]->orca_service_.SetCpuUtilization(0.3); + servers_[2]->orca_service_.SetQps(100); + // Create channel. + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("", response_generator); + auto stub = BuildStub(channel); + const char kServiceConfig[] = + "{\n" + " \"loadBalancingConfig\": [\n" + " {\"weighted_round_robin_experimental\": {\n" + " \"blackoutPeriod\": \"0s\",\n" + " \"enableOobLoadReport\": true\n" + " }}\n" + " ]\n" + "}"; + response_generator.SetNextResolution(GetServersPorts(), kServiceConfig); + // Wait for the right set of WRR picks. + size_t num_picks = 0; + SendRpcsUntil(DEBUG_LOCATION, stub, [&](const Status&) { + if (++num_picks == 7) { + gpr_log(GPR_INFO, "request counts: %d %d %d", + servers_[0]->service_.request_count(), + servers_[1]->service_.request_count(), + servers_[2]->service_.request_count()); + if (servers_[0]->service_.request_count() == 1 && + servers_[1]->service_.request_count() == 3 && + servers_[2]->service_.request_count() == 3) { + return false; + } + num_picks = 0; + ResetCounters(); + } + return true; + }); + // Check LB policy name for the channel. + EXPECT_EQ("weighted_round_robin_experimental", + channel->GetLoadBalancingPolicyName()); +} + } // namespace } // namespace testing } // namespace grpc diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index 4a44b107db3..5a03968ae3a 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -1119,6 +1119,7 @@ src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.cc \ src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h \ src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc \ src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h \ +src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h \ src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.cc \ src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h \ src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc \ @@ -1128,6 +1129,9 @@ src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h \ src/core/ext/filters/client_channel/lb_policy/rls/rls.cc \ src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc \ src/core/ext/filters/client_channel/lb_policy/subchannel_list.h \ +src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc \ +src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h \ +src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc \ src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc \ src/core/ext/filters/client_channel/lb_policy/xds/cds.cc \ src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc \ diff --git a/tools/doxygen/Doxyfile.core.internal b/tools/doxygen/Doxyfile.core.internal index a98ab91f165..9a94bf96246 100644 --- a/tools/doxygen/Doxyfile.core.internal +++ b/tools/doxygen/Doxyfile.core.internal @@ -926,6 +926,7 @@ src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.cc \ src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h \ src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc \ src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.h \ +src/core/ext/filters/client_channel/lb_policy/oob_backend_metric_internal.h \ src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.cc \ src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.h \ src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc \ @@ -935,6 +936,9 @@ src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h \ src/core/ext/filters/client_channel/lb_policy/rls/rls.cc \ src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc \ src/core/ext/filters/client_channel/lb_policy/subchannel_list.h \ +src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc \ +src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h \ +src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc \ src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc \ src/core/ext/filters/client_channel/lb_policy/xds/cds.cc \ src/core/ext/filters/client_channel/lb_policy/xds/xds_attributes.cc \ diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 76d3fff75f6..b5d1b1fa77a 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -8287,6 +8287,54 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "weighted_round_robin_config_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "weighted_round_robin_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false,