diff --git a/src/google/protobuf/BUILD.bazel b/src/google/protobuf/BUILD.bazel index b9d3f7e32b..5655b4b9bc 100644 --- a/src/google/protobuf/BUILD.bazel +++ b/src/google/protobuf/BUILD.bazel @@ -278,6 +278,36 @@ cc_test( ], ) +cc_test( + name = "reflection_visit_fields_test", + size = "small", + srcs = ["reflection_visit_fields_test.cc"], + copts = COPTS + select({ + "//build_defs:config_msvc": [], + "//conditions:default": [ + "-Wno-error=sign-compare", + ], + }), + deps = [ + ":arena", + ":cc_test_protos", + ":lite_test_util", + ":protobuf", + ":protobuf_lite", + ":test_util", + "//src/google/protobuf/io", + "//src/google/protobuf/stubs", + "//src/google/protobuf/testing", + "//src/google/protobuf/testing:file", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "port_test", srcs = ["port_test.cc"], diff --git a/src/google/protobuf/reflection_visit_fields_test.cc b/src/google/protobuf/reflection_visit_fields_test.cc index b5a797dc51..813004962d 100644 --- a/src/google/protobuf/reflection_visit_fields_test.cc +++ b/src/google/protobuf/reflection_visit_fields_test.cc @@ -11,6 +11,7 @@ #include "absl/log/absl_check.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" +#include "google/protobuf/arena.h" #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/map_test_util.h" #include "google/protobuf/map_unittest.pb.h" @@ -28,79 +29,48 @@ namespace { #ifdef __cpp_if_constexpr -TEST(ReflectionVisitTest, VisitedFieldCountMatchesListFields) { - protobuf_unittest::TestAllTypes message; - TestUtil::SetAllFields(&message); - const Reflection* reflection = message.GetReflection(); - - uint32_t count = 0; - VisitFields(message, [&](auto info) { count++; }); - - std::vector fields; - reflection->ListFields(message, &fields); - - EXPECT_EQ(count, fields.size()); -} - -TEST(ReflectionVisitTest, VisitedFieldCountMatchesListFieldsForExtension) { - protobuf_unittest::TestAllExtensions message; - TestUtil::SetAllExtensions(&message); - const Reflection* reflection = message.GetReflection(); - - uint32_t count = 0; - VisitFields(message, [&](auto info) { count++; }); - - std::vector fields; - reflection->ListFields(message, &fields); - - EXPECT_EQ(count, fields.size()); -} - -TEST(ReflectionVisitTest, VisitedFieldCountMatchesListFieldsForMessageType) { - protobuf_unittest::TestAllTypes message; - TestUtil::SetAllFields(&message); - const Reflection* reflection = message.GetReflection(); - - uint32_t count = 0; - VisitFields(message, [&](auto info) { count++; }, FieldMask::kMessage); - - std::vector fields; - reflection->ListFields(message, &fields); - int message_count = 0; - for (auto field : fields) { - if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) ++message_count; +using ::protobuf_unittest::NestedTestAllTypes; +using ::protobuf_unittest::TestAllExtensions; +using ::protobuf_unittest::TestAllTypes; +using ::protobuf_unittest::TestMap; +using ::protobuf_unittest::TestPackedExtensions; +using ::protobuf_unittest::TestPackedTypes; +using ::proto2_wireformat_unittest::TestMessageSet; + +struct TestParam { + absl::string_view name; + Message* (*create_message)(Arena& arena); +}; + +class VisitFieldsTest : public testing::TestWithParam { + public: + VisitFieldsTest() { + message_ = GetParam().create_message(arena_); + reflection_ = message_->GetReflection(); } - EXPECT_EQ(count, message_count); -} - -TEST(ReflectionVisitTest, VisitedFieldCountMatchesListFieldsForLazy) { - protobuf_unittest::NestedTestAllTypes original, parsed; - TestUtil::SetAllFields(original.mutable_payload()); - TestUtil::SetAllFields(original.mutable_lazy_child()->mutable_payload()); - ASSERT_TRUE(parsed.ParseFromString(original.SerializeAsString())); - const Reflection* reflection = parsed.GetReflection(); + protected: + Arena arena_; + Message* message_; + const Reflection* reflection_; +}; +TEST_P(VisitFieldsTest, VisitedFieldsCountMatchesListFields) { uint32_t count = 0; - VisitFields(parsed, [&](auto info) { count++; }); + VisitFields(*message_, [&](auto info) { ++count; }); std::vector fields; - reflection->ListFields(parsed, &fields); + reflection_->ListFields(*message_, &fields); EXPECT_EQ(count, fields.size()); } -TEST(ReflectionVisitTest, - VisitedFieldCountMatchesListFieldsForExtensionMessageType) { - protobuf_unittest::TestAllExtensions message; - TestUtil::SetAllExtensions(&message); - const Reflection* reflection = message.GetReflection(); - +TEST_P(VisitFieldsTest, VisitedMessageFieldsCountMatchesListFields) { uint32_t count = 0; - VisitFields(message, [&](auto info) { count++; }, FieldMask::kMessage); + VisitFields(*message_, [&](auto info) { ++count; }, FieldMask::kMessage); std::vector fields; - reflection->ListFields(message, &fields); + reflection_->ListFields(*message_, &fields); int message_count = 0; for (auto field : fields) { if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) ++message_count; @@ -109,126 +79,19 @@ TEST(ReflectionVisitTest, EXPECT_EQ(count, message_count); } -TEST(ReflectionVisitTest, VisitedFieldCountMatchesListFieldsForMap) { - protobuf_unittest::TestMap message; - MapTestUtil::SetMapFields(&message); - MapTestUtil::ExpectMapFieldsSet(message); - const Reflection* reflection = message.GetReflection(); +TEST_P(VisitFieldsTest, ClearByVisitFieldsMustBeEmpty) { + VisitFields(*message_, [](auto info) { info.Clear(); }); - uint32_t count = 0; - VisitFields(message, [&](auto info) { count++; }); - - std::vector fields; - reflection->ListFields(message, &fields); - - EXPECT_EQ(count, fields.size()); -} - -TEST(ReflectionVisitTest, ClearByVisitIsEmpty) { - protobuf_unittest::TestAllTypes message; - TestUtil::SetAllFields(&message); - TestUtil::ExpectAllFieldsSet(message); - - VisitFields(message, [&](auto info) { info.Clear(); }); - - TestUtil::ExpectClear(message); + EXPECT_EQ(message_->ByteSizeLong(), 0); } -TEST(ReflectionVisitTest, ClearByVisitIsEmptyForExtension) { - protobuf_unittest::TestAllExtensions message; - TestUtil::SetAllExtensions(&message); - TestUtil::ExpectAllExtensionsSet(message); +TEST_P(VisitFieldsTest, ClearByVisitFieldsRevisitNone) { + VisitFields(*message_, [](auto info) { info.Clear(); }); - VisitFields(message, [&](auto info) { info.Clear(); }); - - TestUtil::ExpectExtensionsClear(message); -} - -TEST(ReflectionVisitTest, ClearByVisitHasZeroRevisitForExtension) { - protobuf_unittest::TestAllExtensions message; - TestUtil::SetAllExtensions(&message); - TestUtil::ExpectAllExtensionsSet(message); - - // Clear all fields. - VisitFields(message, [&](auto info) { info.Clear(); }); - // Visiting clear message should yields no fields. uint32_t count = 0; - VisitFields(message, [&](auto info) { ++count; }); + VisitFields(*message_, [&](auto info) { ++count; }); EXPECT_EQ(count, 0); - TestUtil::ExpectExtensionsClear(message); -} - -TEST(ReflectionVisitTest, ClearByVisitHasZeroRevisitForLazy) { - protobuf_unittest::NestedTestAllTypes original, parsed; - TestUtil::SetAllFields(original.mutable_payload()); - TestUtil::SetAllFields(original.mutable_lazy_child()->mutable_payload()); - ASSERT_TRUE(parsed.ParseFromString(original.SerializeAsString())); - - VisitFields(parsed, [&](auto info) { info.Clear(); }); - // Visiting clear message should yields no fields. - uint32_t count = 0; - VisitFields(parsed, [&](auto info) { ++count; }); - - EXPECT_EQ(count, 0); -} - -TEST(ReflectionVisitTest, ClearByVisitIsEmptyForMap) { - protobuf_unittest::TestMap message; - MapTestUtil::SetMapFields(&message); - MapTestUtil::ExpectMapFieldsSet(message); - - VisitFields(message, [&](auto info) { info.Clear(); }); - - MapTestUtil::ExpectClear(message); -} - -template -void MutateMapValue(protobuf_unittest::TestMap& message, absl::string_view name, - int index, T&& callback) { - const Reflection* reflection = message.GetReflection(); - const Descriptor* descriptor = message.GetDescriptor(); - const FieldDescriptor* field = descriptor->FindFieldByName(name); - - auto* map_entry = reflection->MutableRepeatedMessage(&message, field, index); - const FieldDescriptor* val_field = map_entry->GetDescriptor()->map_value(); - ABSL_CHECK_NE(val_field, nullptr); - callback(map_entry->GetReflection(), map_entry, val_field); -} - -TEST(ReflectionVisitTest, VisitMapAfterMutableRepeated) { - protobuf_unittest::TestMap message; - auto& map = *message.mutable_map_int32_int32(); - map[0] = 0; - map[1] = 0; - - // Reflectively overwrites values to 200 for all entries. This forces - // conversion to a mutable repeated field. - auto set_int32_val = [&](const Reflection* reflection, Message* msg, - const FieldDescriptor* field) { - reflection->SetInt32(msg, field, 200); - }; - MutateMapValue(message, "map_int32_int32", 0, set_int32_val); - MutateMapValue(message, "map_int32_int32", 1, set_int32_val); - - // Later visit fields must be map fields synced with the change. - std::vector> key_val_pairs; - VisitFields(message, [&](auto info) { - if constexpr (info.is_map) { - ASSERT_EQ(info.key_type(), FieldDescriptor::TYPE_INT32); - ASSERT_EQ(info.value_type(), FieldDescriptor::TYPE_INT32); - - info.VisitElements([&](auto key, auto val) { - if constexpr (key.cpp_type == FieldDescriptor::CPPTYPE_INT32 && - val.cpp_type == FieldDescriptor::CPPTYPE_INT32) { - key_val_pairs.emplace_back(key.Get(), val.Get()); - } - }); - } - }); - - EXPECT_THAT(key_val_pairs, testing::UnorderedElementsAre( - testing::Pair(0, 200), testing::Pair(1, 200))); } void MutateNothingByVisit(Message& message) { @@ -288,22 +151,16 @@ void MutateNothingByVisit(Message& message) { }); } -TEST(ReflectionVisitTest, ReadAndWriteBackIdempotent) { - protobuf_unittest::TestAllTypes message; - TestUtil::SetAllFields(&message); - - MutateNothingByVisit(message); - - TestUtil::ExpectAllFieldsSet(message); -} - -TEST(ReflectionVisitTest, ReadAndWriteBackIdempotentForExtension) { - protobuf_unittest::TestAllExtensions message; - TestUtil::SetAllExtensions(&message); +TEST_P(VisitFieldsTest, MutateNothingByVisitIdempotent) { + std::string data; + ASSERT_TRUE(message_->SerializeToString(&data)); - MutateNothingByVisit(message); + MutateNothingByVisit(*message_); - TestUtil::ExpectAllExtensionsSet(message); + // Checking the identity by comparing serialize bytes is discouraged, but this + // allows us to be type-agnositc for this test. Also, the back to back + // serialization should be stable. + EXPECT_EQ(data, message_->SerializeAsString()); } template @@ -406,101 +263,145 @@ size_t ByteSizeLongByVisit(const Message& message) { return byte_size; } -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegen) { - protobuf_unittest::TestAllTypes message; - TestUtil::SetAllFields(&message); - TestUtil::ExpectAllFieldsSet(message); - - EXPECT_EQ(ByteSizeLongByVisit(message), message.ByteSizeLong()); +TEST_P(VisitFieldsTest, ByteSizeByVisitFieldsMatchesCodegen) { + EXPECT_EQ(ByteSizeLongByVisit(*message_), message_->ByteSizeLong()); } -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegenForPacked) { - protobuf_unittest::TestPackedTypes message; - TestUtil::SetPackedFields(&message); - TestUtil::ExpectPackedFieldsSet(message); - - EXPECT_EQ(ByteSizeLongByVisit(message), message.ByteSizeLong()); -} - -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegenForExtension) { - protobuf_unittest::TestAllExtensions message; - TestUtil::SetAllExtensions(&message); - TestUtil::ExpectAllExtensionsSet(message); - - EXPECT_EQ(ByteSizeLongByVisit(message), message.ByteSizeLong()); -} - -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegenForPackedExtensions) { - protobuf_unittest::TestPackedExtensions message; - TestUtil::SetPackedExtensions(&message); - TestUtil::ExpectPackedExtensionsSet(message); - - EXPECT_EQ(ByteSizeLongByVisit(message), message.ByteSizeLong()); -} - -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegenForLazyExtension) { - protobuf_unittest::TestAllExtensions original, parsed; - TestUtil::SetAllExtensions(&original); - TestUtil::ExpectAllExtensionsSet(original); - std::string data; - ASSERT_TRUE(original.SerializeToString(&data)); - ASSERT_TRUE(parsed.ParseFromString(data)); - - EXPECT_EQ(ByteSizeLongByVisit(parsed), parsed.ByteSizeLong()); -} - -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegenForMessageSet) { - proto2_wireformat_unittest::TestMessageSet message; - auto* ext1 = message.MutableExtension( - unittest::TestMessageSetExtension1::message_set_extension); - ext1->set_i(-1); - ext1->mutable_recursive() - ->MutableExtension( - unittest::TestMessageSetExtension3::message_set_extension) - ->mutable_msg() - ->set_b(0); - - EXPECT_EQ(ByteSizeLongByVisit(message), message.ByteSizeLong()); -} - -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegenForLazyMessageSet) { - proto2_wireformat_unittest::TestMessageSet original, parsed; - auto* ext1 = original.MutableExtension( +TestMessageSet* CreateTestMessageSet(Arena& arena) { + auto* msg = Arena::Create(&arena); + auto* ext1 = msg->MutableExtension( unittest::TestMessageSetExtension1::message_set_extension); ext1->set_i(-1); - auto* ext3 = ext1->mutable_recursive()->MutableExtension( unittest::TestMessageSetExtension3::message_set_extension); - ext3->mutable_msg()->set_b(0); ext3->set_required_int(-1); + ext3->mutable_msg()->set_b(0); + return msg; +} - std::string data; - ASSERT_TRUE(original.SerializeToString(&data)); - ASSERT_TRUE(parsed.ParseFromString(data)); - - EXPECT_EQ(ByteSizeLongByVisit(parsed), parsed.ByteSizeLong()); +NestedTestAllTypes* CreateNestedTestAllTypes(Arena& arena) { + auto* msg = Arena::Create(&arena); + TestUtil::SetAllFields(msg->mutable_payload()); + TestUtil::SetAllFields(msg->mutable_lazy_child()->mutable_payload()); + return msg; } -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegenForLazy) { - protobuf_unittest::NestedTestAllTypes original, parsed; - TestUtil::SetAllFields(original.mutable_payload()); - TestUtil::SetAllFields(original.mutable_lazy_child()->mutable_payload()); - std::string data; - ASSERT_TRUE(original.SerializeToString(&data)); - ASSERT_TRUE(parsed.ParseFromString(data)); +INSTANTIATE_TEST_SUITE_P( + ReflectionVisitFieldsTest, VisitFieldsTest, + testing::Values( + TestParam{"TestAllTypes", + [](Arena& arena) -> Message* { + auto* msg = Arena::Create(&arena); + TestUtil::SetAllFields(msg); + return msg; + }}, + TestParam{"TestAllExtensions", + [](Arena& arena) -> Message* { + auto* msg = Arena::Create(&arena); + TestUtil::SetAllExtensions(msg); + return msg; + }}, + TestParam{"TestAllExtensionsLazy", + [](Arena& arena) -> Message* { + TestAllExtensions original; + TestUtil::SetAllExtensions(&original); + auto* parsed = Arena::Create(&arena); + ABSL_CHECK( + parsed->ParseFromString(original.SerializeAsString())); + return parsed; + }}, + TestParam{"TestMap", + [](Arena& arena) -> Message* { + auto* msg = Arena::Create(&arena); + MapTestUtil::SetMapFields(msg); + return msg; + }}, + TestParam{"TestMessageSet", + [](Arena& arena) -> Message* { + return CreateTestMessageSet(arena); + }}, + TestParam{"TestMessageSetLazy", + [](Arena& arena) -> Message* { + auto* original = CreateTestMessageSet(arena); + auto* parsed = Arena::Create(&arena); + ABSL_CHECK( + parsed->ParseFromString(original->SerializeAsString())); + return parsed; + }}, + TestParam{"TestPacked", + [](Arena& arena) -> Message* { + auto* msg = Arena::Create(&arena); + TestUtil::SetPackedFields(msg); + return msg; + }}, + TestParam{"TestPackedExtensions", + [](Arena& arena) -> Message* { + auto* msg = Arena::Create(&arena); + TestUtil::SetPackedExtensions(msg); + return msg; + }}, + TestParam{"NestedTestAllTypes", + [](Arena& arena) -> Message* { + return CreateNestedTestAllTypes(arena); + }}, + TestParam{"NestedTestAllTypesLazy", + [](Arena& arena) -> Message* { + auto* original = CreateNestedTestAllTypes(arena); + auto* parsed = Arena::Create(&arena); + ABSL_CHECK( + parsed->ParseFromString(original->SerializeAsString())); + return parsed; + }}), + [](const testing::TestParamInfo& info) { + return std::string(info.param.name); + }); - size_t byte_size_visit = ByteSizeLongByVisit(parsed); +template +void MutateMapValue(TestMap& message, absl::string_view name, int index, + T&& callback) { + const Reflection* reflection = message.GetReflection(); + const Descriptor* descriptor = message.GetDescriptor(); + const FieldDescriptor* field = descriptor->FindFieldByName(name); - EXPECT_EQ(byte_size_visit, parsed.ByteSizeLong()); - EXPECT_EQ(byte_size_visit, data.size()); + auto* map_entry = reflection->MutableRepeatedMessage(&message, field, index); + const FieldDescriptor* val_field = map_entry->GetDescriptor()->map_value(); + ABSL_CHECK_NE(val_field, nullptr); + callback(map_entry->GetReflection(), map_entry, val_field); } -TEST(ReflectionVisitTest, ByteSizeByVisitMatchesCodegenForMap) { - protobuf_unittest::TestMap message; - MapTestUtil::SetMapFields(&message); - MapTestUtil::ExpectMapFieldsSet(message); +TEST(ReflectionVisitTest, VisitMapAfterMutableRepeated) { + TestMap message; + auto& map = *message.mutable_map_int32_int32(); + map[0] = 0; + map[1] = 0; + + // Reflectively overwrites values to 200 for all entries. This forces + // conversion to a mutable repeated field. + auto set_int32_val = [&](const Reflection* reflection, Message* msg, + const FieldDescriptor* field) { + reflection->SetInt32(msg, field, 200); + }; + MutateMapValue(message, "map_int32_int32", 0, set_int32_val); + MutateMapValue(message, "map_int32_int32", 1, set_int32_val); + + // Later visit fields must be map fields synced with the change. + std::vector> key_val_pairs; + VisitFields(message, [&](auto info) { + if constexpr (info.is_map) { + ASSERT_EQ(info.key_type(), FieldDescriptor::TYPE_INT32); + ASSERT_EQ(info.value_type(), FieldDescriptor::TYPE_INT32); - EXPECT_EQ(ByteSizeLongByVisit(message), message.ByteSizeLong()); + info.VisitElements([&](auto key, auto val) { + if constexpr (key.cpp_type == FieldDescriptor::CPPTYPE_INT32 && + val.cpp_type == FieldDescriptor::CPPTYPE_INT32) { + key_val_pairs.emplace_back(key.Get(), val.Get()); + } + }); + } + }); + + EXPECT_THAT(key_val_pairs, testing::UnorderedElementsAre( + testing::Pair(0, 200), testing::Pair(1, 200))); } #endif // __cpp_if_constexpr