From ede4e42c7d7cb18e81bc3c6c486a579a4e2ac25f Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Mon, 10 Oct 2022 10:29:01 -0700 Subject: [PATCH] weighted_target LB: use uint64_t for aggregate weights to avoid overflow (#31244) * weighted_target LB: use uint64_t for aggregate weights to avoid overflow * iwyu * fix undefined behavior * iwyu * iwyu again * fix test weights to sum to uint32 max --- BUILD | 1 + .../weighted_target/weighted_target.cc | 22 +++++++---- .../end2end/xds/xds_cluster_end2end_test.cc | 39 +++++++++++++++++++ test/cpp/end2end/xds/xds_end2end_test_lib.h | 6 +-- 4 files changed, 58 insertions(+), 10 deletions(-) diff --git a/BUILD b/BUILD index 5fd5ff2905f..7742c8d0669 100644 --- a/BUILD +++ b/BUILD @@ -5426,6 +5426,7 @@ grpc_cc_library( "src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc", ], external_deps = [ + "absl/random", "absl/status", "absl/status:statusor", "absl/strings", diff --git a/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc b/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc index d3787eb73f0..a8cf390873c 100644 --- a/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc +++ b/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc @@ -16,7 +16,7 @@ #include -#include +#include #include #include @@ -26,6 +26,7 @@ #include #include +#include "absl/random/random.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -141,7 +142,7 @@ class WeightedTargetLb : public LoadBalancingPolicy { // range proportional to the child's weight. The start of the range // is the previous value in the vector and is 0 for the first element. using PickerList = - std::vector>>; + std::vector>>; explicit WeightedPicker(PickerList pickers) : pickers_(std::move(pickers)) {} @@ -150,6 +151,7 @@ class WeightedTargetLb : public LoadBalancingPolicy { private: PickerList pickers_; + absl::BitGen bit_gen_; }; // Each WeightedChild holds a ref to its parent WeightedTargetLb. @@ -226,7 +228,7 @@ class WeightedTargetLb : public LoadBalancingPolicy { const std::string name_; - uint32_t weight_; + uint32_t weight_ = 0; OrphanablePtr child_policy_; @@ -260,7 +262,8 @@ class WeightedTargetLb : public LoadBalancingPolicy { WeightedTargetLb::PickResult WeightedTargetLb::WeightedPicker::Pick( PickArgs args) { // Generate a random number in [0, total weight). - const uint32_t key = rand() % pickers_[pickers_.size() - 1].first; + const uint64_t key = + absl::Uniform(bit_gen_, 0, pickers_.back().first); // Find the index in pickers_ corresponding to key. size_t mid = 0; size_t start_index = 0; @@ -393,9 +396,9 @@ void WeightedTargetLb::UpdateStateLocked() { // the range proportional to its weight, such that the total range is the // sum of the weights of all children. WeightedPicker::PickerList ready_picker_list; - uint32_t ready_end = 0; + uint64_t ready_end = 0; WeightedPicker::PickerList tf_picker_list; - uint32_t tf_end = 0; + uint64_t tf_end = 0; // Also count the number of children in CONNECTING and IDLE, to determine // the aggregated state. size_t num_connecting = 0; @@ -409,7 +412,7 @@ void WeightedTargetLb::UpdateStateLocked() { } if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { gpr_log(GPR_INFO, - "[weighted_target_lb %p] child=%s state=%s weight=%d picker=%p", + "[weighted_target_lb %p] child=%s state=%s weight=%u picker=%p", this, child_name.c_str(), ConnectivityStateName(child->connectivity_state()), child->weight(), child->picker_wrapper().get()); @@ -586,6 +589,11 @@ absl::Status WeightedTargetLb::WeightedChild::UpdateLocked( const std::string& resolution_note, const ChannelArgs& args) { if (weighted_target_policy_->shutting_down_) return absl::OkStatus(); // Update child weight. + if (weight_ != config.weight && + GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, "[weighted_target_lb %p] WeightedChild %p %s: weight=%u", + weighted_target_policy_.get(), this, name_.c_str(), config.weight); + } weight_ = config.weight; // Reactivate if needed. if (delayed_removal_timer_ != nullptr) { diff --git a/test/cpp/end2end/xds/xds_cluster_end2end_test.cc b/test/cpp/end2end/xds/xds_cluster_end2end_test.cc index a1a9f2bf554..99d0cf13e29 100644 --- a/test/cpp/end2end/xds/xds_cluster_end2end_test.cc +++ b/test/cpp/end2end/xds/xds_cluster_end2end_test.cc @@ -649,6 +649,45 @@ TEST_P(EdsTest, WeightedRoundRobin) { ::testing::DoubleNear(kLocalityWeightRate1, kErrorTolerance)); } +// Tests that we don't suffer from integer overflow in locality weights. +TEST_P(EdsTest, NoIntegerOverflowInLocalityWeights) { + CreateAndStartBackends(2); + const uint32_t kLocalityWeight1 = std::numeric_limits::max() / 3; + const uint32_t kLocalityWeight0 = + std::numeric_limits::max() - kLocalityWeight1; + const uint64_t kTotalLocalityWeight = + static_cast(kLocalityWeight0) + + static_cast(kLocalityWeight1); + const double kLocalityWeightRate0 = + static_cast(kLocalityWeight0) / kTotalLocalityWeight; + const double kLocalityWeightRate1 = + static_cast(kLocalityWeight1) / kTotalLocalityWeight; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kLocalityWeightRate0, kErrorTolerance); + // ADS response contains 2 localities, each of which contains 1 backend. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kLocalityWeight0}, + {"locality1", CreateEndpointsForBackends(1, 2), kLocalityWeight1}, + }); + balancer_->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Wait for both backends to be ready. + WaitForAllBackends(DEBUG_LOCATION, 0, 2); + // Send kNumRpcs RPCs. + CheckRpcSendOk(DEBUG_LOCATION, kNumRpcs); + // The locality picking rates should be roughly equal to the expectation. + const double locality_picked_rate_0 = + static_cast(backends_[0]->backend_service()->request_count()) / + kNumRpcs; + const double locality_picked_rate_1 = + static_cast(backends_[1]->backend_service()->request_count()) / + kNumRpcs; + EXPECT_THAT(locality_picked_rate_0, + ::testing::DoubleNear(kLocalityWeightRate0, kErrorTolerance)); + EXPECT_THAT(locality_picked_rate_1, + ::testing::DoubleNear(kLocalityWeightRate1, kErrorTolerance)); +} + // Tests that we correctly handle a locality containing no endpoints. TEST_P(EdsTest, LocalityContainingNoEndpoints) { CreateAndStartBackends(2); diff --git a/test/cpp/end2end/xds/xds_end2end_test_lib.h b/test/cpp/end2end/xds/xds_end2end_test_lib.h index f003624aebd..8ee103f6b61 100644 --- a/test/cpp/end2end/xds/xds_end2end_test_lib.h +++ b/test/cpp/end2end/xds/xds_end2end_test_lib.h @@ -192,7 +192,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam { // Default values for locality fields. static const char kDefaultLocalityRegion[]; static const char kDefaultLocalityZone[]; - static const int kDefaultLocalityWeight = 3; + static const uint32_t kDefaultLocalityWeight = 3; static const int kDefaultLocalityPriority = 0; // Default resource names. @@ -569,7 +569,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam { // A locality. struct Locality { Locality(std::string sub_zone, std::vector endpoints, - int lb_weight = kDefaultLocalityWeight, + uint32_t lb_weight = kDefaultLocalityWeight, int priority = kDefaultLocalityPriority) : sub_zone(std::move(sub_zone)), endpoints(std::move(endpoints)), @@ -578,7 +578,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam { const std::string sub_zone; std::vector endpoints; - int lb_weight; + uint32_t lb_weight; int priority; };