Refactor map_field.cc to use Emit().

PiperOrigin-RevId: 528634293
pull/12630/head
Matt Kulukundis 2 years ago committed by Copybara-Service
parent 874e291c00
commit 80b0609666
  1. 471
      src/google/protobuf/compiler/cpp/field_generators/map_field.cc

@ -30,317 +30,270 @@
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "google/protobuf/compiler/cpp/field_generators/generators.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "google/protobuf/compiler/cpp/field.h"
#include "google/protobuf/compiler/cpp/helpers.h"
#include "google/protobuf/wire_format.h"
#include "google/protobuf/compiler/cpp/options.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/printer.h"
namespace google {
namespace protobuf {
namespace compiler {
namespace cpp {
namespace {
void SetMessageVariables(
const FieldDescriptor* descriptor,
absl::flat_hash_map<absl::string_view, std::string>* variables,
const Options& options) {
(*variables)["type"] = ClassName(descriptor->message_type(), false);
(*variables)["full_name"] = descriptor->full_name();
const FieldDescriptor* key = descriptor->message_type()->map_key();
const FieldDescriptor* val = descriptor->message_type()->map_value();
(*variables)["key_cpp"] = PrimitiveTypeName(options, key->cpp_type());
using Sub = ::google::protobuf::io::Printer::Sub;
std::vector<Sub> Vars(const FieldDescriptor* field, const Options& opts,
bool lite) {
const auto* key = field->message_type()->map_key();
const auto* val = field->message_type()->map_value();
std::string key_type = PrimitiveTypeName(opts, key->cpp_type());
std::string val_type;
switch (val->cpp_type()) {
case FieldDescriptor::CPPTYPE_MESSAGE:
(*variables)["val_cpp"] = FieldMessageTypeName(val, options);
val_type = FieldMessageTypeName(val, opts);
break;
case FieldDescriptor::CPPTYPE_ENUM:
(*variables)["val_cpp"] = ClassName(val->enum_type(), true);
val_type = ClassName(val->enum_type(), true);
break;
default:
(*variables)["val_cpp"] = PrimitiveTypeName(options, val->cpp_type());
}
(*variables)["key_wire_type"] = absl::StrCat(
"TYPE_", absl::AsciiStrToUpper(DeclaredTypeMethodName(key->type())));
(*variables)["val_wire_type"] = absl::StrCat(
"TYPE_", absl::AsciiStrToUpper(DeclaredTypeMethodName(val->type())));
(*variables)["map_classname"] = ClassName(descriptor->message_type(), false);
(*variables)["number"] = absl::StrCat(descriptor->number());
(*variables)["tag"] = absl::StrCat(internal::WireFormat::MakeTag(descriptor));
if (HasDescriptorMethods(descriptor->file(), options)) {
(*variables)["lite"] = "";
} else {
(*variables)["lite"] = "Lite";
val_type = PrimitiveTypeName(opts, val->cpp_type());
break;
}
return {
{"Map", absl::Substitute("::PROTOBUF_NAMESPACE_ID::Map<$0, $1>", key_type,
val_type)},
{"Entry", ClassName(field->message_type(), false)},
{"Key", PrimitiveTypeName(opts, key->cpp_type())},
{"Val", val_type},
{"MapField", lite ? "MapFieldLite" : "MapField"},
};
}
class MapFieldGenerator : public FieldGeneratorBase {
class Map : public FieldGeneratorBase {
public:
MapFieldGenerator(const FieldDescriptor* descriptor, const Options& options,
MessageSCCAnalyzer* scc_analyzer);
~MapFieldGenerator() override = default;
// implements FieldGeneratorBase ---------------------------------------
void GeneratePrivateMembers(io::Printer* printer) const override;
void GenerateAccessorDeclarations(io::Printer* printer) const override;
void GenerateInlineAccessorDefinitions(io::Printer* printer) const override;
void GenerateClearingCode(io::Printer* printer) const override;
void GenerateMergingCode(io::Printer* printer) const override;
void GenerateSwappingCode(io::Printer* printer) const override;
void GenerateConstructorCode(io::Printer* printer) const override {}
void GenerateCopyConstructorCode(io::Printer* printer) const override;
void GenerateSerializeWithCachedSizesToArray(
io::Printer* printer) const override;
void GenerateByteSize(io::Printer* printer) const override;
void GenerateIsInitialized(io::Printer* printer) const override;
void GenerateConstexprAggregateInitializer(
io::Printer* printer) const override;
void GenerateCopyAggregateInitializer(io::Printer* printer) const override;
void GenerateAggregateInitializer(io::Printer* printer) const override;
void GenerateDestructorCode(io::Printer* printer) const override;
private:
bool has_required_fields_;
};
Map(const FieldDescriptor* field, const Options& opts,
MessageSCCAnalyzer* scc)
: FieldGeneratorBase(field, opts),
field_(field),
key_(field->message_type()->map_key()),
val_(field->message_type()->map_value()),
opts_(&opts),
has_required_(scc->HasRequiredFields(field->message_type())),
lite_(!HasDescriptorMethods(field->file(), opts)) {}
~Map() override = default;
MapFieldGenerator::MapFieldGenerator(const FieldDescriptor* descriptor,
const Options& options,
MessageSCCAnalyzer* scc_analyzer)
: FieldGeneratorBase(descriptor, options),
has_required_fields_(
scc_analyzer->HasRequiredFields(descriptor->message_type())) {
SetMessageVariables(descriptor, &variables_, options);
}
std::vector<Sub> MakeVars() const override {
return Vars(field_, *opts_, lite_);
}
void MapFieldGenerator::GeneratePrivateMembers(io::Printer* printer) const {
Formatter format(printer, variables_);
format(
"::$proto_ns$::internal::MapField$lite$<\n"
" $map_classname$,\n"
" $key_cpp$, $val_cpp$,\n"
" ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n"
" ::$proto_ns$::internal::WireFormatLite::$val_wire_type$> "
"$name$_;\n");
}
void GenerateClearingCode(io::Printer* p) const override {
p->Emit(R"cc(
$field_$.Clear();
)cc");
}
void MapFieldGenerator::GenerateAccessorDeclarations(
io::Printer* printer) const {
Formatter format(printer, variables_);
format(
"private:\n"
"const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
" ${1$_internal_$name$$}$() const;\n"
"::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
" ${1$_internal_mutable_$name$$}$();\n"
"public:\n"
"$deprecated_attr$const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
" ${1$$name$$}$() const;\n",
descriptor_);
format(
"$deprecated_attr$::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
" ${1$mutable_$name$$}$();\n",
std::make_tuple(descriptor_, GeneratedCodeInfo::Annotation::ALIAS));
}
void GenerateMergingCode(io::Printer* p) const override {
p->Emit(R"cc(
_this->$field_$.MergeFrom(from.$field_$);
)cc");
}
void MapFieldGenerator::GenerateInlineAccessorDefinitions(
io::Printer* printer) const {
Formatter format(printer, variables_);
format(
"inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
"$classname$::_internal_$name$() const {\n"
" return $field$.GetMap();\n"
"}\n"
"inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
"$classname$::$name$() const {\n"
"$annotate_get$"
" // @@protoc_insertion_point(field_map:$full_name$)\n"
" return _internal_$name$();\n"
"}\n"
"inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
"$classname$::_internal_mutable_$name$() {\n"
"$PrepareSplitMessageForWrite$"
" return $field$.MutableMap();\n"
"}\n"
"inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
"$classname$::mutable_$name$() {\n"
"$annotate_mutable$"
" // @@protoc_insertion_point(field_mutable_map:$full_name$)\n"
" return _internal_mutable_$name$();\n"
"}\n");
}
void GenerateSwappingCode(io::Printer* p) const override {
p->Emit(R"cc(
$field_$.InternalSwap(&other->$field_$);
)cc");
}
void MapFieldGenerator::GenerateClearingCode(io::Printer* printer) const {
Formatter format(printer, variables_);
format("$field$.Clear();\n");
}
void GenerateCopyConstructorCode(io::Printer* p) const override {
GenerateConstructorCode(p);
GenerateMergingCode(p);
}
void MapFieldGenerator::GenerateMergingCode(io::Printer* printer) const {
Formatter format(printer, variables_);
format("_this->$field$.MergeFrom(from.$field$);\n");
}
void GenerateIsInitialized(io::Printer* p) const override {
if (!has_required_) return;
void MapFieldGenerator::GenerateSwappingCode(io::Printer* printer) const {
Formatter format(printer, variables_);
format("$field$.InternalSwap(&other->$field$);\n");
}
p->Emit(R"cc(
if (!$pbi$::AllAreInitialized($field_$)) {
return false;
}
)cc");
}
void MapFieldGenerator::GenerateCopyConstructorCode(
io::Printer* printer) const {
GenerateConstructorCode(printer);
GenerateMergingCode(printer);
}
void GenerateConstexprAggregateInitializer(io::Printer* p) const override {
p->Emit(R"cc(/* decltype($field_$) */ {})cc");
}
static void GenerateSerializationLoop(Formatter& format, bool string_key,
bool string_value,
bool is_deterministic) {
if (is_deterministic) {
format(
"for (const auto& entry : "
"::_pbi::MapSorter$1$<MapType>(map_field)) {\n",
(string_key ? "Ptr" : "Flat"));
} else {
format("for (const auto& entry : map_field) {\n");
void GenerateCopyAggregateInitializer(io::Printer* p) const override {
// MapField has no move constructor, which prevents explicit aggregate
// initialization pre-C++17.
p->Emit(R"cc(/* decltype($field_$) */ {})cc");
}
{
auto loop_scope = format.ScopedIndent();
format(
"target = WireHelper::InternalSerialize($number$, "
"entry.first, entry.second, target, stream);\n");
if (string_key || string_value) {
format("check_utf8(entry);\n");
void GenerateAggregateInitializer(io::Printer* p) const override {
if (ShouldSplit(field_, *opts_)) {
p->Emit(R"cc(
/* decltype($Msg$::Split::$name$_) */ {
$pbi$::ArenaInitialized(), arena
}
)cc");
return;
}
p->Emit(R"cc(
/* decltype($field_$) */ { $pbi$::ArenaInitialized(), arena }
)cc");
}
format("}\n");
}
void MapFieldGenerator::GenerateSerializeWithCachedSizesToArray(
io::Printer* printer) const {
Formatter format(printer, variables_);
format("if (!this->_internal_$name$().empty()) {\n");
format.Indent();
const FieldDescriptor* key_field = descriptor_->message_type()->map_key();
const FieldDescriptor* value_field = descriptor_->message_type()->map_value();
const bool string_key = key_field->type() == FieldDescriptor::TYPE_STRING;
const bool string_value = value_field->type() == FieldDescriptor::TYPE_STRING;
format(
"using MapType = ::_pb::Map<$key_cpp$, $val_cpp$>;\n"
"using WireHelper = $map_classname$::Funcs;\n"
"const auto& map_field = this->_internal_$name$();\n");
bool utf8_check = string_key || string_value;
if (utf8_check) {
format("auto check_utf8 = [](const MapType::value_type& entry) {\n");
{
auto check_scope = format.ScopedIndent();
// p may be unused when GetUtf8CheckMode evaluates to kNone,
// thus disabling the validation.
format("(void)entry;\n");
if (string_key) {
GenerateUtf8CheckCodeForString(
key_field, options_, false,
"entry.first.data(), static_cast<int>(entry.first.length()),\n",
format);
}
if (string_value) {
GenerateUtf8CheckCodeForString(
value_field, options_, false,
"entry.second.data(), static_cast<int>(entry.second.length()),\n",
format);
}
void GenerateConstructorCode(io::Printer* p) const override {}
void GenerateDestructorCode(io::Printer* p) const override {
if (ShouldSplit(field_, *opts_)) {
p->Emit(R"cc(
$cached_split_ptr$->$name$_.~$MapField$();
)cc");
return;
}
format("};\n");
}
format(
"\n"
"if (stream->IsSerializationDeterministic() && "
"map_field.size() > 1) {\n");
{
auto deterministic_scope = format.ScopedIndent();
GenerateSerializationLoop(format, string_key, string_value, true);
}
format("} else {\n");
{
auto map_order_scope = format.ScopedIndent();
GenerateSerializationLoop(format, string_key, string_value, false);
p->Emit(R"cc(
$field_$.~$MapField$();
)cc");
}
format("}\n");
format.Outdent();
format("}\n");
}
void MapFieldGenerator::GenerateByteSize(io::Printer* printer) const {
Formatter format(printer, variables_);
format(
"total_size += $tag_size$ *\n"
" "
"::$proto_ns$::internal::FromIntSize(this->_internal_$name$_size());\n"
"for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n"
" it = this->_internal_$name$().begin();\n"
" it != this->_internal_$name$().end(); ++it) {\n"
" total_size += $map_classname$::Funcs::ByteSizeLong(it->first, "
"it->second);\n"
"}\n");
}
void GeneratePrivateMembers(io::Printer* p) const override;
void GenerateAccessorDeclarations(io::Printer* p) const override;
void GenerateInlineAccessorDefinitions(io::Printer* p) const override;
void GenerateSerializeWithCachedSizesToArray(io::Printer* p) const override;
void GenerateByteSize(io::Printer* p) const override;
void MapFieldGenerator::GenerateIsInitialized(io::Printer* printer) const {
if (!has_required_fields_) return;
private:
const FieldDescriptor* field_;
const FieldDescriptor* key_;
const FieldDescriptor* val_;
const Options* opts_;
bool has_required_;
bool lite_;
};
Formatter format(printer, variables_);
format(
"if (!::$proto_ns$::internal::AllAreInitialized($field$)) return "
"false;\n");
void Map::GeneratePrivateMembers(io::Printer* p) const {
p->Emit({{"kKeyType",
absl::AsciiStrToUpper(DeclaredTypeMethodName(key_->type()))},
{"kValType",
absl::AsciiStrToUpper(DeclaredTypeMethodName(val_->type()))}},
R"cc(
$pbi$::$MapField$<$Entry$, $Key$, $Val$,
$pbi$::WireFormatLite::TYPE_$kKeyType$,
$pbi$::WireFormatLite::TYPE_$kValType$>
$name$_;
)cc");
}
void MapFieldGenerator::GenerateConstexprAggregateInitializer(
io::Printer* printer) const {
Formatter format(printer, variables_);
format("/*decltype($field$)*/{}");
}
void Map::GenerateAccessorDeclarations(io::Printer* p) const {
auto v1 = p->WithVars(
AnnotatedAccessors(field_, {"", "_internal_", "_internal_mutable_"}));
auto v2 = p->WithVars(AnnotatedAccessors(field_, {"mutable_"},
io::AnnotationCollector::kAlias));
p->Emit(R"cc(
$DEPRECATED$ const $Map$& $name$() const;
$DEPRECATED$ $Map$* $mutable_name$();
private:
const $Map$& $_internal_name$() const;
$Map$* $_internal_mutable_name$();
void MapFieldGenerator::GenerateCopyAggregateInitializer(
io::Printer* printer) const {
Formatter format(printer, variables_);
// MapField has no move constructor, which prevents explicit aggregate
// initialization pre-C++17.
format("/*decltype($field$)*/{}");
public:
)cc");
}
void MapFieldGenerator::GenerateAggregateInitializer(
io::Printer* printer) const {
Formatter format(printer, variables_);
if (ShouldSplit(descriptor_, options_)) {
format(
"/*decltype($classname$::Split::$name$_)*/"
"{::_pbi::ArenaInitialized(), arena}");
return;
}
// MapField has no move constructor.
format("/*decltype($field$)*/{::_pbi::ArenaInitialized(), arena}");
void Map::GenerateInlineAccessorDefinitions(io::Printer* p) const {
p->Emit(R"cc(
inline const $Map$& $Msg$::_internal_$name$() const {
return $field_$.GetMap();
}
inline const $Map$& $Msg$::$name$() const {
$annotate_get$;
// @@protoc_insertion_point(field_map:$pkg.Msg.field$)
return _internal_$name$();
}
inline $Map$* $Msg$::_internal_mutable_$name$() {
$PrepareSplitMessageForWrite$;
return $field_$.MutableMap();
}
inline $Map$* $Msg$::mutable_$name$() {
$annotate_mutable$;
// @@protoc_insertion_point(field_mutable_map:$pkg.Msg.field$)
return _internal_mutable_$name$();
}
)cc");
}
void MapFieldGenerator::GenerateDestructorCode(io::Printer* printer) const {
Formatter format(printer, variables_);
if (ShouldSplit(descriptor_, options_)) {
format("$cached_split_ptr$->$name$_.~MapField$lite$();\n");
return;
}
format("$field$.~MapField$lite$();\n");
void Map::GenerateSerializeWithCachedSizesToArray(io::Printer* p) const {
bool string_key = key_->type() == FieldDescriptor::TYPE_STRING;
bool string_val = val_->type() == FieldDescriptor::TYPE_STRING;
p->Emit(
{
{"Sorter", string_key ? "MapSorterPtr" : "MapSorterFlat"},
{"CheckUtf8",
[&] {
if (string_key) {
GenerateUtf8CheckCodeForString(
p, key_, *opts_, /*for_parse=*/false,
"entry.first.data(), "
"static_cast<int>(entry.first.length()),\n");
}
if (string_val) {
GenerateUtf8CheckCodeForString(
p, val_, *opts_, /*for_parse=*/false,
"entry.second.data(), "
"static_cast<int>(entry.second.length()),\n");
}
}},
},
R"cc(
if (!_internal_$name$().empty()) {
using MapType = $Map$;
using WireHelper = $Entry$::Funcs;
const auto& field = _internal_$name$();
if (stream->IsSerializationDeterministic() && field.size() > 1) {
for (const auto& entry : $pbi$::$Sorter$<MapType>(field)) {
target = WireHelper::InternalSerialize(
$number$, entry.first, entry.second, target, stream);
$CheckUtf8$;
}
} else {
for (const auto& entry : field) {
target = WireHelper::InternalSerialize(
$number$, entry.first, entry.second, target, stream);
$CheckUtf8$;
}
}
}
)cc");
}
void Map::GenerateByteSize(io::Printer* p) const {
p->Emit(R"cc(
total_size += $kTagBytes$ * $pbi$::FromIntSize(_internal_$name$_size());
for (const auto& entry : _internal_$name$()) {
total_size += $Entry$::Funcs::ByteSizeLong(entry.first, entry.second);
}
)cc");
}
} // namespace
std::unique_ptr<FieldGeneratorBase> MakeMapGenerator(
const FieldDescriptor* desc, const Options& options,
MessageSCCAnalyzer* scc) {
return std::make_unique<MapFieldGenerator>(desc, options, scc);
return std::make_unique<Map>(desc, options, scc);
}
} // namespace cpp

Loading…
Cancel
Save