From b81d2cc8c5c0210414f80c4d4e8883037b613859 Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Wed, 12 Jul 2023 16:02:16 -0700 Subject: [PATCH] Expand VisitDescriptor to support mutable access to the proto. PiperOrigin-RevId: 547625871 --- src/google/protobuf/descriptor_visitor.h | 74 +++++++++++++------ .../protobuf/descriptor_visitor_test.cc | 18 +++++ 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/src/google/protobuf/descriptor_visitor.h b/src/google/protobuf/descriptor_visitor.h index a8b2f04b04..267ffab703 100644 --- a/src/google/protobuf/descriptor_visitor.h +++ b/src/google/protobuf/descriptor_visitor.h @@ -46,6 +46,10 @@ template void VisitDescriptors(const FileDescriptor& file, const FileDescriptorProto& proto, Visitor visitor); +template +void VisitDescriptors(const FileDescriptor& file, FileDescriptorProto& proto, + Visitor visitor); + // Visit just the descriptors, without a corresponding proto tree. template void VisitDescriptors(const FileDescriptor& file, Visitor visitor); @@ -54,92 +58,111 @@ template struct VisitImpl { Visitor visitor; template - void Visit(const FieldDescriptor& descriptor, const Proto&... proto) { + void Visit(const FieldDescriptor& descriptor, Proto&... proto) { visitor(descriptor, proto...); } template - void Visit(const EnumValueDescriptor& descriptor, const Proto&... proto) { + void Visit(const EnumValueDescriptor& descriptor, Proto&... proto) { visitor(descriptor, proto...); } template - void Visit(const EnumDescriptor& descriptor, const Proto&... proto) { + void Visit(const EnumDescriptor& descriptor, Proto&... proto) { visitor(descriptor, proto...); for (int i = 0; i < descriptor.value_count(); i++) { - Visit(*descriptor.value(i), proto.value(i)...); + Visit(*descriptor.value(i), value(proto, i)...); } } template - void Visit(const Descriptor::ExtensionRange& descriptor, - const Proto&... proto) { + void Visit(const Descriptor::ExtensionRange& descriptor, Proto&... proto) { visitor(descriptor, proto...); } template - void Visit(const OneofDescriptor& descriptor, const Proto&... proto) { + void Visit(const OneofDescriptor& descriptor, Proto&... proto) { visitor(descriptor, proto...); } template - void Visit(const Descriptor& descriptor, const Proto&... proto) { + void Visit(const Descriptor& descriptor, Proto&... proto) { visitor(descriptor, proto...); for (int i = 0; i < descriptor.enum_type_count(); i++) { - Visit(*descriptor.enum_type(i), proto.enum_type(i)...); + Visit(*descriptor.enum_type(i), enum_type(proto, i)...); } for (int i = 0; i < descriptor.oneof_decl_count(); i++) { - Visit(*descriptor.oneof_decl(i), proto.oneof_decl(i)...); + Visit(*descriptor.oneof_decl(i), oneof_decl(proto, i)...); } for (int i = 0; i < descriptor.field_count(); i++) { - Visit(*descriptor.field(i), proto.field(i)...); + Visit(*descriptor.field(i), field(proto, i)...); } for (int i = 0; i < descriptor.nested_type_count(); i++) { - Visit(*descriptor.nested_type(i), proto.nested_type(i)...); + Visit(*descriptor.nested_type(i), nested_type(proto, i)...); } for (int i = 0; i < descriptor.extension_count(); i++) { - Visit(*descriptor.extension(i), proto.extension(i)...); + Visit(*descriptor.extension(i), extension(proto, i)...); } for (int i = 0; i < descriptor.extension_range_count(); i++) { - Visit(*descriptor.extension_range(i), proto.extension_range(i)...); + Visit(*descriptor.extension_range(i), extension_range(proto, i)...); } } template - void Visit(const MethodDescriptor& method, const Proto&... proto) { + void Visit(const MethodDescriptor& method, Proto&... proto) { visitor(method, proto...); } template - void Visit(const ServiceDescriptor& descriptor, const Proto&... proto) { + void Visit(const ServiceDescriptor& descriptor, Proto&... proto) { visitor(descriptor, proto...); for (int i = 0; i < descriptor.method_count(); i++) { - Visit(*descriptor.method(i), proto.method(i)...); + Visit(*descriptor.method(i), method(proto, i)...); } } template - void Visit(const FileDescriptor& descriptor, const Proto&... proto) { + void Visit(const FileDescriptor& descriptor, Proto&... proto) { visitor(descriptor, proto...); for (int i = 0; i < descriptor.message_type_count(); i++) { - Visit(*descriptor.message_type(i), proto.message_type(i)...); + Visit(*descriptor.message_type(i), message_type(proto, i)...); } for (int i = 0; i < descriptor.enum_type_count(); i++) { - Visit(*descriptor.enum_type(i), proto.enum_type(i)...); + Visit(*descriptor.enum_type(i), enum_type(proto, i)...); } for (int i = 0; i < descriptor.extension_count(); i++) { - Visit(*descriptor.extension(i), proto.extension(i)...); + Visit(*descriptor.extension(i), extension(proto, i)...); } for (int i = 0; i < descriptor.service_count(); i++) { - Visit(*descriptor.service(i), proto.service(i)...); + Visit(*descriptor.service(i), service(proto, i)...); } } + + private: +#define CREATE_NESTED_GETTER(TYPE, NESTED) \ + inline auto& NESTED(TYPE& desc, int i) { return *desc.mutable_##NESTED(i); } \ + inline auto& NESTED(const TYPE& desc, int i) { return desc.NESTED(i); } + + CREATE_NESTED_GETTER(DescriptorProto, enum_type); + CREATE_NESTED_GETTER(DescriptorProto, extension); + CREATE_NESTED_GETTER(DescriptorProto, extension_range); + CREATE_NESTED_GETTER(DescriptorProto, field); + CREATE_NESTED_GETTER(DescriptorProto, nested_type); + CREATE_NESTED_GETTER(DescriptorProto, oneof_decl); + CREATE_NESTED_GETTER(EnumDescriptorProto, value); + CREATE_NESTED_GETTER(FileDescriptorProto, enum_type); + CREATE_NESTED_GETTER(FileDescriptorProto, extension); + CREATE_NESTED_GETTER(FileDescriptorProto, message_type); + CREATE_NESTED_GETTER(FileDescriptorProto, service); + CREATE_NESTED_GETTER(ServiceDescriptorProto, method); + +#undef CREATE_NESTED_GETTER }; // Provide a fallback to ignore all the nodes that are not interesting to the @@ -167,6 +190,13 @@ void VisitDescriptors(const FileDescriptor& file, internal::VisitImpl{VisitorImpl(visitor)}.Visit(file, proto); } +template +void VisitDescriptors(const FileDescriptor& file, FileDescriptorProto& proto, + Visitor visitor) { + using VisitorImpl = internal::VisitorImpl; + internal::VisitImpl{VisitorImpl(visitor)}.Visit(file, proto); +} + template void VisitDescriptors(const FileDescriptor& file, Visitor visitor) { using VisitorImpl = internal::VisitorImpl; diff --git a/src/google/protobuf/descriptor_visitor_test.cc b/src/google/protobuf/descriptor_visitor_test.cc index 020c3754d0..2e7bb0165e 100644 --- a/src/google/protobuf/descriptor_visitor_test.cc +++ b/src/google/protobuf/descriptor_visitor_test.cc @@ -78,6 +78,24 @@ TEST(VisitDescriptorsTest, SingleTypeWithProto) { "protobuf_unittest.TestAllTypes.NestedMessage"})); } +TEST(VisitDescriptorsTest, SingleTypeMutableProto) { + const FileDescriptor& file = + *protobuf_unittest::TestAllTypes::GetDescriptor()->file(); + FileDescriptorProto proto; + file.CopyTo(&proto); + std::vector descriptors; + VisitDescriptors(file, proto, + [&](const Descriptor& descriptor, DescriptorProto& proto) { + descriptors.push_back(descriptor.full_name()); + EXPECT_EQ(descriptor.name(), proto.name()); + proto.set_name(""); + }); + EXPECT_THAT(descriptors, + IsSupersetOf({"protobuf_unittest.TestAllTypes", + "protobuf_unittest.TestAllTypes.NestedMessage"})); + EXPECT_EQ(proto.message_type(0).name(), ""); +} + TEST(VisitDescriptorsTest, AllTypesDeduce) { const FileDescriptor& file = *protobuf_unittest::TestAllTypes::GetDescriptor()->file();