diff --git a/src/google/protobuf/reflection_visit_fields.h b/src/google/protobuf/reflection_visit_fields.h index e3178c5fbe..8f1cae36c7 100644 --- a/src/google/protobuf/reflection_visit_fields.h +++ b/src/google/protobuf/reflection_visit_fields.h @@ -60,6 +60,12 @@ class ReflectionVisit final { template static void VisitFields(MessageT& message, CallbackFn&& func, FieldMask mask); + template + static void VisitMessageFields(const Message& message, CallbackFn&& func); + + template + static void VisitMessageFields(Message& message, CallbackFn&& func); + private: static const internal::ReflectionSchema& GetSchema( const Reflection* reflection) { @@ -396,10 +402,95 @@ void ReflectionVisit::VisitFields(MessageT& message, CallbackFn&& func, }); } +template +void ReflectionVisit::VisitMessageFields(const Message& message, + CallbackFn&& func) { + ReflectionVisit::VisitFields( + message, + [&](auto info) { + if constexpr (info.is_map) { + auto value_type = info.value_type(); + if (value_type != FieldDescriptor::TYPE_MESSAGE && + value_type != FieldDescriptor::TYPE_GROUP) { + return; + } + info.VisitElements([&](auto key, auto val) { + if constexpr (val.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) { + func(val.Get()); + } + }); + } else if constexpr (info.cpp_type == + FieldDescriptor::CPPTYPE_MESSAGE) { + if constexpr (info.is_repeated) { + for (const auto& it : info.Get()) { + func(DownCast(it)); + } + } else { + func(info.Get()); + } + } + }, + FieldMask::kMessage); +} + +template +void ReflectionVisit::VisitMessageFields(Message& message, CallbackFn&& func) { + ReflectionVisit::VisitFields( + message, + [&](auto info) { + if constexpr (info.is_map) { + auto value_type = info.value_type(); + if (value_type != FieldDescriptor::TYPE_MESSAGE && + value_type != FieldDescriptor::TYPE_GROUP) { + return; + } + info.VisitElements([&](auto key, auto val) { + if constexpr (val.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) { + func(*val.Mutable()); + } + }); + } else if constexpr (info.cpp_type == + FieldDescriptor::CPPTYPE_MESSAGE) { + if constexpr (info.is_repeated) { + for (auto& it : info.Mutable()) { + func(DownCast(it)); + } + } else { + func(info.Mutable()); + } + } + }, + FieldMask::kMessage); +} + +// Visits present fields of "message" and calls the callback function "func". +// Skips fields whose ctypes are missing in "mask". template void VisitFields(MessageT& message, CallbackFn&& func, FieldMask mask) { - internal::ReflectionVisit::VisitFields(message, - std::forward(func), mask); + ReflectionVisit::VisitFields(message, std::forward(func), mask); +} + +// Visits message fields of "message" and calls "func". Expects "func" to +// accept const Message&. Note the following divergence from VisitFields. +// +// --Each of N elements of a repeated message field is visited (total N). +// --Each of M elements of a map field whose value type is message are visited +// (total M). +// --A map field whose value type is not message is ignored. +// +// This is a helper API built on top of VisitFields to hide specifics about +// extensions, repeated fields, etc. +template +void VisitMessageFields(const Message& message, CallbackFn&& func) { + ReflectionVisit::VisitMessageFields(message, std::forward(func)); +} + +// Same as VisitMessageFields above but expects "func" to accept Message&. This +// is useful when mutable access is required. As mutable access can be +// expensive, use it only if it's necessary. +template +void VisitMutableMessageFields(Message& message, CallbackFn&& func) { + ReflectionVisit::VisitMessageFields(message, std::forward(func)); } #endif // __cpp_if_constexpr diff --git a/src/google/protobuf/reflection_visit_fields_test.cc b/src/google/protobuf/reflection_visit_fields_test.cc index 130bdf7d8b..eb69d9bb0a 100644 --- a/src/google/protobuf/reflection_visit_fields_test.cc +++ b/src/google/protobuf/reflection_visit_fields_test.cc @@ -80,6 +80,47 @@ TEST_P(VisitFieldsTest, VisitedMessageFieldsCountMatchesListFields) { EXPECT_EQ(count, message_count); } +// Counts present message fields using ListFields() where: +// --N elements in a repeated message field are counted N times +// --M message values in a map field are counted M times. +// --A map field whose value type is not message is ignored. +int CountAllMessageFieldsViaListFields(const Reflection* reflection, + const Message& message) { + std::vector fields; + reflection->ListFields(message, &fields); + + int message_count = 0; + for (auto field : fields) { + if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue; + if (field->is_map()) { + if (field->message_type()->map_value()->cpp_type() != + FieldDescriptor::CPPTYPE_MESSAGE) + continue; + } + if (field->is_repeated()) { + message_count += reflection->FieldSize(message, field); + } else { + ++message_count; + } + } + return message_count; +} + +TEST_P(VisitFieldsTest, VisitMessageFieldsCountIncludesRepeatedElements) { + int count = 0; + VisitMessageFields(*message_, [&](const Message& msg) { ++count; }); + + EXPECT_EQ(count, CountAllMessageFieldsViaListFields(reflection_, *message_)); +} + +TEST_P(VisitFieldsTest, + VisitMutableMessageFieldsCountIncludesRepeatedElements) { + int count = 0; + VisitMutableMessageFields(*message_, [&](Message& msg) { ++count; }); + + EXPECT_EQ(count, CountAllMessageFieldsViaListFields(reflection_, *message_)); +} + TEST_P(VisitFieldsTest, ClearByVisitFieldsMustBeEmpty) { VisitFields(*message_, [](auto info) { info.Clear(); });