WRR: port StaticStrideScheduler to OSS (#31893)
* WRR: port StaticStrideScheduler to OSS * Automated change: Fix sanity tests * fix build * remove unused aliases * fix another type mismatch * remove unnecessary include * move benchmarks to their own file, and don't run it on windows * Automated change: Fix sanity tests Co-authored-by: markdroth <markdroth@users.noreply.github.com>pull/31905/head^2
parent
6a97f492ff
commit
f3419430df
9 changed files with 723 additions and 0 deletions
@ -0,0 +1,128 @@ |
|||||||
|
//
|
||||||
|
// Copyright 2022 gRPC authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <grpc/support/port_platform.h> |
||||||
|
|
||||||
|
#include "src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h" |
||||||
|
|
||||||
|
#include <algorithm> |
||||||
|
#include <cmath> |
||||||
|
#include <limits> |
||||||
|
#include <utility> |
||||||
|
#include <vector> |
||||||
|
|
||||||
|
#include "absl/functional/any_invocable.h" |
||||||
|
|
||||||
|
#include <grpc/support/log.h> |
||||||
|
|
||||||
|
namespace grpc_core { |
||||||
|
|
||||||
|
namespace { |
||||||
|
constexpr uint16_t kMaxWeight = std::numeric_limits<uint16_t>::max(); |
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::optional<StaticStrideScheduler> StaticStrideScheduler::Make( |
||||||
|
absl::Span<const float> float_weights, |
||||||
|
absl::AnyInvocable<uint32_t()> next_sequence_func) { |
||||||
|
if (float_weights.empty()) return absl::nullopt; |
||||||
|
if (float_weights.size() == 1) return absl::nullopt; |
||||||
|
|
||||||
|
// TODO(b/190488683): should we normalize negative weights to 0?
|
||||||
|
|
||||||
|
const size_t n = float_weights.size(); |
||||||
|
size_t num_zero_weight_channels = 0; |
||||||
|
double sum = 0; |
||||||
|
float max = 0; |
||||||
|
for (const float weight : float_weights) { |
||||||
|
sum += weight; |
||||||
|
max = std::max(max, weight); |
||||||
|
if (weight == 0) { |
||||||
|
++num_zero_weight_channels; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if (num_zero_weight_channels == n) return absl::nullopt; |
||||||
|
|
||||||
|
// Mean of non-zero weights before scaling to `kMaxWeight`.
|
||||||
|
const double unscaled_mean = |
||||||
|
sum / static_cast<double>(n - num_zero_weight_channels); |
||||||
|
|
||||||
|
// Scale weights such that the largest is equal to `kMaxWeight`. This should
|
||||||
|
// be accurate enough once we convert to an integer. Quantisation errors won't
|
||||||
|
// be measurable on borg.
|
||||||
|
// TODO(b/190488683): it may be more stable over updates if we try to keep
|
||||||
|
// `scaling_factor` consistent, and only change it when we can't accurately
|
||||||
|
// represent the new weights.
|
||||||
|
const double scaling_factor = kMaxWeight / max; |
||||||
|
const uint16_t mean = std::lround(scaling_factor * unscaled_mean); |
||||||
|
|
||||||
|
std::vector<uint16_t> weights; |
||||||
|
weights.reserve(n); |
||||||
|
for (size_t i = 0; i < n; ++i) { |
||||||
|
weights.push_back(float_weights[i] == 0 |
||||||
|
? mean |
||||||
|
: std::lround(float_weights[i] * scaling_factor)); |
||||||
|
} |
||||||
|
|
||||||
|
GPR_ASSERT(weights.size() == float_weights.size()); |
||||||
|
return StaticStrideScheduler{std::move(weights), |
||||||
|
std::move(next_sequence_func)}; |
||||||
|
} |
||||||
|
|
||||||
|
StaticStrideScheduler::StaticStrideScheduler( |
||||||
|
std::vector<uint16_t> weights, |
||||||
|
absl::AnyInvocable<uint32_t()> next_sequence_func) |
||||||
|
: next_sequence_func_(std::move(next_sequence_func)), |
||||||
|
weights_(std::move(weights)) { |
||||||
|
GPR_ASSERT(next_sequence_func_ != nullptr); |
||||||
|
} |
||||||
|
|
||||||
|
size_t StaticStrideScheduler::Pick() const { |
||||||
|
while (true) { |
||||||
|
const uint32_t sequence = next_sequence_func_(); |
||||||
|
|
||||||
|
// The sequence number is split in two: the lower %n gives the index of the
|
||||||
|
// backend, and the rest gives the number of times we've iterated through
|
||||||
|
// all backends. `generation` is used to deterministically decide whether
|
||||||
|
// we pick or skip the backend on this iteration, in proportion to the
|
||||||
|
// backend's weight.
|
||||||
|
const uint64_t backend_index = sequence % weights_.size(); |
||||||
|
const uint64_t generation = sequence / weights_.size(); |
||||||
|
const uint64_t weight = weights_[backend_index]; |
||||||
|
|
||||||
|
// We pick a backend `weight` times per `kMaxWeight` generations. The
|
||||||
|
// multiply and modulus ~evenly spread out the picks for a given backend
|
||||||
|
// between different generations. The offset by `backend_index` helps to
|
||||||
|
// reduce the chance of multiple consecutive non-picks: if we have two
|
||||||
|
// consecutive backends with an equal, say, 80% weight of the max, with no
|
||||||
|
// offset we would see 1/5 generations that skipped both.
|
||||||
|
// TODO(b/190488683): add test for offset efficacy.
|
||||||
|
const uint16_t kOffset = kMaxWeight / 2; |
||||||
|
const uint16_t mod = |
||||||
|
(weight * generation + backend_index * kOffset) % kMaxWeight; |
||||||
|
|
||||||
|
if (mod < kMaxWeight - weight) { |
||||||
|
// Probability of skipping = 1 - mean(weights) / max(weights).
|
||||||
|
// For a typical large-scale service using RR, max task utilization will
|
||||||
|
// be ~100% when mean utilization is ~80%. So ~20% of picks will be
|
||||||
|
// skipped.
|
||||||
|
continue; |
||||||
|
} |
||||||
|
return backend_index; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
} // namespace grpc_core
|
@ -0,0 +1,69 @@ |
|||||||
|
//
|
||||||
|
// Copyright 2022 gRPC authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_LB_POLICY_WEIGHTED_ROUND_ROBIN_STATIC_STRIDE_SCHEDULER_H |
||||||
|
#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_LB_POLICY_WEIGHTED_ROUND_ROBIN_STATIC_STRIDE_SCHEDULER_H |
||||||
|
|
||||||
|
#include <grpc/support/port_platform.h> |
||||||
|
|
||||||
|
#include <atomic> |
||||||
|
#include <vector> |
||||||
|
|
||||||
|
#include "absl/functional/any_invocable.h" |
||||||
|
#include "absl/types/optional.h" |
||||||
|
#include "absl/types/span.h" |
||||||
|
|
||||||
|
namespace grpc_core { |
||||||
|
|
||||||
|
// StaticStrideScheduler implements a stride scheduler without the ability to
|
||||||
|
// add, remove, or modify elements after construction. In exchange, not only is
|
||||||
|
// it cheaper to construct and batch-update weights than a traditional dynamic
|
||||||
|
// stride scheduler, it can also be used to make concurrent picks without any
|
||||||
|
// locking.
|
||||||
|
//
|
||||||
|
// Construction is O(|weights|). Picking is O(1) if weights are similar, or
|
||||||
|
// O(|weights|) if the mean of the non-zero weights is a small fraction of the
|
||||||
|
// max. Stores two bytes per weight.
|
||||||
|
class StaticStrideScheduler { |
||||||
|
public: |
||||||
|
// Constructs and returns a new StaticStrideScheduler, or nullopt if all
|
||||||
|
// wieghts are zero or |weights| <= 1. All weights must be >=0.
|
||||||
|
// `next_sequence_func` should return a rate monotonically increasing sequence
|
||||||
|
// number, which may wrap. `float_weights` does not need to live beyond the
|
||||||
|
// function. Caller is responsible for ensuring `next_sequence_func` remains
|
||||||
|
// valid for all calls to `Pick()`.
|
||||||
|
static absl::optional<StaticStrideScheduler> Make( |
||||||
|
absl::Span<const float> float_weights, |
||||||
|
absl::AnyInvocable<uint32_t()> next_sequence_func); |
||||||
|
|
||||||
|
// Returns the index of the next pick. May invoke `next_sequence_func`
|
||||||
|
// multiple times. The returned value is guaranteed to be in [0, |weights|).
|
||||||
|
// Can be called concurrently iff `next_sequence_func` can.
|
||||||
|
size_t Pick() const; |
||||||
|
|
||||||
|
private: |
||||||
|
StaticStrideScheduler(std::vector<uint16_t> weights, |
||||||
|
absl::AnyInvocable<uint32_t()> next_sequence_func); |
||||||
|
|
||||||
|
mutable absl::AnyInvocable<uint32_t()> next_sequence_func_; |
||||||
|
|
||||||
|
// List of backend weights scaled such that the max(weights_) == kMaxWeight.
|
||||||
|
std::vector<uint16_t> weights_; |
||||||
|
}; |
||||||
|
|
||||||
|
} // namespace grpc_core
|
||||||
|
|
||||||
|
#endif // GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_LB_POLICY_WEIGHTED_ROUND_ROBIN_STATIC_STRIDE_SCHEDULER_H
|
@ -0,0 +1,116 @@ |
|||||||
|
//
|
||||||
|
// Copyright 2022 gRPC authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <algorithm> |
||||||
|
#include <atomic> |
||||||
|
#include <cstdint> |
||||||
|
#include <limits> |
||||||
|
#include <vector> |
||||||
|
|
||||||
|
#include <benchmark/benchmark.h> |
||||||
|
|
||||||
|
#include "absl/algorithm/container.h" |
||||||
|
#include "absl/random/random.h" |
||||||
|
#include "absl/types/optional.h" |
||||||
|
|
||||||
|
#include <grpc/support/log.h> |
||||||
|
|
||||||
|
#include "src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h" |
||||||
|
#include "src/core/lib/gprpp/no_destruct.h" |
||||||
|
|
||||||
|
namespace grpc_core { |
||||||
|
namespace { |
||||||
|
|
||||||
|
const int kNumWeightsLow = 10; |
||||||
|
const int kNumWeightsHigh = 10000; |
||||||
|
const int kRangeMultiplier = 10; |
||||||
|
|
||||||
|
// Returns a randomly ordered list of weights equally distributed between 0.6
|
||||||
|
// and 1.0.
|
||||||
|
const std::vector<float>& Weights() { |
||||||
|
static const NoDestruct<std::vector<float>> kWeights([] { |
||||||
|
static NoDestruct<absl::BitGen> bit_gen; |
||||||
|
std::vector<float> weights; |
||||||
|
weights.reserve(kNumWeightsHigh); |
||||||
|
for (int i = 0; i < 40; ++i) { |
||||||
|
for (int j = 0; j < kNumWeightsHigh / 40; ++j) { |
||||||
|
weights.push_back(0.6 + (0.01 * i)); |
||||||
|
} |
||||||
|
} |
||||||
|
absl::c_shuffle(weights, *bit_gen); |
||||||
|
return weights; |
||||||
|
}()); |
||||||
|
return *kWeights; |
||||||
|
} |
||||||
|
|
||||||
|
void BM_StaticStrideSchedulerPickNonAtomic(benchmark::State& state) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make( |
||||||
|
absl::MakeSpan(Weights()).subspan(0, state.range(0)), |
||||||
|
[&] { return sequence++; }); |
||||||
|
GPR_ASSERT(scheduler.has_value()); |
||||||
|
for (auto s : state) { |
||||||
|
benchmark::DoNotOptimize(scheduler->Pick()); |
||||||
|
} |
||||||
|
} |
||||||
|
BENCHMARK(BM_StaticStrideSchedulerPickNonAtomic) |
||||||
|
->RangeMultiplier(kRangeMultiplier) |
||||||
|
->Range(kNumWeightsLow, kNumWeightsHigh); |
||||||
|
|
||||||
|
void BM_StaticStrideSchedulerPickAtomic(benchmark::State& state) { |
||||||
|
std::atomic<uint32_t> sequence{0}; |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make( |
||||||
|
absl::MakeSpan(Weights()).subspan(0, state.range(0)), |
||||||
|
[&] { return sequence.fetch_add(1, std::memory_order_relaxed); }); |
||||||
|
GPR_ASSERT(scheduler.has_value()); |
||||||
|
for (auto s : state) { |
||||||
|
benchmark::DoNotOptimize(scheduler->Pick()); |
||||||
|
} |
||||||
|
} |
||||||
|
BENCHMARK(BM_StaticStrideSchedulerPickAtomic) |
||||||
|
->RangeMultiplier(kRangeMultiplier) |
||||||
|
->Range(kNumWeightsLow, kNumWeightsHigh); |
||||||
|
|
||||||
|
void BM_StaticStrideSchedulerMake(benchmark::State& state) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
for (auto s : state) { |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make( |
||||||
|
absl::MakeSpan(Weights()).subspan(0, state.range(0)), |
||||||
|
[&] { return sequence++; }); |
||||||
|
GPR_ASSERT(scheduler.has_value()); |
||||||
|
} |
||||||
|
} |
||||||
|
BENCHMARK(BM_StaticStrideSchedulerMake) |
||||||
|
->RangeMultiplier(kRangeMultiplier) |
||||||
|
->Range(kNumWeightsLow, kNumWeightsHigh); |
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace grpc_core
|
||||||
|
|
||||||
|
// Some distros have RunSpecifiedBenchmarks under the benchmark namespace,
|
||||||
|
// and others do not. This allows us to support both modes.
|
||||||
|
namespace benchmark { |
||||||
|
void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } |
||||||
|
} // namespace benchmark
|
||||||
|
|
||||||
|
int main(int argc, char** argv) { |
||||||
|
benchmark::Initialize(&argc, argv); |
||||||
|
benchmark::RunTheBenchmarksNamespaced(); |
||||||
|
return 0; |
||||||
|
} |
@ -0,0 +1,201 @@ |
|||||||
|
//
|
||||||
|
// Copyright 2022 gRPC authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.h" |
||||||
|
|
||||||
|
#include <algorithm> |
||||||
|
#include <atomic> |
||||||
|
#include <cstdint> |
||||||
|
#include <limits> |
||||||
|
#include <vector> |
||||||
|
|
||||||
|
#include "absl/types/optional.h" |
||||||
|
#include "gmock/gmock.h" |
||||||
|
#include "gtest/gtest.h" |
||||||
|
|
||||||
|
#include <grpc/support/log.h> |
||||||
|
|
||||||
|
namespace grpc_core { |
||||||
|
namespace { |
||||||
|
|
||||||
|
using ::testing::ElementsAre; |
||||||
|
using ::testing::UnorderedElementsAre; |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, EmptyWeightsIsNullopt) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {}; |
||||||
|
ASSERT_FALSE(StaticStrideScheduler::Make(absl::MakeSpan(weights), [&] { |
||||||
|
return sequence++; |
||||||
|
}).has_value()); |
||||||
|
} |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, OneZeroWeightIsNullopt) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {0}; |
||||||
|
ASSERT_FALSE(StaticStrideScheduler::Make(absl::MakeSpan(weights), [&] { |
||||||
|
return sequence++; |
||||||
|
}).has_value()); |
||||||
|
} |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, AllZeroWeightsIsNullopt) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {0, 0, 0, 0}; |
||||||
|
ASSERT_FALSE(StaticStrideScheduler::Make(absl::MakeSpan(weights), [&] { |
||||||
|
return sequence++; |
||||||
|
}).has_value()); |
||||||
|
} |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, OneWeightsIsNullopt) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {1}; |
||||||
|
ASSERT_FALSE(StaticStrideScheduler::Make(absl::MakeSpan(weights), [&] { |
||||||
|
return sequence++; |
||||||
|
}).has_value()); |
||||||
|
} |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, PicksAreWeightedExactly) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {1, 2, 3}; |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make(absl::MakeSpan(weights), |
||||||
|
[&] { return sequence++; }); |
||||||
|
ASSERT_TRUE(scheduler.has_value()); |
||||||
|
|
||||||
|
std::vector<int> picks(weights.size()); |
||||||
|
for (int i = 0; i < 6; ++i) { |
||||||
|
++picks[scheduler->Pick()]; |
||||||
|
} |
||||||
|
EXPECT_THAT(picks, ElementsAre(1, 2, 3)); |
||||||
|
} |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, ZeroWeightUsesMean) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {3, 0, 1}; |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make(absl::MakeSpan(weights), |
||||||
|
[&] { return sequence++; }); |
||||||
|
ASSERT_TRUE(scheduler.has_value()); |
||||||
|
|
||||||
|
std::vector<int> picks(weights.size()); |
||||||
|
for (int i = 0; i < 6; ++i) { |
||||||
|
++picks[scheduler->Pick()]; |
||||||
|
} |
||||||
|
EXPECT_THAT(picks, ElementsAre(3, 2, 1)); |
||||||
|
} |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, AllWeightsEqualIsRoundRobin) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {300, 300, 0}; |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make(absl::MakeSpan(weights), |
||||||
|
[&] { return sequence++; }); |
||||||
|
ASSERT_TRUE(scheduler.has_value()); |
||||||
|
|
||||||
|
std::vector<size_t> picks(weights.size()); |
||||||
|
for (int i = 0; i < 3; ++i) { |
||||||
|
picks[i] = scheduler->Pick(); |
||||||
|
} |
||||||
|
|
||||||
|
// Each backend is selected exactly once.
|
||||||
|
EXPECT_THAT(picks, UnorderedElementsAre(0, 1, 2)); |
||||||
|
|
||||||
|
// And continues to be picked in the original order, whatever it may be.
|
||||||
|
for (int i = 0; i < 1000; ++i) { |
||||||
|
EXPECT_EQ(scheduler->Pick(), picks[i % 3]); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, PicksAreDeterministic) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {1, 2, 3}; |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make(absl::MakeSpan(weights), |
||||||
|
[&] { return sequence++; }); |
||||||
|
ASSERT_TRUE(scheduler.has_value()); |
||||||
|
|
||||||
|
const int n = 100; |
||||||
|
std::vector<size_t> picks; |
||||||
|
picks.reserve(n); |
||||||
|
for (int i = 0; i < n; ++i) { |
||||||
|
picks.push_back(scheduler->Pick()); |
||||||
|
} |
||||||
|
for (int i = 0; i < 5; ++i) { |
||||||
|
sequence = 0; |
||||||
|
for (int j = 0; j < n; ++j) { |
||||||
|
EXPECT_EQ(scheduler->Pick(), picks[j]); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
TEST(StaticStrideSchedulerTest, RebuildGiveSamePicks) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {1, 2, 3}; |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make(absl::MakeSpan(weights), |
||||||
|
[&] { return sequence++; }); |
||||||
|
ASSERT_TRUE(scheduler.has_value()); |
||||||
|
|
||||||
|
const int n = 100; |
||||||
|
std::vector<size_t> picks; |
||||||
|
picks.reserve(n); |
||||||
|
for (int i = 0; i < n; ++i) { |
||||||
|
picks.push_back(scheduler->Pick()); |
||||||
|
} |
||||||
|
|
||||||
|
// Rewind and make each pick with a new scheduler instance. This should give
|
||||||
|
// identical picks.
|
||||||
|
sequence = 0; |
||||||
|
for (int i = 0; i < n; ++i) { |
||||||
|
const absl::optional<StaticStrideScheduler> rebuild = |
||||||
|
StaticStrideScheduler::Make(absl::MakeSpan(weights), |
||||||
|
[&] { return sequence++; }); |
||||||
|
ASSERT_TRUE(rebuild.has_value()); |
||||||
|
|
||||||
|
EXPECT_EQ(rebuild->Pick(), picks[i]); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// This tests an internal implementation detail of StaticStrideScheduler --
|
||||||
|
// the highest weighted element will be picked on all `kMaxWeight` generations.
|
||||||
|
// The number of picks required to run through all values of the sequence is
|
||||||
|
// mean(weights) * kMaxWeight. It is worth testing this property because it can
|
||||||
|
// catch rounding and off-by-one errors.
|
||||||
|
TEST(StaticStrideSchedulerTest, LargestIsPickedEveryGeneration) { |
||||||
|
uint32_t sequence = 0; |
||||||
|
const std::vector<float> weights = {1, 2, 3}; |
||||||
|
const int mean = 2; |
||||||
|
const absl::optional<StaticStrideScheduler> scheduler = |
||||||
|
StaticStrideScheduler::Make(absl::MakeSpan(weights), |
||||||
|
[&] { return sequence++; }); |
||||||
|
ASSERT_TRUE(scheduler.has_value()); |
||||||
|
|
||||||
|
const int kMaxWeight = std::numeric_limits<uint16_t>::max(); |
||||||
|
int largest_weight_pick_count = 0; |
||||||
|
for (int i = 0; i < kMaxWeight * mean; ++i) { |
||||||
|
if (scheduler->Pick() == 2) { |
||||||
|
++largest_weight_pick_count; |
||||||
|
} |
||||||
|
} |
||||||
|
EXPECT_EQ(largest_weight_pick_count, kMaxWeight); |
||||||
|
} |
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace grpc_core
|
||||||
|
|
||||||
|
int main(int argc, char** argv) { |
||||||
|
::testing::InitGoogleTest(&argc, argv); |
||||||
|
return RUN_ALL_TESTS(); |
||||||
|
} |
Loading…
Reference in new issue