Add ability for custom a MessagePrinter to default back to the default implementation.

Prior to the CL we cannot use the printer as that would trigger a recursion. So the only way out would be to create a new printer that has the same abilities as the current printer.

PiperOrigin-RevId: 489024852
pull/11008/head
Protobuf Team Bot 2 years ago committed by Copybara-Service
parent c862c1cab4
commit 72977f951c
  1. 57
      src/google/protobuf/text_format.cc
  2. 52
      src/google/protobuf/text_format.h
  3. 19
      src/google/protobuf/text_format_unittest.cc

@ -1470,15 +1470,15 @@ class TextFormat::Printer::TextGenerator
// error.) // error.)
bool failed() const { return failed_; } bool failed() const { return failed_; }
void PrintMaybeWithMarker(absl::string_view text) { void PrintMaybeWithMarker(MarkerToken, absl::string_view text) override {
Print(text.data(), text.size()); Print(text.data(), text.size());
if (ConsumeInsertSilentMarker()) { if (ConsumeInsertSilentMarker()) {
PrintLiteral(internal::kDebugStringSilentMarker); PrintLiteral(internal::kDebugStringSilentMarker);
} }
} }
void PrintMaybeWithMarker(absl::string_view text_head, void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head,
absl::string_view text_tail) { absl::string_view text_tail) override {
Print(text_head.data(), text_head.size()); Print(text_head.data(), text_head.size());
if (ConsumeInsertSilentMarker()) { if (ConsumeInsertSilentMarker()) {
PrintLiteral(internal::kDebugStringSilentMarker); PrintLiteral(internal::kDebugStringSilentMarker);
@ -1577,13 +1577,10 @@ class TextFormat::Printer::DebugStringFieldValuePrinter
void PrintMessageStart(const Message& /*message*/, int /*field_index*/, void PrintMessageStart(const Message& /*message*/, int /*field_index*/,
int /*field_count*/, bool single_line_mode, int /*field_count*/, bool single_line_mode,
BaseTextGenerator* generator) const override { BaseTextGenerator* generator) const override {
// This is safe as only TextGenerator is used with
// DebugStringFieldValuePrinter.
TextGenerator* text_generator = static_cast<TextGenerator*>(generator);
if (single_line_mode) { if (single_line_mode) {
text_generator->PrintMaybeWithMarker(" ", "{ "); generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else { } else {
text_generator->PrintMaybeWithMarker(" ", "{\n"); generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
} }
} }
}; };
@ -2174,7 +2171,7 @@ struct FieldIndexSorter {
} // namespace } // namespace
bool TextFormat::Printer::PrintAny(const Message& message, bool TextFormat::Printer::PrintAny(const Message& message,
TextGenerator* generator) const { BaseTextGenerator* generator) const {
const FieldDescriptor* type_url_field; const FieldDescriptor* type_url_field;
const FieldDescriptor* value_field; const FieldDescriptor* value_field;
if (!internal::GetAnyFieldDescriptors(message, &type_url_field, if (!internal::GetAnyFieldDescriptors(message, &type_url_field,
@ -2222,7 +2219,7 @@ bool TextFormat::Printer::PrintAny(const Message& message,
} }
void TextFormat::Printer::Print(const Message& message, void TextFormat::Printer::Print(const Message& message,
TextGenerator* generator) const { BaseTextGenerator* generator) const {
const Reflection* reflection = message.GetReflection(); const Reflection* reflection = message.GetReflection();
if (!reflection) { if (!reflection) {
// This message does not provide any way to describe its structure. // This message does not provide any way to describe its structure.
@ -2242,10 +2239,20 @@ void TextFormat::Printer::Print(const Message& message,
itr->second->Print(message, single_line_mode_, generator); itr->second->Print(message, single_line_mode_, generator);
return; return;
} }
PrintMessage(message, generator);
}
void TextFormat::Printer::PrintMessage(const Message& message,
BaseTextGenerator* generator) const {
if (generator == nullptr) {
return;
}
const Descriptor* descriptor = message.GetDescriptor();
if (descriptor->full_name() == internal::kAnyFullTypeName && expand_any_ && if (descriptor->full_name() == internal::kAnyFullTypeName && expand_any_ &&
PrintAny(message, generator)) { PrintAny(message, generator)) {
return; return;
} }
const Reflection* reflection = message.GetReflection();
std::vector<const FieldDescriptor*> fields; std::vector<const FieldDescriptor*> fields;
if (descriptor->options().map_entry()) { if (descriptor->options().map_entry()) {
fields.push_back(descriptor->field(0)); fields.push_back(descriptor->field(0));
@ -2460,7 +2467,7 @@ void MapFieldPrinterHelper::CopyValue(const MapValueRef& value,
void TextFormat::Printer::PrintField(const Message& message, void TextFormat::Printer::PrintField(const Message& message,
const Reflection* reflection, const Reflection* reflection,
const FieldDescriptor* field, const FieldDescriptor* field,
TextGenerator* generator) const { BaseTextGenerator* generator) const {
if (use_short_repeated_primitives_ && field->is_repeated() && if (use_short_repeated_primitives_ && field->is_repeated() &&
field->cpp_type() != FieldDescriptor::CPPTYPE_STRING && field->cpp_type() != FieldDescriptor::CPPTYPE_STRING &&
field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
@ -2508,7 +2515,7 @@ void TextFormat::Printer::PrintField(const Message& message,
printer->PrintMessageEnd(sub_message, field_index, count, printer->PrintMessageEnd(sub_message, field_index, count,
single_line_mode_, generator); single_line_mode_, generator);
} else { } else {
generator->PrintMaybeWithMarker(": "); generator->PrintMaybeWithMarker(MarkerToken(), ": ");
// Write the field value. // Write the field value.
PrintFieldValue(message, reflection, field, field_index, generator); PrintFieldValue(message, reflection, field, field_index, generator);
if (single_line_mode_) { if (single_line_mode_) {
@ -2528,12 +2535,12 @@ void TextFormat::Printer::PrintField(const Message& message,
void TextFormat::Printer::PrintShortRepeatedField( void TextFormat::Printer::PrintShortRepeatedField(
const Message& message, const Reflection* reflection, const Message& message, const Reflection* reflection,
const FieldDescriptor* field, TextGenerator* generator) const { const FieldDescriptor* field, BaseTextGenerator* generator) const {
// Print primitive repeated field in short form. // Print primitive repeated field in short form.
int size = reflection->FieldSize(message, field); int size = reflection->FieldSize(message, field);
PrintFieldName(message, /*field_index=*/-1, /*field_count=*/size, reflection, PrintFieldName(message, /*field_index=*/-1, /*field_count=*/size, reflection,
field, generator); field, generator);
generator->PrintMaybeWithMarker(": ", "["); generator->PrintMaybeWithMarker(MarkerToken(), ": ", "[");
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
if (i > 0) generator->PrintLiteral(", "); if (i > 0) generator->PrintLiteral(", ");
PrintFieldValue(message, reflection, field, i, generator); PrintFieldValue(message, reflection, field, i, generator);
@ -2549,7 +2556,7 @@ void TextFormat::Printer::PrintFieldName(const Message& message,
int field_index, int field_count, int field_index, int field_count,
const Reflection* reflection, const Reflection* reflection,
const FieldDescriptor* field, const FieldDescriptor* field,
TextGenerator* generator) const { BaseTextGenerator* generator) const {
// if use_field_number_ is true, prints field number instead // if use_field_number_ is true, prints field number instead
// of field name. // of field name.
if (use_field_number_) { if (use_field_number_) {
@ -2566,7 +2573,7 @@ void TextFormat::Printer::PrintFieldValue(const Message& message,
const Reflection* reflection, const Reflection* reflection,
const FieldDescriptor* field, const FieldDescriptor* field,
int index, int index,
TextGenerator* generator) const { BaseTextGenerator* generator) const {
GOOGLE_DCHECK(field->is_repeated() || (index == -1)) GOOGLE_DCHECK(field->is_repeated() || (index == -1))
<< "Index must be -1 for non-repeated fields"; << "Index must be -1 for non-repeated fields";
@ -2687,7 +2694,7 @@ void TextFormat::Printer::PrintFieldValue(const Message& message,
} }
void TextFormat::Printer::PrintUnknownFields( void TextFormat::Printer::PrintUnknownFields(
const UnknownFieldSet& unknown_fields, TextGenerator* generator, const UnknownFieldSet& unknown_fields, BaseTextGenerator* generator,
int recursion_budget) const { int recursion_budget) const {
for (int i = 0; i < unknown_fields.field_count(); i++) { for (int i = 0; i < unknown_fields.field_count(); i++) {
const UnknownField& field = unknown_fields.field(i); const UnknownField& field = unknown_fields.field(i);
@ -2696,7 +2703,7 @@ void TextFormat::Printer::PrintUnknownFields(
switch (field.type()) { switch (field.type()) {
case UnknownField::TYPE_VARINT: case UnknownField::TYPE_VARINT:
generator->PrintString(field_number); generator->PrintString(field_number);
generator->PrintMaybeWithMarker(": "); generator->PrintMaybeWithMarker(MarkerToken(), ": ");
generator->PrintString(absl::StrCat(field.varint())); generator->PrintString(absl::StrCat(field.varint()));
if (single_line_mode_) { if (single_line_mode_) {
generator->PrintLiteral(" "); generator->PrintLiteral(" ");
@ -2706,7 +2713,7 @@ void TextFormat::Printer::PrintUnknownFields(
break; break;
case UnknownField::TYPE_FIXED32: { case UnknownField::TYPE_FIXED32: {
generator->PrintString(field_number); generator->PrintString(field_number);
generator->PrintMaybeWithMarker(": ", "0x"); generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x");
generator->PrintString( generator->PrintString(
absl::StrCat(absl::Hex(field.fixed32(), absl::kZeroPad8))); absl::StrCat(absl::Hex(field.fixed32(), absl::kZeroPad8)));
if (single_line_mode_) { if (single_line_mode_) {
@ -2718,7 +2725,7 @@ void TextFormat::Printer::PrintUnknownFields(
} }
case UnknownField::TYPE_FIXED64: { case UnknownField::TYPE_FIXED64: {
generator->PrintString(field_number); generator->PrintString(field_number);
generator->PrintMaybeWithMarker(": ", "0x"); generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x");
generator->PrintString( generator->PrintString(
absl::StrCat(absl::Hex(field.fixed64(), absl::kZeroPad16))); absl::StrCat(absl::Hex(field.fixed64(), absl::kZeroPad16)));
if (single_line_mode_) { if (single_line_mode_) {
@ -2743,9 +2750,9 @@ void TextFormat::Printer::PrintUnknownFields(
// This field is parseable as a Message. // This field is parseable as a Message.
// So it is probably an embedded message. // So it is probably an embedded message.
if (single_line_mode_) { if (single_line_mode_) {
generator->PrintMaybeWithMarker(" ", "{ "); generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else { } else {
generator->PrintMaybeWithMarker(" ", "{\n"); generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
generator->Indent(); generator->Indent();
} }
PrintUnknownFields(embedded_unknown_fields, generator, PrintUnknownFields(embedded_unknown_fields, generator,
@ -2759,7 +2766,7 @@ void TextFormat::Printer::PrintUnknownFields(
} else { } else {
// This field is not parseable as a Message (or we ran out of // This field is not parseable as a Message (or we ran out of
// recursion budget). So it is probably just a plain string. // recursion budget). So it is probably just a plain string.
generator->PrintMaybeWithMarker(": ", "\""); generator->PrintMaybeWithMarker(MarkerToken(), ": ", "\"");
generator->PrintString(absl::CEscape(value)); generator->PrintString(absl::CEscape(value));
if (single_line_mode_) { if (single_line_mode_) {
generator->PrintLiteral("\" "); generator->PrintLiteral("\" ");
@ -2772,9 +2779,9 @@ void TextFormat::Printer::PrintUnknownFields(
case UnknownField::TYPE_GROUP: case UnknownField::TYPE_GROUP:
generator->PrintString(field_number); generator->PrintString(field_number);
if (single_line_mode_) { if (single_line_mode_) {
generator->PrintMaybeWithMarker(" ", "{ "); generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ ");
} else { } else {
generator->PrintMaybeWithMarker(" ", "{\n"); generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n");
generator->Indent(); generator->Indent();
} }
// For groups, we recurse without checking the budget. This is OK, // For groups, we recurse without checking the budget. This is OK,

@ -115,7 +115,22 @@ class PROTOBUF_EXPORT TextFormat {
const FieldDescriptor* field, int index, const FieldDescriptor* field, int index,
std::string* output); std::string* output);
// Forward declare `Printer` for `BaseTextGenerator::MarkerToken` which
// restricts some methods of `BaseTextGenerator` to the class `Printer`.
class Printer;
class PROTOBUF_EXPORT BaseTextGenerator { class PROTOBUF_EXPORT BaseTextGenerator {
private:
// Passkey (go/totw/134#what-about-stdshared-ptr) that allows `Printer`
// (but not derived classes) to call `PrintMaybeWithMarker` and its
// `Printer::TextGenerator` to overload it.
// This prevents users from bypassing the marker generation.
class MarkerToken {
private:
explicit MarkerToken() = default; // 'explicit' prevents aggregate init.
friend class Printer;
};
public: public:
virtual ~BaseTextGenerator(); virtual ~BaseTextGenerator();
@ -133,6 +148,20 @@ class PROTOBUF_EXPORT TextFormat {
void PrintLiteral(const char (&text)[n]) { void PrintLiteral(const char (&text)[n]) {
Print(text, n - 1); // n includes the terminating zero character. Print(text, n - 1); // n includes the terminating zero character.
} }
// Internal to Printer, access regulated by `MarkerToken`.
virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text) {
Print(text.data(), text.size());
}
// Internal to Printer, access regulated by `MarkerToken`.
virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head,
absl::string_view text_tail) {
Print(text_head.data(), text_head.size());
Print(text_tail.data(), text_tail.size());
}
friend class Printer;
}; };
// The default printer that converts scalar values from fields into their // The default printer that converts scalar values from fields into their
@ -381,6 +410,14 @@ class PROTOBUF_EXPORT TextFormat {
bool RegisterMessagePrinter(const Descriptor* descriptor, bool RegisterMessagePrinter(const Descriptor* descriptor,
const MessagePrinter* printer); const MessagePrinter* printer);
// Default printing for messages, which allows registered message printers
// to fall back to default printing without losing the ability to control
// sub-messages or fields.
// NOTE: If the passed in `text_generaor` is not actually the current
// `TextGenerator`, then no output will be produced.
void PrintMessage(const Message& message,
BaseTextGenerator* generator) const;
private: private:
friend std::string Message::DebugString() const; friend std::string Message::DebugString() const;
friend std::string Message::ShortDebugString() const; friend std::string Message::ShortDebugString() const;
@ -404,6 +441,7 @@ class PROTOBUF_EXPORT TextFormat {
// Forward declaration of an internal class used to print the text // Forward declaration of an internal class used to print the text
// output to the OutputStream (see text_format.cc for implementation). // output to the OutputStream (see text_format.cc for implementation).
class TextGenerator; class TextGenerator;
using MarkerToken = BaseTextGenerator::MarkerToken;
// Forward declaration of an internal class used to print field values for // Forward declaration of an internal class used to print field values for
// DebugString APIs (see text_format.cc for implementation). // DebugString APIs (see text_format.cc for implementation).
@ -417,40 +455,40 @@ class PROTOBUF_EXPORT TextFormat {
// Internal Print method, used for writing to the OutputStream via // Internal Print method, used for writing to the OutputStream via
// the TextGenerator class. // the TextGenerator class.
void Print(const Message& message, TextGenerator* generator) const; void Print(const Message& message, BaseTextGenerator* generator) const;
// Print a single field. // Print a single field.
void PrintField(const Message& message, const Reflection* reflection, void PrintField(const Message& message, const Reflection* reflection,
const FieldDescriptor* field, const FieldDescriptor* field,
TextGenerator* generator) const; BaseTextGenerator* generator) const;
// Print a repeated primitive field in short form. // Print a repeated primitive field in short form.
void PrintShortRepeatedField(const Message& message, void PrintShortRepeatedField(const Message& message,
const Reflection* reflection, const Reflection* reflection,
const FieldDescriptor* field, const FieldDescriptor* field,
TextGenerator* generator) const; BaseTextGenerator* generator) const;
// Print the name of a field -- i.e. everything that comes before the // Print the name of a field -- i.e. everything that comes before the
// ':' for a single name/value pair. // ':' for a single name/value pair.
void PrintFieldName(const Message& message, int field_index, void PrintFieldName(const Message& message, int field_index,
int field_count, const Reflection* reflection, int field_count, const Reflection* reflection,
const FieldDescriptor* field, const FieldDescriptor* field,
TextGenerator* generator) const; BaseTextGenerator* generator) const;
// Outputs a textual representation of the value of the field supplied on // Outputs a textual representation of the value of the field supplied on
// the message supplied or the default value if not set. // the message supplied or the default value if not set.
void PrintFieldValue(const Message& message, const Reflection* reflection, void PrintFieldValue(const Message& message, const Reflection* reflection,
const FieldDescriptor* field, int index, const FieldDescriptor* field, int index,
TextGenerator* generator) const; BaseTextGenerator* generator) const;
// Print the fields in an UnknownFieldSet. They are printed by tag number // Print the fields in an UnknownFieldSet. They are printed by tag number
// only. Embedded messages are heuristically identified by attempting to // only. Embedded messages are heuristically identified by attempting to
// parse them (subject to the recursion budget). // parse them (subject to the recursion budget).
void PrintUnknownFields(const UnknownFieldSet& unknown_fields, void PrintUnknownFields(const UnknownFieldSet& unknown_fields,
TextGenerator* generator, BaseTextGenerator* generator,
int recursion_budget) const; int recursion_budget) const;
bool PrintAny(const Message& message, TextGenerator* generator) const; bool PrintAny(const Message& message, BaseTextGenerator* generator) const;
const FastFieldValuePrinter* GetFieldPrinter( const FastFieldValuePrinter* GetFieldPrinter(
const FieldDescriptor* field) const { const FieldDescriptor* field) const {

@ -780,15 +780,24 @@ class CustomNestedMessagePrinter : public TextFormat::MessagePrinter {
~CustomNestedMessagePrinter() override {} ~CustomNestedMessagePrinter() override {}
void Print(const Message& message, bool single_line_mode, void Print(const Message& message, bool single_line_mode,
TextFormat::BaseTextGenerator* generator) const override { TextFormat::BaseTextGenerator* generator) const override {
generator->PrintLiteral("custom"); generator->PrintLiteral("// custom\n");
if (printer_ != nullptr) {
printer_->PrintMessage(message, generator);
}
} }
void SetPrinter(TextFormat::Printer* printer) { printer_ = printer; }
private:
TextFormat::Printer* printer_ = nullptr;
}; };
TEST_F(TextFormatTest, CustomMessagePrinter) { TEST_F(TextFormatTest, CustomMessagePrinter) {
TextFormat::Printer printer; TextFormat::Printer printer;
auto* custom_printer = new CustomNestedMessagePrinter;
printer.RegisterMessagePrinter( printer.RegisterMessagePrinter(
unittest::TestAllTypes::NestedMessage::default_instance().descriptor(), unittest::TestAllTypes::NestedMessage::default_instance().descriptor(),
new CustomNestedMessagePrinter); custom_printer);
unittest::TestAllTypes message; unittest::TestAllTypes message;
std::string text; std::string text;
@ -797,7 +806,11 @@ TEST_F(TextFormatTest, CustomMessagePrinter) {
message.mutable_optional_nested_message()->set_bb(1); message.mutable_optional_nested_message()->set_bb(1);
EXPECT_TRUE(printer.PrintToString(message, &text)); EXPECT_TRUE(printer.PrintToString(message, &text));
EXPECT_EQ("optional_nested_message {\n custom}\n", text); EXPECT_EQ("optional_nested_message {\n // custom\n}\n", text);
custom_printer->SetPrinter(&printer);
EXPECT_TRUE(printer.PrintToString(message, &text));
EXPECT_EQ("optional_nested_message {\n // custom\n bb: 1\n}\n", text);
} }
TEST_F(TextFormatTest, ParseBasic) { TEST_F(TextFormatTest, ParseBasic) {

Loading…
Cancel
Save