From 80b060966636eecd8e7e2c4432f9f76a1df55cc2 Mon Sep 17 00:00:00 2001 From: Matt Kulukundis Date: Mon, 1 May 2023 18:47:51 -0700 Subject: [PATCH] Refactor map_field.cc to use Emit(). PiperOrigin-RevId: 528634293 --- .../cpp/field_generators/map_field.cc | 471 ++++++++---------- 1 file changed, 212 insertions(+), 259 deletions(-) diff --git a/src/google/protobuf/compiler/cpp/field_generators/map_field.cc b/src/google/protobuf/compiler/cpp/field_generators/map_field.cc index 97bdb68b0c..3d3456f2e6 100644 --- a/src/google/protobuf/compiler/cpp/field_generators/map_field.cc +++ b/src/google/protobuf/compiler/cpp/field_generators/map_field.cc @@ -30,317 +30,270 @@ #include #include -#include +#include -#include "absl/container/flat_hash_map.h" -#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" -#include "absl/strings/str_cat.h" -#include "google/protobuf/compiler/cpp/field_generators/generators.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "google/protobuf/compiler/cpp/field.h" #include "google/protobuf/compiler/cpp/helpers.h" -#include "google/protobuf/wire_format.h" +#include "google/protobuf/compiler/cpp/options.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/printer.h" namespace google { namespace protobuf { namespace compiler { namespace cpp { namespace { -void SetMessageVariables( - const FieldDescriptor* descriptor, - absl::flat_hash_map* variables, - const Options& options) { - (*variables)["type"] = ClassName(descriptor->message_type(), false); - (*variables)["full_name"] = descriptor->full_name(); - - const FieldDescriptor* key = descriptor->message_type()->map_key(); - const FieldDescriptor* val = descriptor->message_type()->map_value(); - (*variables)["key_cpp"] = PrimitiveTypeName(options, key->cpp_type()); +using Sub = ::google::protobuf::io::Printer::Sub; + +std::vector Vars(const FieldDescriptor* field, const Options& opts, + bool lite) { + const auto* key = field->message_type()->map_key(); + const auto* val = field->message_type()->map_value(); + + std::string key_type = PrimitiveTypeName(opts, key->cpp_type()); + std::string val_type; switch (val->cpp_type()) { case FieldDescriptor::CPPTYPE_MESSAGE: - (*variables)["val_cpp"] = FieldMessageTypeName(val, options); + val_type = FieldMessageTypeName(val, opts); break; case FieldDescriptor::CPPTYPE_ENUM: - (*variables)["val_cpp"] = ClassName(val->enum_type(), true); + val_type = ClassName(val->enum_type(), true); break; default: - (*variables)["val_cpp"] = PrimitiveTypeName(options, val->cpp_type()); - } - (*variables)["key_wire_type"] = absl::StrCat( - "TYPE_", absl::AsciiStrToUpper(DeclaredTypeMethodName(key->type()))); - (*variables)["val_wire_type"] = absl::StrCat( - "TYPE_", absl::AsciiStrToUpper(DeclaredTypeMethodName(val->type()))); - (*variables)["map_classname"] = ClassName(descriptor->message_type(), false); - (*variables)["number"] = absl::StrCat(descriptor->number()); - (*variables)["tag"] = absl::StrCat(internal::WireFormat::MakeTag(descriptor)); - - if (HasDescriptorMethods(descriptor->file(), options)) { - (*variables)["lite"] = ""; - } else { - (*variables)["lite"] = "Lite"; + val_type = PrimitiveTypeName(opts, val->cpp_type()); + break; } + + return { + {"Map", absl::Substitute("::PROTOBUF_NAMESPACE_ID::Map<$0, $1>", key_type, + val_type)}, + {"Entry", ClassName(field->message_type(), false)}, + {"Key", PrimitiveTypeName(opts, key->cpp_type())}, + {"Val", val_type}, + {"MapField", lite ? "MapFieldLite" : "MapField"}, + }; } -class MapFieldGenerator : public FieldGeneratorBase { +class Map : public FieldGeneratorBase { public: - MapFieldGenerator(const FieldDescriptor* descriptor, const Options& options, - MessageSCCAnalyzer* scc_analyzer); - ~MapFieldGenerator() override = default; - - // implements FieldGeneratorBase --------------------------------------- - void GeneratePrivateMembers(io::Printer* printer) const override; - void GenerateAccessorDeclarations(io::Printer* printer) const override; - void GenerateInlineAccessorDefinitions(io::Printer* printer) const override; - void GenerateClearingCode(io::Printer* printer) const override; - void GenerateMergingCode(io::Printer* printer) const override; - void GenerateSwappingCode(io::Printer* printer) const override; - void GenerateConstructorCode(io::Printer* printer) const override {} - void GenerateCopyConstructorCode(io::Printer* printer) const override; - void GenerateSerializeWithCachedSizesToArray( - io::Printer* printer) const override; - void GenerateByteSize(io::Printer* printer) const override; - void GenerateIsInitialized(io::Printer* printer) const override; - void GenerateConstexprAggregateInitializer( - io::Printer* printer) const override; - void GenerateCopyAggregateInitializer(io::Printer* printer) const override; - void GenerateAggregateInitializer(io::Printer* printer) const override; - void GenerateDestructorCode(io::Printer* printer) const override; - - private: - bool has_required_fields_; -}; + Map(const FieldDescriptor* field, const Options& opts, + MessageSCCAnalyzer* scc) + : FieldGeneratorBase(field, opts), + field_(field), + key_(field->message_type()->map_key()), + val_(field->message_type()->map_value()), + opts_(&opts), + has_required_(scc->HasRequiredFields(field->message_type())), + lite_(!HasDescriptorMethods(field->file(), opts)) {} + ~Map() override = default; -MapFieldGenerator::MapFieldGenerator(const FieldDescriptor* descriptor, - const Options& options, - MessageSCCAnalyzer* scc_analyzer) - : FieldGeneratorBase(descriptor, options), - has_required_fields_( - scc_analyzer->HasRequiredFields(descriptor->message_type())) { - SetMessageVariables(descriptor, &variables_, options); -} + std::vector MakeVars() const override { + return Vars(field_, *opts_, lite_); + } -void MapFieldGenerator::GeneratePrivateMembers(io::Printer* printer) const { - Formatter format(printer, variables_); - format( - "::$proto_ns$::internal::MapField$lite$<\n" - " $map_classname$,\n" - " $key_cpp$, $val_cpp$,\n" - " ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n" - " ::$proto_ns$::internal::WireFormatLite::$val_wire_type$> " - "$name$_;\n"); -} + void GenerateClearingCode(io::Printer* p) const override { + p->Emit(R"cc( + $field_$.Clear(); + )cc"); + } -void MapFieldGenerator::GenerateAccessorDeclarations( - io::Printer* printer) const { - Formatter format(printer, variables_); - format( - "private:\n" - "const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n" - " ${1$_internal_$name$$}$() const;\n" - "::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n" - " ${1$_internal_mutable_$name$$}$();\n" - "public:\n" - "$deprecated_attr$const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n" - " ${1$$name$$}$() const;\n", - descriptor_); - format( - "$deprecated_attr$::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n" - " ${1$mutable_$name$$}$();\n", - std::make_tuple(descriptor_, GeneratedCodeInfo::Annotation::ALIAS)); -} + void GenerateMergingCode(io::Printer* p) const override { + p->Emit(R"cc( + _this->$field_$.MergeFrom(from.$field_$); + )cc"); + } -void MapFieldGenerator::GenerateInlineAccessorDefinitions( - io::Printer* printer) const { - Formatter format(printer, variables_); - format( - "inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n" - "$classname$::_internal_$name$() const {\n" - " return $field$.GetMap();\n" - "}\n" - "inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n" - "$classname$::$name$() const {\n" - "$annotate_get$" - " // @@protoc_insertion_point(field_map:$full_name$)\n" - " return _internal_$name$();\n" - "}\n" - "inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n" - "$classname$::_internal_mutable_$name$() {\n" - "$PrepareSplitMessageForWrite$" - " return $field$.MutableMap();\n" - "}\n" - "inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n" - "$classname$::mutable_$name$() {\n" - "$annotate_mutable$" - " // @@protoc_insertion_point(field_mutable_map:$full_name$)\n" - " return _internal_mutable_$name$();\n" - "}\n"); -} + void GenerateSwappingCode(io::Printer* p) const override { + p->Emit(R"cc( + $field_$.InternalSwap(&other->$field_$); + )cc"); + } -void MapFieldGenerator::GenerateClearingCode(io::Printer* printer) const { - Formatter format(printer, variables_); - format("$field$.Clear();\n"); -} + void GenerateCopyConstructorCode(io::Printer* p) const override { + GenerateConstructorCode(p); + GenerateMergingCode(p); + } -void MapFieldGenerator::GenerateMergingCode(io::Printer* printer) const { - Formatter format(printer, variables_); - format("_this->$field$.MergeFrom(from.$field$);\n"); -} + void GenerateIsInitialized(io::Printer* p) const override { + if (!has_required_) return; -void MapFieldGenerator::GenerateSwappingCode(io::Printer* printer) const { - Formatter format(printer, variables_); - format("$field$.InternalSwap(&other->$field$);\n"); -} + p->Emit(R"cc( + if (!$pbi$::AllAreInitialized($field_$)) { + return false; + } + )cc"); + } -void MapFieldGenerator::GenerateCopyConstructorCode( - io::Printer* printer) const { - GenerateConstructorCode(printer); - GenerateMergingCode(printer); -} + void GenerateConstexprAggregateInitializer(io::Printer* p) const override { + p->Emit(R"cc(/* decltype($field_$) */ {})cc"); + } -static void GenerateSerializationLoop(Formatter& format, bool string_key, - bool string_value, - bool is_deterministic) { - if (is_deterministic) { - format( - "for (const auto& entry : " - "::_pbi::MapSorter$1$(map_field)) {\n", - (string_key ? "Ptr" : "Flat")); - } else { - format("for (const auto& entry : map_field) {\n"); + void GenerateCopyAggregateInitializer(io::Printer* p) const override { + // MapField has no move constructor, which prevents explicit aggregate + // initialization pre-C++17. + p->Emit(R"cc(/* decltype($field_$) */ {})cc"); } - { - auto loop_scope = format.ScopedIndent(); - format( - "target = WireHelper::InternalSerialize($number$, " - "entry.first, entry.second, target, stream);\n"); - - if (string_key || string_value) { - format("check_utf8(entry);\n"); + + void GenerateAggregateInitializer(io::Printer* p) const override { + if (ShouldSplit(field_, *opts_)) { + p->Emit(R"cc( + /* decltype($Msg$::Split::$name$_) */ { + $pbi$::ArenaInitialized(), arena + } + )cc"); + return; } + + p->Emit(R"cc( + /* decltype($field_$) */ { $pbi$::ArenaInitialized(), arena } + )cc"); } - format("}\n"); -} -void MapFieldGenerator::GenerateSerializeWithCachedSizesToArray( - io::Printer* printer) const { - Formatter format(printer, variables_); - format("if (!this->_internal_$name$().empty()) {\n"); - format.Indent(); - const FieldDescriptor* key_field = descriptor_->message_type()->map_key(); - const FieldDescriptor* value_field = descriptor_->message_type()->map_value(); - const bool string_key = key_field->type() == FieldDescriptor::TYPE_STRING; - const bool string_value = value_field->type() == FieldDescriptor::TYPE_STRING; - - format( - "using MapType = ::_pb::Map<$key_cpp$, $val_cpp$>;\n" - "using WireHelper = $map_classname$::Funcs;\n" - "const auto& map_field = this->_internal_$name$();\n"); - bool utf8_check = string_key || string_value; - if (utf8_check) { - format("auto check_utf8 = [](const MapType::value_type& entry) {\n"); - { - auto check_scope = format.ScopedIndent(); - // p may be unused when GetUtf8CheckMode evaluates to kNone, - // thus disabling the validation. - format("(void)entry;\n"); - if (string_key) { - GenerateUtf8CheckCodeForString( - key_field, options_, false, - "entry.first.data(), static_cast(entry.first.length()),\n", - format); - } - if (string_value) { - GenerateUtf8CheckCodeForString( - value_field, options_, false, - "entry.second.data(), static_cast(entry.second.length()),\n", - format); - } + void GenerateConstructorCode(io::Printer* p) const override {} + + void GenerateDestructorCode(io::Printer* p) const override { + if (ShouldSplit(field_, *opts_)) { + p->Emit(R"cc( + $cached_split_ptr$->$name$_.~$MapField$(); + )cc"); + return; } - format("};\n"); - } - format( - "\n" - "if (stream->IsSerializationDeterministic() && " - "map_field.size() > 1) {\n"); - { - auto deterministic_scope = format.ScopedIndent(); - GenerateSerializationLoop(format, string_key, string_value, true); - } - format("} else {\n"); - { - auto map_order_scope = format.ScopedIndent(); - GenerateSerializationLoop(format, string_key, string_value, false); + p->Emit(R"cc( + $field_$.~$MapField$(); + )cc"); } - format("}\n"); - format.Outdent(); - format("}\n"); -} -void MapFieldGenerator::GenerateByteSize(io::Printer* printer) const { - Formatter format(printer, variables_); - format( - "total_size += $tag_size$ *\n" - " " - "::$proto_ns$::internal::FromIntSize(this->_internal_$name$_size());\n" - "for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n" - " it = this->_internal_$name$().begin();\n" - " it != this->_internal_$name$().end(); ++it) {\n" - " total_size += $map_classname$::Funcs::ByteSizeLong(it->first, " - "it->second);\n" - "}\n"); -} + void GeneratePrivateMembers(io::Printer* p) const override; + void GenerateAccessorDeclarations(io::Printer* p) const override; + void GenerateInlineAccessorDefinitions(io::Printer* p) const override; + void GenerateSerializeWithCachedSizesToArray(io::Printer* p) const override; + void GenerateByteSize(io::Printer* p) const override; -void MapFieldGenerator::GenerateIsInitialized(io::Printer* printer) const { - if (!has_required_fields_) return; + private: + const FieldDescriptor* field_; + const FieldDescriptor* key_; + const FieldDescriptor* val_; + const Options* opts_; + bool has_required_; + bool lite_; +}; - Formatter format(printer, variables_); - format( - "if (!::$proto_ns$::internal::AllAreInitialized($field$)) return " - "false;\n"); +void Map::GeneratePrivateMembers(io::Printer* p) const { + p->Emit({{"kKeyType", + absl::AsciiStrToUpper(DeclaredTypeMethodName(key_->type()))}, + {"kValType", + absl::AsciiStrToUpper(DeclaredTypeMethodName(val_->type()))}}, + R"cc( + $pbi$::$MapField$<$Entry$, $Key$, $Val$, + $pbi$::WireFormatLite::TYPE_$kKeyType$, + $pbi$::WireFormatLite::TYPE_$kValType$> + $name$_; + )cc"); } -void MapFieldGenerator::GenerateConstexprAggregateInitializer( - io::Printer* printer) const { - Formatter format(printer, variables_); - format("/*decltype($field$)*/{}"); -} +void Map::GenerateAccessorDeclarations(io::Printer* p) const { + auto v1 = p->WithVars( + AnnotatedAccessors(field_, {"", "_internal_", "_internal_mutable_"})); + auto v2 = p->WithVars(AnnotatedAccessors(field_, {"mutable_"}, + io::AnnotationCollector::kAlias)); + p->Emit(R"cc( + $DEPRECATED$ const $Map$& $name$() const; + $DEPRECATED$ $Map$* $mutable_name$(); + + private: + const $Map$& $_internal_name$() const; + $Map$* $_internal_mutable_name$(); -void MapFieldGenerator::GenerateCopyAggregateInitializer( - io::Printer* printer) const { - Formatter format(printer, variables_); - // MapField has no move constructor, which prevents explicit aggregate - // initialization pre-C++17. - format("/*decltype($field$)*/{}"); + public: + )cc"); } -void MapFieldGenerator::GenerateAggregateInitializer( - io::Printer* printer) const { - Formatter format(printer, variables_); - if (ShouldSplit(descriptor_, options_)) { - format( - "/*decltype($classname$::Split::$name$_)*/" - "{::_pbi::ArenaInitialized(), arena}"); - return; - } - // MapField has no move constructor. - format("/*decltype($field$)*/{::_pbi::ArenaInitialized(), arena}"); +void Map::GenerateInlineAccessorDefinitions(io::Printer* p) const { + p->Emit(R"cc( + inline const $Map$& $Msg$::_internal_$name$() const { + return $field_$.GetMap(); + } + inline const $Map$& $Msg$::$name$() const { + $annotate_get$; + // @@protoc_insertion_point(field_map:$pkg.Msg.field$) + return _internal_$name$(); + } + inline $Map$* $Msg$::_internal_mutable_$name$() { + $PrepareSplitMessageForWrite$; + return $field_$.MutableMap(); + } + inline $Map$* $Msg$::mutable_$name$() { + $annotate_mutable$; + // @@protoc_insertion_point(field_mutable_map:$pkg.Msg.field$) + return _internal_mutable_$name$(); + } + )cc"); } -void MapFieldGenerator::GenerateDestructorCode(io::Printer* printer) const { - Formatter format(printer, variables_); - if (ShouldSplit(descriptor_, options_)) { - format("$cached_split_ptr$->$name$_.~MapField$lite$();\n"); - return; - } - format("$field$.~MapField$lite$();\n"); +void Map::GenerateSerializeWithCachedSizesToArray(io::Printer* p) const { + bool string_key = key_->type() == FieldDescriptor::TYPE_STRING; + bool string_val = val_->type() == FieldDescriptor::TYPE_STRING; + + p->Emit( + { + {"Sorter", string_key ? "MapSorterPtr" : "MapSorterFlat"}, + {"CheckUtf8", + [&] { + if (string_key) { + GenerateUtf8CheckCodeForString( + p, key_, *opts_, /*for_parse=*/false, + "entry.first.data(), " + "static_cast(entry.first.length()),\n"); + } + if (string_val) { + GenerateUtf8CheckCodeForString( + p, val_, *opts_, /*for_parse=*/false, + "entry.second.data(), " + "static_cast(entry.second.length()),\n"); + } + }}, + }, + R"cc( + if (!_internal_$name$().empty()) { + using MapType = $Map$; + using WireHelper = $Entry$::Funcs; + const auto& field = _internal_$name$(); + + if (stream->IsSerializationDeterministic() && field.size() > 1) { + for (const auto& entry : $pbi$::$Sorter$(field)) { + target = WireHelper::InternalSerialize( + $number$, entry.first, entry.second, target, stream); + $CheckUtf8$; + } + } else { + for (const auto& entry : field) { + target = WireHelper::InternalSerialize( + $number$, entry.first, entry.second, target, stream); + $CheckUtf8$; + } + } + } + )cc"); } +void Map::GenerateByteSize(io::Printer* p) const { + p->Emit(R"cc( + total_size += $kTagBytes$ * $pbi$::FromIntSize(_internal_$name$_size()); + for (const auto& entry : _internal_$name$()) { + total_size += $Entry$::Funcs::ByteSizeLong(entry.first, entry.second); + } + )cc"); +} } // namespace std::unique_ptr MakeMapGenerator( const FieldDescriptor* desc, const Options& options, MessageSCCAnalyzer* scc) { - return std::make_unique(desc, options, scc); + return std::make_unique(desc, options, scc); } } // namespace cpp