From 594a71f48515715d37a797252e0dd2d97e635826 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Thu, 20 Jun 2024 12:38:21 -0700 Subject: [PATCH] Add an optimization that prevents an exponential number of comparisons on deeply nested repeated fields when using AS_SET or AS_SMART_SET. PiperOrigin-RevId: 645123131 --- .../protobuf/util/message_differencer.cc | 23 +++++ .../util/message_differencer_unittest.cc | 85 +++++++++++++++++++ .../util/message_differencer_unittest.proto | 1 + 3 files changed, 109 insertions(+) diff --git a/src/google/protobuf/util/message_differencer.cc b/src/google/protobuf/util/message_differencer.cc index 2f11e5f889..258c0e9bec 100644 --- a/src/google/protobuf/util/message_differencer.cc +++ b/src/google/protobuf/util/message_differencer.cc @@ -1913,6 +1913,29 @@ bool MessageDifferencer::MatchRepeatedFieldIndices( match_list1->assign(count1, -1); match_list2->assign(count2, -1); + + // In the special case where both repeated fields have exactly one element, + // return without calling the comparator. This optimization prevents the + // pathological case of deeply nested repeated fields of size 1 from taking + // exponential-time to compare. + // + // In the case where reporter_ is set, we need to do the compare here to + // properly distinguish a modify from an add+delete. The code below will not + // pass the reporter along in recursive calls to nested repeated fields, so + // the inner call will have the opportunity to perform this optimization and + // avoid exponential-time behavior. + // + // In the case where key_comparator is set, we need to do the compare here to + // fulfill the interface contract that keys will be compared even if the user + // asked to ignore that field. The code will only compare the key fields + // which (hopefully) do not contain further repeated fields. + if (count1 == 1 && count2 == 1 && reporter_ == nullptr && + key_comparator == nullptr) { + match_list1->at(0) = 0; + match_list2->at(0) = 0; + return true; + } + // Ensure that we don't report differences during the matching process. Since // field comparators could potentially use this message differencer object to // perform further comparisons, turn off reporting here and re-enable it diff --git a/src/google/protobuf/util/message_differencer_unittest.cc b/src/google/protobuf/util/message_differencer_unittest.cc index a77f819518..8807d1dfe0 100644 --- a/src/google/protobuf/util/message_differencer_unittest.cc +++ b/src/google/protobuf/util/message_differencer_unittest.cc @@ -1912,6 +1912,91 @@ TEST(MessageDifferencerTest, RepeatedFieldSetTest_Combination) { EXPECT_TRUE(differencer2.Compare(msg1, msg2)); } +// This class is a comparator that uses the default comparator, but counts how +// many times it was called. +class CountingComparator : public util::SimpleFieldComparator { + public: + ComparisonResult Compare(const Message& message_1, const Message& message_2, + const FieldDescriptor* field, int index_1, + int index_2, + const util::FieldContext* field_context) override { + ++compare_count_; + return SimpleCompare(message_1, message_2, field, index_1, index_2, + field_context); + } + + int compare_count() const { return compare_count_; } + + private: + int compare_count_ = 0; +}; + +TEST(MessageDifferencerTest, RepeatedFieldSet_RecursivePerformance) { + constexpr int kDepth = 20; + + protobuf_unittest::TestField left; + protobuf_unittest::TestField* p = &left; + for (int i = 0; i < kDepth; ++i) { + p = p->add_rm(); + } + + protobuf_unittest::TestField right = left; + util::MessageDifferencer differencer; + differencer.set_repeated_field_comparison( + util::MessageDifferencer::RepeatedFieldComparison::AS_SET); + CountingComparator comparator; + differencer.set_field_comparator(&comparator); + std::string report; + differencer.ReportDifferencesToString(&report); + differencer.Compare(left, right); + + EXPECT_LE(comparator.compare_count(), kDepth * kDepth); +} + +TEST(MessageDifferencerTest, RepeatedFieldSmartSet_RecursivePerformance) { + constexpr int kDepth = 20; + + protobuf_unittest::TestField left; + protobuf_unittest::TestField* p = &left; + for (int i = 0; i < kDepth; ++i) { + p = p->add_rm(); + } + + protobuf_unittest::TestField right = left; + util::MessageDifferencer differencer; + differencer.set_repeated_field_comparison( + util::MessageDifferencer::RepeatedFieldComparison::AS_SMART_SET); + CountingComparator comparator; + differencer.set_field_comparator(&comparator); + std::string report; + differencer.ReportDifferencesToString(&report); + differencer.Compare(left, right); + + EXPECT_LE(comparator.compare_count(), kDepth * kDepth); +} + +TEST(MessageDifferencerTest, RepeatedFieldSmartList_RecursivePerformance) { + constexpr int kDepth = 20; + + protobuf_unittest::TestField left; + protobuf_unittest::TestField* p = &left; + for (int i = 0; i < kDepth; ++i) { + p = p->add_rm(); + } + + protobuf_unittest::TestField right = left; + util::MessageDifferencer differencer; + differencer.set_repeated_field_comparison( + util::MessageDifferencer::RepeatedFieldComparison::AS_SMART_LIST); + CountingComparator comparator; + differencer.set_field_comparator(&comparator); + std::string report; + differencer.ReportDifferencesToString(&report); + differencer.Compare(left, right); + + EXPECT_LE(comparator.compare_count(), kDepth * kDepth); +} + TEST(MessageDifferencerTest, RepeatedFieldMapTest_Partial) { protobuf_unittest::TestDiffMessage msg1; // message msg1 { diff --git a/src/google/protobuf/util/message_differencer_unittest.proto b/src/google/protobuf/util/message_differencer_unittest.proto index 032f4b9b7f..702b3e3549 100644 --- a/src/google/protobuf/util/message_differencer_unittest.proto +++ b/src/google/protobuf/util/message_differencer_unittest.proto @@ -26,6 +26,7 @@ message TestField { optional int32 c = 1; repeated int32 rc = 2; optional TestField m = 5; + repeated TestField rm = 6; extend TestDiffMessage { optional TestField tf = 100;