Refactor map_field.cc to use Emit().

PiperOrigin-RevId: 528634293
pull/12630/head
Matt Kulukundis 2 years ago committed by Copybara-Service
parent 874e291c00
commit 80b0609666
  1. 471
      src/google/protobuf/compiler/cpp/field_generators/map_field.cc

@ -30,317 +30,270 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <tuple> #include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/strings/ascii.h" #include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h"
#include "google/protobuf/compiler/cpp/field_generators/generators.h" #include "absl/strings/substitute.h"
#include "google/protobuf/compiler/cpp/field.h"
#include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/cpp/helpers.h"
#include "google/protobuf/wire_format.h" #include "google/protobuf/compiler/cpp/options.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/printer.h"
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
namespace cpp { namespace cpp {
namespace { namespace {
void SetMessageVariables( using Sub = ::google::protobuf::io::Printer::Sub;
const FieldDescriptor* descriptor,
absl::flat_hash_map<absl::string_view, std::string>* variables, std::vector<Sub> Vars(const FieldDescriptor* field, const Options& opts,
const Options& options) { bool lite) {
(*variables)["type"] = ClassName(descriptor->message_type(), false); const auto* key = field->message_type()->map_key();
(*variables)["full_name"] = descriptor->full_name(); const auto* val = field->message_type()->map_value();
const FieldDescriptor* key = descriptor->message_type()->map_key(); std::string key_type = PrimitiveTypeName(opts, key->cpp_type());
const FieldDescriptor* val = descriptor->message_type()->map_value(); std::string val_type;
(*variables)["key_cpp"] = PrimitiveTypeName(options, key->cpp_type());
switch (val->cpp_type()) { switch (val->cpp_type()) {
case FieldDescriptor::CPPTYPE_MESSAGE: case FieldDescriptor::CPPTYPE_MESSAGE:
(*variables)["val_cpp"] = FieldMessageTypeName(val, options); val_type = FieldMessageTypeName(val, opts);
break; break;
case FieldDescriptor::CPPTYPE_ENUM: case FieldDescriptor::CPPTYPE_ENUM:
(*variables)["val_cpp"] = ClassName(val->enum_type(), true); val_type = ClassName(val->enum_type(), true);
break; break;
default: default:
(*variables)["val_cpp"] = PrimitiveTypeName(options, val->cpp_type()); val_type = PrimitiveTypeName(opts, val->cpp_type());
} break;
(*variables)["key_wire_type"] = absl::StrCat(
"TYPE_", absl::AsciiStrToUpper(DeclaredTypeMethodName(key->type())));
(*variables)["val_wire_type"] = absl::StrCat(
"TYPE_", absl::AsciiStrToUpper(DeclaredTypeMethodName(val->type())));
(*variables)["map_classname"] = ClassName(descriptor->message_type(), false);
(*variables)["number"] = absl::StrCat(descriptor->number());
(*variables)["tag"] = absl::StrCat(internal::WireFormat::MakeTag(descriptor));
if (HasDescriptorMethods(descriptor->file(), options)) {
(*variables)["lite"] = "";
} else {
(*variables)["lite"] = "Lite";
} }
return {
{"Map", absl::Substitute("::PROTOBUF_NAMESPACE_ID::Map<$0, $1>", key_type,
val_type)},
{"Entry", ClassName(field->message_type(), false)},
{"Key", PrimitiveTypeName(opts, key->cpp_type())},
{"Val", val_type},
{"MapField", lite ? "MapFieldLite" : "MapField"},
};
} }
class MapFieldGenerator : public FieldGeneratorBase { class Map : public FieldGeneratorBase {
public: public:
MapFieldGenerator(const FieldDescriptor* descriptor, const Options& options, Map(const FieldDescriptor* field, const Options& opts,
MessageSCCAnalyzer* scc_analyzer); MessageSCCAnalyzer* scc)
~MapFieldGenerator() override = default; : FieldGeneratorBase(field, opts),
field_(field),
// implements FieldGeneratorBase --------------------------------------- key_(field->message_type()->map_key()),
void GeneratePrivateMembers(io::Printer* printer) const override; val_(field->message_type()->map_value()),
void GenerateAccessorDeclarations(io::Printer* printer) const override; opts_(&opts),
void GenerateInlineAccessorDefinitions(io::Printer* printer) const override; has_required_(scc->HasRequiredFields(field->message_type())),
void GenerateClearingCode(io::Printer* printer) const override; lite_(!HasDescriptorMethods(field->file(), opts)) {}
void GenerateMergingCode(io::Printer* printer) const override; ~Map() override = default;
void GenerateSwappingCode(io::Printer* printer) const override;
void GenerateConstructorCode(io::Printer* printer) const override {}
void GenerateCopyConstructorCode(io::Printer* printer) const override;
void GenerateSerializeWithCachedSizesToArray(
io::Printer* printer) const override;
void GenerateByteSize(io::Printer* printer) const override;
void GenerateIsInitialized(io::Printer* printer) const override;
void GenerateConstexprAggregateInitializer(
io::Printer* printer) const override;
void GenerateCopyAggregateInitializer(io::Printer* printer) const override;
void GenerateAggregateInitializer(io::Printer* printer) const override;
void GenerateDestructorCode(io::Printer* printer) const override;
private:
bool has_required_fields_;
};
MapFieldGenerator::MapFieldGenerator(const FieldDescriptor* descriptor, std::vector<Sub> MakeVars() const override {
const Options& options, return Vars(field_, *opts_, lite_);
MessageSCCAnalyzer* scc_analyzer) }
: FieldGeneratorBase(descriptor, options),
has_required_fields_(
scc_analyzer->HasRequiredFields(descriptor->message_type())) {
SetMessageVariables(descriptor, &variables_, options);
}
void MapFieldGenerator::GeneratePrivateMembers(io::Printer* printer) const { void GenerateClearingCode(io::Printer* p) const override {
Formatter format(printer, variables_); p->Emit(R"cc(
format( $field_$.Clear();
"::$proto_ns$::internal::MapField$lite$<\n" )cc");
" $map_classname$,\n" }
" $key_cpp$, $val_cpp$,\n"
" ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n"
" ::$proto_ns$::internal::WireFormatLite::$val_wire_type$> "
"$name$_;\n");
}
void MapFieldGenerator::GenerateAccessorDeclarations( void GenerateMergingCode(io::Printer* p) const override {
io::Printer* printer) const { p->Emit(R"cc(
Formatter format(printer, variables_); _this->$field_$.MergeFrom(from.$field_$);
format( )cc");
"private:\n" }
"const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
" ${1$_internal_$name$$}$() const;\n"
"::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
" ${1$_internal_mutable_$name$$}$();\n"
"public:\n"
"$deprecated_attr$const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
" ${1$$name$$}$() const;\n",
descriptor_);
format(
"$deprecated_attr$::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
" ${1$mutable_$name$$}$();\n",
std::make_tuple(descriptor_, GeneratedCodeInfo::Annotation::ALIAS));
}
void MapFieldGenerator::GenerateInlineAccessorDefinitions( void GenerateSwappingCode(io::Printer* p) const override {
io::Printer* printer) const { p->Emit(R"cc(
Formatter format(printer, variables_); $field_$.InternalSwap(&other->$field_$);
format( )cc");
"inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n" }
"$classname$::_internal_$name$() const {\n"
" return $field$.GetMap();\n"
"}\n"
"inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
"$classname$::$name$() const {\n"
"$annotate_get$"
" // @@protoc_insertion_point(field_map:$full_name$)\n"
" return _internal_$name$();\n"
"}\n"
"inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
"$classname$::_internal_mutable_$name$() {\n"
"$PrepareSplitMessageForWrite$"
" return $field$.MutableMap();\n"
"}\n"
"inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
"$classname$::mutable_$name$() {\n"
"$annotate_mutable$"
" // @@protoc_insertion_point(field_mutable_map:$full_name$)\n"
" return _internal_mutable_$name$();\n"
"}\n");
}
void MapFieldGenerator::GenerateClearingCode(io::Printer* printer) const { void GenerateCopyConstructorCode(io::Printer* p) const override {
Formatter format(printer, variables_); GenerateConstructorCode(p);
format("$field$.Clear();\n"); GenerateMergingCode(p);
} }
void MapFieldGenerator::GenerateMergingCode(io::Printer* printer) const { void GenerateIsInitialized(io::Printer* p) const override {
Formatter format(printer, variables_); if (!has_required_) return;
format("_this->$field$.MergeFrom(from.$field$);\n");
}
void MapFieldGenerator::GenerateSwappingCode(io::Printer* printer) const { p->Emit(R"cc(
Formatter format(printer, variables_); if (!$pbi$::AllAreInitialized($field_$)) {
format("$field$.InternalSwap(&other->$field$);\n"); return false;
} }
)cc");
}
void MapFieldGenerator::GenerateCopyConstructorCode( void GenerateConstexprAggregateInitializer(io::Printer* p) const override {
io::Printer* printer) const { p->Emit(R"cc(/* decltype($field_$) */ {})cc");
GenerateConstructorCode(printer); }
GenerateMergingCode(printer);
}
static void GenerateSerializationLoop(Formatter& format, bool string_key, void GenerateCopyAggregateInitializer(io::Printer* p) const override {
bool string_value, // MapField has no move constructor, which prevents explicit aggregate
bool is_deterministic) { // initialization pre-C++17.
if (is_deterministic) { p->Emit(R"cc(/* decltype($field_$) */ {})cc");
format(
"for (const auto& entry : "
"::_pbi::MapSorter$1$<MapType>(map_field)) {\n",
(string_key ? "Ptr" : "Flat"));
} else {
format("for (const auto& entry : map_field) {\n");
} }
{
auto loop_scope = format.ScopedIndent(); void GenerateAggregateInitializer(io::Printer* p) const override {
format( if (ShouldSplit(field_, *opts_)) {
"target = WireHelper::InternalSerialize($number$, " p->Emit(R"cc(
"entry.first, entry.second, target, stream);\n"); /* decltype($Msg$::Split::$name$_) */ {
$pbi$::ArenaInitialized(), arena
if (string_key || string_value) { }
format("check_utf8(entry);\n"); )cc");
return;
} }
p->Emit(R"cc(
/* decltype($field_$) */ { $pbi$::ArenaInitialized(), arena }
)cc");
} }
format("}\n");
}
void MapFieldGenerator::GenerateSerializeWithCachedSizesToArray( void GenerateConstructorCode(io::Printer* p) const override {}
io::Printer* printer) const {
Formatter format(printer, variables_); void GenerateDestructorCode(io::Printer* p) const override {
format("if (!this->_internal_$name$().empty()) {\n"); if (ShouldSplit(field_, *opts_)) {
format.Indent(); p->Emit(R"cc(
const FieldDescriptor* key_field = descriptor_->message_type()->map_key(); $cached_split_ptr$->$name$_.~$MapField$();
const FieldDescriptor* value_field = descriptor_->message_type()->map_value(); )cc");
const bool string_key = key_field->type() == FieldDescriptor::TYPE_STRING; return;
const bool string_value = value_field->type() == FieldDescriptor::TYPE_STRING;
format(
"using MapType = ::_pb::Map<$key_cpp$, $val_cpp$>;\n"
"using WireHelper = $map_classname$::Funcs;\n"
"const auto& map_field = this->_internal_$name$();\n");
bool utf8_check = string_key || string_value;
if (utf8_check) {
format("auto check_utf8 = [](const MapType::value_type& entry) {\n");
{
auto check_scope = format.ScopedIndent();
// p may be unused when GetUtf8CheckMode evaluates to kNone,
// thus disabling the validation.
format("(void)entry;\n");
if (string_key) {
GenerateUtf8CheckCodeForString(
key_field, options_, false,
"entry.first.data(), static_cast<int>(entry.first.length()),\n",
format);
}
if (string_value) {
GenerateUtf8CheckCodeForString(
value_field, options_, false,
"entry.second.data(), static_cast<int>(entry.second.length()),\n",
format);
}
} }
format("};\n");
}
format( p->Emit(R"cc(
"\n" $field_$.~$MapField$();
"if (stream->IsSerializationDeterministic() && " )cc");
"map_field.size() > 1) {\n");
{
auto deterministic_scope = format.ScopedIndent();
GenerateSerializationLoop(format, string_key, string_value, true);
}
format("} else {\n");
{
auto map_order_scope = format.ScopedIndent();
GenerateSerializationLoop(format, string_key, string_value, false);
} }
format("}\n");
format.Outdent();
format("}\n");
}
void MapFieldGenerator::GenerateByteSize(io::Printer* printer) const { void GeneratePrivateMembers(io::Printer* p) const override;
Formatter format(printer, variables_); void GenerateAccessorDeclarations(io::Printer* p) const override;
format( void GenerateInlineAccessorDefinitions(io::Printer* p) const override;
"total_size += $tag_size$ *\n" void GenerateSerializeWithCachedSizesToArray(io::Printer* p) const override;
" " void GenerateByteSize(io::Printer* p) const override;
"::$proto_ns$::internal::FromIntSize(this->_internal_$name$_size());\n"
"for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n"
" it = this->_internal_$name$().begin();\n"
" it != this->_internal_$name$().end(); ++it) {\n"
" total_size += $map_classname$::Funcs::ByteSizeLong(it->first, "
"it->second);\n"
"}\n");
}
void MapFieldGenerator::GenerateIsInitialized(io::Printer* printer) const { private:
if (!has_required_fields_) return; const FieldDescriptor* field_;
const FieldDescriptor* key_;
const FieldDescriptor* val_;
const Options* opts_;
bool has_required_;
bool lite_;
};
Formatter format(printer, variables_); void Map::GeneratePrivateMembers(io::Printer* p) const {
format( p->Emit({{"kKeyType",
"if (!::$proto_ns$::internal::AllAreInitialized($field$)) return " absl::AsciiStrToUpper(DeclaredTypeMethodName(key_->type()))},
"false;\n"); {"kValType",
absl::AsciiStrToUpper(DeclaredTypeMethodName(val_->type()))}},
R"cc(
$pbi$::$MapField$<$Entry$, $Key$, $Val$,
$pbi$::WireFormatLite::TYPE_$kKeyType$,
$pbi$::WireFormatLite::TYPE_$kValType$>
$name$_;
)cc");
} }
void MapFieldGenerator::GenerateConstexprAggregateInitializer( void Map::GenerateAccessorDeclarations(io::Printer* p) const {
io::Printer* printer) const { auto v1 = p->WithVars(
Formatter format(printer, variables_); AnnotatedAccessors(field_, {"", "_internal_", "_internal_mutable_"}));
format("/*decltype($field$)*/{}"); auto v2 = p->WithVars(AnnotatedAccessors(field_, {"mutable_"},
} io::AnnotationCollector::kAlias));
p->Emit(R"cc(
$DEPRECATED$ const $Map$& $name$() const;
$DEPRECATED$ $Map$* $mutable_name$();
private:
const $Map$& $_internal_name$() const;
$Map$* $_internal_mutable_name$();
void MapFieldGenerator::GenerateCopyAggregateInitializer( public:
io::Printer* printer) const { )cc");
Formatter format(printer, variables_);
// MapField has no move constructor, which prevents explicit aggregate
// initialization pre-C++17.
format("/*decltype($field$)*/{}");
} }
void MapFieldGenerator::GenerateAggregateInitializer( void Map::GenerateInlineAccessorDefinitions(io::Printer* p) const {
io::Printer* printer) const { p->Emit(R"cc(
Formatter format(printer, variables_); inline const $Map$& $Msg$::_internal_$name$() const {
if (ShouldSplit(descriptor_, options_)) { return $field_$.GetMap();
format( }
"/*decltype($classname$::Split::$name$_)*/" inline const $Map$& $Msg$::$name$() const {
"{::_pbi::ArenaInitialized(), arena}"); $annotate_get$;
return; // @@protoc_insertion_point(field_map:$pkg.Msg.field$)
} return _internal_$name$();
// MapField has no move constructor. }
format("/*decltype($field$)*/{::_pbi::ArenaInitialized(), arena}"); inline $Map$* $Msg$::_internal_mutable_$name$() {
$PrepareSplitMessageForWrite$;
return $field_$.MutableMap();
}
inline $Map$* $Msg$::mutable_$name$() {
$annotate_mutable$;
// @@protoc_insertion_point(field_mutable_map:$pkg.Msg.field$)
return _internal_mutable_$name$();
}
)cc");
} }
void MapFieldGenerator::GenerateDestructorCode(io::Printer* printer) const { void Map::GenerateSerializeWithCachedSizesToArray(io::Printer* p) const {
Formatter format(printer, variables_); bool string_key = key_->type() == FieldDescriptor::TYPE_STRING;
if (ShouldSplit(descriptor_, options_)) { bool string_val = val_->type() == FieldDescriptor::TYPE_STRING;
format("$cached_split_ptr$->$name$_.~MapField$lite$();\n");
return; p->Emit(
} {
format("$field$.~MapField$lite$();\n"); {"Sorter", string_key ? "MapSorterPtr" : "MapSorterFlat"},
{"CheckUtf8",
[&] {
if (string_key) {
GenerateUtf8CheckCodeForString(
p, key_, *opts_, /*for_parse=*/false,
"entry.first.data(), "
"static_cast<int>(entry.first.length()),\n");
}
if (string_val) {
GenerateUtf8CheckCodeForString(
p, val_, *opts_, /*for_parse=*/false,
"entry.second.data(), "
"static_cast<int>(entry.second.length()),\n");
}
}},
},
R"cc(
if (!_internal_$name$().empty()) {
using MapType = $Map$;
using WireHelper = $Entry$::Funcs;
const auto& field = _internal_$name$();
if (stream->IsSerializationDeterministic() && field.size() > 1) {
for (const auto& entry : $pbi$::$Sorter$<MapType>(field)) {
target = WireHelper::InternalSerialize(
$number$, entry.first, entry.second, target, stream);
$CheckUtf8$;
}
} else {
for (const auto& entry : field) {
target = WireHelper::InternalSerialize(
$number$, entry.first, entry.second, target, stream);
$CheckUtf8$;
}
}
}
)cc");
} }
void Map::GenerateByteSize(io::Printer* p) const {
p->Emit(R"cc(
total_size += $kTagBytes$ * $pbi$::FromIntSize(_internal_$name$_size());
for (const auto& entry : _internal_$name$()) {
total_size += $Entry$::Funcs::ByteSizeLong(entry.first, entry.second);
}
)cc");
}
} // namespace } // namespace
std::unique_ptr<FieldGeneratorBase> MakeMapGenerator( std::unique_ptr<FieldGeneratorBase> MakeMapGenerator(
const FieldDescriptor* desc, const Options& options, const FieldDescriptor* desc, const Options& options,
MessageSCCAnalyzer* scc) { MessageSCCAnalyzer* scc) {
return std::make_unique<MapFieldGenerator>(desc, options, scc); return std::make_unique<Map>(desc, options, scc);
} }
} // namespace cpp } // namespace cpp

Loading…
Cancel
Save