weighted_target LB: use uint64_t for aggregate weights to avoid overflow (#31244)

* weighted_target LB: use uint64_t for aggregate weights to avoid overflow

* iwyu

* fix undefined behavior

* iwyu

* iwyu again

* fix test weights to sum to uint32 max
pull/31307/head
Mark D. Roth 3 years ago committed by GitHub
parent 42482060fc
commit ede4e42c7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      BUILD
  2. 22
      src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc
  3. 39
      test/cpp/end2end/xds/xds_cluster_end2end_test.cc
  4. 6
      test/cpp/end2end/xds/xds_end2end_test_lib.h

@ -5426,6 +5426,7 @@ grpc_cc_library(
"src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc",
],
external_deps = [
"absl/random",
"absl/status",
"absl/status:statusor",
"absl/strings",

@ -16,7 +16,7 @@
#include <grpc/support/port_platform.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <cstdint>
@ -26,6 +26,7 @@
#include <utility>
#include <vector>
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
@ -141,7 +142,7 @@ class WeightedTargetLb : public LoadBalancingPolicy {
// range proportional to the child's weight. The start of the range
// is the previous value in the vector and is 0 for the first element.
using PickerList =
std::vector<std::pair<uint32_t, RefCountedPtr<ChildPickerWrapper>>>;
std::vector<std::pair<uint64_t, RefCountedPtr<ChildPickerWrapper>>>;
explicit WeightedPicker(PickerList pickers)
: pickers_(std::move(pickers)) {}
@ -150,6 +151,7 @@ class WeightedTargetLb : public LoadBalancingPolicy {
private:
PickerList pickers_;
absl::BitGen bit_gen_;
};
// Each WeightedChild holds a ref to its parent WeightedTargetLb.
@ -226,7 +228,7 @@ class WeightedTargetLb : public LoadBalancingPolicy {
const std::string name_;
uint32_t weight_;
uint32_t weight_ = 0;
OrphanablePtr<LoadBalancingPolicy> child_policy_;
@ -260,7 +262,8 @@ class WeightedTargetLb : public LoadBalancingPolicy {
WeightedTargetLb::PickResult WeightedTargetLb::WeightedPicker::Pick(
PickArgs args) {
// Generate a random number in [0, total weight).
const uint32_t key = rand() % pickers_[pickers_.size() - 1].first;
const uint64_t key =
absl::Uniform<uint64_t>(bit_gen_, 0, pickers_.back().first);
// Find the index in pickers_ corresponding to key.
size_t mid = 0;
size_t start_index = 0;
@ -393,9 +396,9 @@ void WeightedTargetLb::UpdateStateLocked() {
// the range proportional to its weight, such that the total range is the
// sum of the weights of all children.
WeightedPicker::PickerList ready_picker_list;
uint32_t ready_end = 0;
uint64_t ready_end = 0;
WeightedPicker::PickerList tf_picker_list;
uint32_t tf_end = 0;
uint64_t tf_end = 0;
// Also count the number of children in CONNECTING and IDLE, to determine
// the aggregated state.
size_t num_connecting = 0;
@ -409,7 +412,7 @@ void WeightedTargetLb::UpdateStateLocked() {
}
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) {
gpr_log(GPR_INFO,
"[weighted_target_lb %p] child=%s state=%s weight=%d picker=%p",
"[weighted_target_lb %p] child=%s state=%s weight=%u picker=%p",
this, child_name.c_str(),
ConnectivityStateName(child->connectivity_state()),
child->weight(), child->picker_wrapper().get());
@ -586,6 +589,11 @@ absl::Status WeightedTargetLb::WeightedChild::UpdateLocked(
const std::string& resolution_note, const ChannelArgs& args) {
if (weighted_target_policy_->shutting_down_) return absl::OkStatus();
// Update child weight.
if (weight_ != config.weight &&
GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) {
gpr_log(GPR_INFO, "[weighted_target_lb %p] WeightedChild %p %s: weight=%u",
weighted_target_policy_.get(), this, name_.c_str(), config.weight);
}
weight_ = config.weight;
// Reactivate if needed.
if (delayed_removal_timer_ != nullptr) {

@ -649,6 +649,45 @@ TEST_P(EdsTest, WeightedRoundRobin) {
::testing::DoubleNear(kLocalityWeightRate1, kErrorTolerance));
}
// Tests that we don't suffer from integer overflow in locality weights.
TEST_P(EdsTest, NoIntegerOverflowInLocalityWeights) {
CreateAndStartBackends(2);
const uint32_t kLocalityWeight1 = std::numeric_limits<uint32_t>::max() / 3;
const uint32_t kLocalityWeight0 =
std::numeric_limits<uint32_t>::max() - kLocalityWeight1;
const uint64_t kTotalLocalityWeight =
static_cast<uint64_t>(kLocalityWeight0) +
static_cast<uint64_t>(kLocalityWeight1);
const double kLocalityWeightRate0 =
static_cast<double>(kLocalityWeight0) / kTotalLocalityWeight;
const double kLocalityWeightRate1 =
static_cast<double>(kLocalityWeight1) / kTotalLocalityWeight;
const double kErrorTolerance = 0.05;
const size_t kNumRpcs =
ComputeIdealNumRpcs(kLocalityWeightRate0, kErrorTolerance);
// ADS response contains 2 localities, each of which contains 1 backend.
EdsResourceArgs args({
{"locality0", CreateEndpointsForBackends(0, 1), kLocalityWeight0},
{"locality1", CreateEndpointsForBackends(1, 2), kLocalityWeight1},
});
balancer_->ads_service()->SetEdsResource(BuildEdsResource(args));
// Wait for both backends to be ready.
WaitForAllBackends(DEBUG_LOCATION, 0, 2);
// Send kNumRpcs RPCs.
CheckRpcSendOk(DEBUG_LOCATION, kNumRpcs);
// The locality picking rates should be roughly equal to the expectation.
const double locality_picked_rate_0 =
static_cast<double>(backends_[0]->backend_service()->request_count()) /
kNumRpcs;
const double locality_picked_rate_1 =
static_cast<double>(backends_[1]->backend_service()->request_count()) /
kNumRpcs;
EXPECT_THAT(locality_picked_rate_0,
::testing::DoubleNear(kLocalityWeightRate0, kErrorTolerance));
EXPECT_THAT(locality_picked_rate_1,
::testing::DoubleNear(kLocalityWeightRate1, kErrorTolerance));
}
// Tests that we correctly handle a locality containing no endpoints.
TEST_P(EdsTest, LocalityContainingNoEndpoints) {
CreateAndStartBackends(2);

@ -192,7 +192,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam<XdsTestType> {
// Default values for locality fields.
static const char kDefaultLocalityRegion[];
static const char kDefaultLocalityZone[];
static const int kDefaultLocalityWeight = 3;
static const uint32_t kDefaultLocalityWeight = 3;
static const int kDefaultLocalityPriority = 0;
// Default resource names.
@ -569,7 +569,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam<XdsTestType> {
// A locality.
struct Locality {
Locality(std::string sub_zone, std::vector<Endpoint> endpoints,
int lb_weight = kDefaultLocalityWeight,
uint32_t lb_weight = kDefaultLocalityWeight,
int priority = kDefaultLocalityPriority)
: sub_zone(std::move(sub_zone)),
endpoints(std::move(endpoints)),
@ -578,7 +578,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam<XdsTestType> {
const std::string sub_zone;
std::vector<Endpoint> endpoints;
int lb_weight;
uint32_t lb_weight;
int priority;
};

Loading…
Cancel
Save