Add ability for custom a MessagePrinter to default back to the default implementation.

Prior to the CL we cannot use the printer as that would trigger a recursion. So the only way out would be to create a new printer that has the same abilities as the current printer.

PiperOrigin-RevId: 489024852
pull/11008/head
Protobuf Team Bot 2 years ago committed by Copybara-Service
parent c862c1cab4
commit 72977f951c
  1. 57
      src/google/protobuf/text_format.cc
  2. 52
      src/google/protobuf/text_format.h
  3. 19
      src/google/protobuf/text_format_unittest.cc

@ -1470,15 +1470,15 @@ class TextFormat::Printer::TextGenerator
// error.)
bool failed() const { return failed_; }
void PrintMaybeWithMarker(absl::string_view text) {
void PrintMaybeWithMarker(MarkerToken, absl::string_view text) override {
Print(text.data(), text.size());
if (ConsumeInsertSilentMarker()) {
PrintLiteral(internal::kDebugStringSilentMarker);
}
}
void PrintMaybeWithMarker(absl::string_view text_head,
absl::string_view text_tail) {
void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head,
absl::string_view text_tail) override {
Print(text_head.data(), text_head.size());
if (ConsumeInsertSilentMarker()) {
PrintLiteral(internal::kDebugStringSilentMarker);
@ -1577,13 +1577,10 @@ class TextFormat::Printer::DebugStringFieldValuePrinter
void PrintMessageStart(const Message& /*message*/, int /*field_index*/,
int /*field_count*/, bool single_line_mode,
BaseTextGenerator* generator) const override {
// This is safe as only TextGenerator is used with
// DebugStringFieldValuePrinter.
TextGenerator* text_generator = static_cast<TextGenerator*>(generator);
if (single_line_mode) {
text_generator->PrintMaybeWithMarker(" ", "{ ");
generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else {
text_generator->PrintMaybeWithMarker(" ", "{\n");
generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
}
}
};
@ -2174,7 +2171,7 @@ struct FieldIndexSorter {
} // namespace
bool TextFormat::Printer::PrintAny(const Message& message,
TextGenerator* generator) const {
BaseTextGenerator* generator) const {
const FieldDescriptor* type_url_field;
const FieldDescriptor* value_field;
if (!internal::GetAnyFieldDescriptors(message, &type_url_field,
@ -2222,7 +2219,7 @@ bool TextFormat::Printer::PrintAny(const Message& message,
}
void TextFormat::Printer::Print(const Message& message,
TextGenerator* generator) const {
BaseTextGenerator* generator) const {
const Reflection* reflection = message.GetReflection();
if (!reflection) {
// This message does not provide any way to describe its structure.
@ -2242,10 +2239,20 @@ void TextFormat::Printer::Print(const Message& message,
itr->second->Print(message, single_line_mode_, generator);
return;
}
PrintMessage(message, generator);
}
void TextFormat::Printer::PrintMessage(const Message& message,
BaseTextGenerator* generator) const {
if (generator == nullptr) {
return;
}
const Descriptor* descriptor = message.GetDescriptor();
if (descriptor->full_name() == internal::kAnyFullTypeName && expand_any_ &&
PrintAny(message, generator)) {
return;
}
const Reflection* reflection = message.GetReflection();
std::vector<const FieldDescriptor*> fields;
if (descriptor->options().map_entry()) {
fields.push_back(descriptor->field(0));
@ -2460,7 +2467,7 @@ void MapFieldPrinterHelper::CopyValue(const MapValueRef& value,
void TextFormat::Printer::PrintField(const Message& message,
const Reflection* reflection,
const FieldDescriptor* field,
TextGenerator* generator) const {
BaseTextGenerator* generator) const {
if (use_short_repeated_primitives_ && field->is_repeated() &&
field->cpp_type() != FieldDescriptor::CPPTYPE_STRING &&
field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
@ -2508,7 +2515,7 @@ void TextFormat::Printer::PrintField(const Message& message,
printer->PrintMessageEnd(sub_message, field_index, count,
single_line_mode_, generator);
} else {
generator->PrintMaybeWithMarker(": ");
generator->PrintMaybeWithMarker(MarkerToken(), ": ");
// Write the field value.
PrintFieldValue(message, reflection, field, field_index, generator);
if (single_line_mode_) {
@ -2528,12 +2535,12 @@ void TextFormat::Printer::PrintField(const Message& message,
void TextFormat::Printer::PrintShortRepeatedField(
const Message& message, const Reflection* reflection,
const FieldDescriptor* field, TextGenerator* generator) const {
const FieldDescriptor* field, BaseTextGenerator* generator) const {
// Print primitive repeated field in short form.
int size = reflection->FieldSize(message, field);
PrintFieldName(message, /*field_index=*/-1, /*field_count=*/size, reflection,
field, generator);
generator->PrintMaybeWithMarker(": ", "[");
generator->PrintMaybeWithMarker(MarkerToken(), ": ", "[");
for (int i = 0; i < size; i++) {
if (i > 0) generator->PrintLiteral(", ");
PrintFieldValue(message, reflection, field, i, generator);
@ -2549,7 +2556,7 @@ void TextFormat::Printer::PrintFieldName(const Message& message,
int field_index, int field_count,
const Reflection* reflection,
const FieldDescriptor* field,
TextGenerator* generator) const {
BaseTextGenerator* generator) const {
// if use_field_number_ is true, prints field number instead
// of field name.
if (use_field_number_) {
@ -2566,7 +2573,7 @@ void TextFormat::Printer::PrintFieldValue(const Message& message,
const Reflection* reflection,
const FieldDescriptor* field,
int index,
TextGenerator* generator) const {
BaseTextGenerator* generator) const {
GOOGLE_DCHECK(field->is_repeated() || (index == -1))
<< "Index must be -1 for non-repeated fields";
@ -2687,7 +2694,7 @@ void TextFormat::Printer::PrintFieldValue(const Message& message,
}
void TextFormat::Printer::PrintUnknownFields(
const UnknownFieldSet& unknown_fields, TextGenerator* generator,
const UnknownFieldSet& unknown_fields, BaseTextGenerator* generator,
int recursion_budget) const {
for (int i = 0; i < unknown_fields.field_count(); i++) {
const UnknownField& field = unknown_fields.field(i);
@ -2696,7 +2703,7 @@ void TextFormat::Printer::PrintUnknownFields(
switch (field.type()) {
case UnknownField::TYPE_VARINT:
generator->PrintString(field_number);
generator->PrintMaybeWithMarker(": ");
generator->PrintMaybeWithMarker(MarkerToken(), ": ");
generator->PrintString(absl::StrCat(field.varint()));
if (single_line_mode_) {
generator->PrintLiteral(" ");
@ -2706,7 +2713,7 @@ void TextFormat::Printer::PrintUnknownFields(
break;
case UnknownField::TYPE_FIXED32: {
generator->PrintString(field_number);
generator->PrintMaybeWithMarker(": ", "0x");
generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x");
generator->PrintString(
absl::StrCat(absl::Hex(field.fixed32(), absl::kZeroPad8)));
if (single_line_mode_) {
@ -2718,7 +2725,7 @@ void TextFormat::Printer::PrintUnknownFields(
}
case UnknownField::TYPE_FIXED64: {
generator->PrintString(field_number);
generator->PrintMaybeWithMarker(": ", "0x");
generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x");
generator->PrintString(
absl::StrCat(absl::Hex(field.fixed64(), absl::kZeroPad16)));
if (single_line_mode_) {
@ -2743,9 +2750,9 @@ void TextFormat::Printer::PrintUnknownFields(
// This field is parseable as a Message.
// So it is probably an embedded message.
if (single_line_mode_) {
generator->PrintMaybeWithMarker(" ", "{ ");
generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else {
generator->PrintMaybeWithMarker(" ", "{\n");
generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
generator->Indent();
}
PrintUnknownFields(embedded_unknown_fields, generator,
@ -2759,7 +2766,7 @@ void TextFormat::Printer::PrintUnknownFields(
} else {
// This field is not parseable as a Message (or we ran out of
// recursion budget). So it is probably just a plain string.
generator->PrintMaybeWithMarker(": ", "\"");
generator->PrintMaybeWithMarker(MarkerToken(), ": ", "\"");
generator->PrintString(absl::CEscape(value));
if (single_line_mode_) {
generator->PrintLiteral("\" ");
@ -2772,9 +2779,9 @@ void TextFormat::Printer::PrintUnknownFields(
case UnknownField::TYPE_GROUP:
generator->PrintString(field_number);
if (single_line_mode_) {
generator->PrintMaybeWithMarker(" ", "{ ");
generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else {
generator->PrintMaybeWithMarker(" ", "{\n");
generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
generator->Indent();
}
// For groups, we recurse without checking the budget. This is OK,

@ -115,7 +115,22 @@ class PROTOBUF_EXPORT TextFormat {
const FieldDescriptor* field, int index,
std::string* output);
// Forward declare `Printer` for `BaseTextGenerator::MarkerToken` which
// restricts some methods of `BaseTextGenerator` to the class `Printer`.
class Printer;
class PROTOBUF_EXPORT BaseTextGenerator {
private:
// Passkey (go/totw/134#what-about-stdshared-ptr) that allows `Printer`
// (but not derived classes) to call `PrintMaybeWithMarker` and its
// `Printer::TextGenerator` to overload it.
// This prevents users from bypassing the marker generation.
class MarkerToken {
private:
explicit MarkerToken() = default; // 'explicit' prevents aggregate init.
friend class Printer;
};
public:
virtual ~BaseTextGenerator();
@ -133,6 +148,20 @@ class PROTOBUF_EXPORT TextFormat {
void PrintLiteral(const char (&text)[n]) {
Print(text, n - 1); // n includes the terminating zero character.
}
// Internal to Printer, access regulated by `MarkerToken`.
virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text) {
Print(text.data(), text.size());
}
// Internal to Printer, access regulated by `MarkerToken`.
virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head,
absl::string_view text_tail) {
Print(text_head.data(), text_head.size());
Print(text_tail.data(), text_tail.size());
}
friend class Printer;
};
// The default printer that converts scalar values from fields into their
@ -381,6 +410,14 @@ class PROTOBUF_EXPORT TextFormat {
bool RegisterMessagePrinter(const Descriptor* descriptor,
const MessagePrinter* printer);
// Default printing for messages, which allows registered message printers
// to fall back to default printing without losing the ability to control
// sub-messages or fields.
// NOTE: If the passed in `text_generaor` is not actually the current
// `TextGenerator`, then no output will be produced.
void PrintMessage(const Message& message,
BaseTextGenerator* generator) const;
private:
friend std::string Message::DebugString() const;
friend std::string Message::ShortDebugString() const;
@ -404,6 +441,7 @@ class PROTOBUF_EXPORT TextFormat {
// Forward declaration of an internal class used to print the text
// output to the OutputStream (see text_format.cc for implementation).
class TextGenerator;
using MarkerToken = BaseTextGenerator::MarkerToken;
// Forward declaration of an internal class used to print field values for
// DebugString APIs (see text_format.cc for implementation).
@ -417,40 +455,40 @@ class PROTOBUF_EXPORT TextFormat {
// Internal Print method, used for writing to the OutputStream via
// the TextGenerator class.
void Print(const Message& message, TextGenerator* generator) const;
void Print(const Message& message, BaseTextGenerator* generator) const;
// Print a single field.
void PrintField(const Message& message, const Reflection* reflection,
const FieldDescriptor* field,
TextGenerator* generator) const;
BaseTextGenerator* generator) const;
// Print a repeated primitive field in short form.
void PrintShortRepeatedField(const Message& message,
const Reflection* reflection,
const FieldDescriptor* field,
TextGenerator* generator) const;
BaseTextGenerator* generator) const;
// Print the name of a field -- i.e. everything that comes before the
// ':' for a single name/value pair.
void PrintFieldName(const Message& message, int field_index,
int field_count, const Reflection* reflection,
const FieldDescriptor* field,
TextGenerator* generator) const;
BaseTextGenerator* generator) const;
// Outputs a textual representation of the value of the field supplied on
// the message supplied or the default value if not set.
void PrintFieldValue(const Message& message, const Reflection* reflection,
const FieldDescriptor* field, int index,
TextGenerator* generator) const;
BaseTextGenerator* generator) const;
// Print the fields in an UnknownFieldSet. They are printed by tag number
// only. Embedded messages are heuristically identified by attempting to
// parse them (subject to the recursion budget).
void PrintUnknownFields(const UnknownFieldSet& unknown_fields,
TextGenerator* generator,
BaseTextGenerator* generator,
int recursion_budget) const;
bool PrintAny(const Message& message, TextGenerator* generator) const;
bool PrintAny(const Message& message, BaseTextGenerator* generator) const;
const FastFieldValuePrinter* GetFieldPrinter(
const FieldDescriptor* field) const {

@ -780,15 +780,24 @@ class CustomNestedMessagePrinter : public TextFormat::MessagePrinter {
~CustomNestedMessagePrinter() override {}
void Print(const Message& message, bool single_line_mode,
TextFormat::BaseTextGenerator* generator) const override {
generator->PrintLiteral("custom");
generator->PrintLiteral("// custom\n");
if (printer_ != nullptr) {
printer_->PrintMessage(message, generator);
}
}
void SetPrinter(TextFormat::Printer* printer) { printer_ = printer; }
private:
TextFormat::Printer* printer_ = nullptr;
};
TEST_F(TextFormatTest, CustomMessagePrinter) {
TextFormat::Printer printer;
auto* custom_printer = new CustomNestedMessagePrinter;
printer.RegisterMessagePrinter(
unittest::TestAllTypes::NestedMessage::default_instance().descriptor(),
new CustomNestedMessagePrinter);
custom_printer);
unittest::TestAllTypes message;
std::string text;
@ -797,7 +806,11 @@ TEST_F(TextFormatTest, CustomMessagePrinter) {
message.mutable_optional_nested_message()->set_bb(1);
EXPECT_TRUE(printer.PrintToString(message, &text));
EXPECT_EQ("optional_nested_message {\n custom}\n", text);
EXPECT_EQ("optional_nested_message {\n // custom\n}\n", text);
custom_printer->SetPrinter(&printer);
EXPECT_TRUE(printer.PrintToString(message, &text));
EXPECT_EQ("optional_nested_message {\n // custom\n bb: 1\n}\n", text);
}
TEST_F(TextFormatTest, ParseBasic) {

Loading…
Cancel
Save