From 72977f951c1f20807534a9d9af4e9583c6a9049b Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Wed, 16 Nov 2022 13:33:05 -0800 Subject: [PATCH] Add ability for custom a MessagePrinter to default back to the default implementation. Prior to the CL we cannot use the printer as that would trigger a recursion. So the only way out would be to create a new printer that has the same abilities as the current printer. PiperOrigin-RevId: 489024852 --- src/google/protobuf/text_format.cc | 57 ++++++++++++--------- src/google/protobuf/text_format.h | 52 ++++++++++++++++--- src/google/protobuf/text_format_unittest.cc | 19 +++++-- 3 files changed, 93 insertions(+), 35 deletions(-) diff --git a/src/google/protobuf/text_format.cc b/src/google/protobuf/text_format.cc index 7ab3d488e7..b8ddd086e3 100644 --- a/src/google/protobuf/text_format.cc +++ b/src/google/protobuf/text_format.cc @@ -1470,15 +1470,15 @@ class TextFormat::Printer::TextGenerator // error.) bool failed() const { return failed_; } - void PrintMaybeWithMarker(absl::string_view text) { + void PrintMaybeWithMarker(MarkerToken, absl::string_view text) override { Print(text.data(), text.size()); if (ConsumeInsertSilentMarker()) { PrintLiteral(internal::kDebugStringSilentMarker); } } - void PrintMaybeWithMarker(absl::string_view text_head, - absl::string_view text_tail) { + void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head, + absl::string_view text_tail) override { Print(text_head.data(), text_head.size()); if (ConsumeInsertSilentMarker()) { PrintLiteral(internal::kDebugStringSilentMarker); @@ -1577,13 +1577,10 @@ class TextFormat::Printer::DebugStringFieldValuePrinter void PrintMessageStart(const Message& /*message*/, int /*field_index*/, int /*field_count*/, bool single_line_mode, BaseTextGenerator* generator) const override { - // This is safe as only TextGenerator is used with - // DebugStringFieldValuePrinter. - TextGenerator* text_generator = static_cast(generator); if (single_line_mode) { - text_generator->PrintMaybeWithMarker(" ", "{ "); + generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ "); } else { - text_generator->PrintMaybeWithMarker(" ", "{\n"); + generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n"); } } }; @@ -2174,7 +2171,7 @@ struct FieldIndexSorter { } // namespace bool TextFormat::Printer::PrintAny(const Message& message, - TextGenerator* generator) const { + BaseTextGenerator* generator) const { const FieldDescriptor* type_url_field; const FieldDescriptor* value_field; if (!internal::GetAnyFieldDescriptors(message, &type_url_field, @@ -2222,7 +2219,7 @@ bool TextFormat::Printer::PrintAny(const Message& message, } void TextFormat::Printer::Print(const Message& message, - TextGenerator* generator) const { + BaseTextGenerator* generator) const { const Reflection* reflection = message.GetReflection(); if (!reflection) { // This message does not provide any way to describe its structure. @@ -2242,10 +2239,20 @@ void TextFormat::Printer::Print(const Message& message, itr->second->Print(message, single_line_mode_, generator); return; } + PrintMessage(message, generator); +} + +void TextFormat::Printer::PrintMessage(const Message& message, + BaseTextGenerator* generator) const { + if (generator == nullptr) { + return; + } + const Descriptor* descriptor = message.GetDescriptor(); if (descriptor->full_name() == internal::kAnyFullTypeName && expand_any_ && PrintAny(message, generator)) { return; } + const Reflection* reflection = message.GetReflection(); std::vector fields; if (descriptor->options().map_entry()) { fields.push_back(descriptor->field(0)); @@ -2460,7 +2467,7 @@ void MapFieldPrinterHelper::CopyValue(const MapValueRef& value, void TextFormat::Printer::PrintField(const Message& message, const Reflection* reflection, const FieldDescriptor* field, - TextGenerator* generator) const { + BaseTextGenerator* generator) const { if (use_short_repeated_primitives_ && field->is_repeated() && field->cpp_type() != FieldDescriptor::CPPTYPE_STRING && field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { @@ -2508,7 +2515,7 @@ void TextFormat::Printer::PrintField(const Message& message, printer->PrintMessageEnd(sub_message, field_index, count, single_line_mode_, generator); } else { - generator->PrintMaybeWithMarker(": "); + generator->PrintMaybeWithMarker(MarkerToken(), ": "); // Write the field value. PrintFieldValue(message, reflection, field, field_index, generator); if (single_line_mode_) { @@ -2528,12 +2535,12 @@ void TextFormat::Printer::PrintField(const Message& message, void TextFormat::Printer::PrintShortRepeatedField( const Message& message, const Reflection* reflection, - const FieldDescriptor* field, TextGenerator* generator) const { + const FieldDescriptor* field, BaseTextGenerator* generator) const { // Print primitive repeated field in short form. int size = reflection->FieldSize(message, field); PrintFieldName(message, /*field_index=*/-1, /*field_count=*/size, reflection, field, generator); - generator->PrintMaybeWithMarker(": ", "["); + generator->PrintMaybeWithMarker(MarkerToken(), ": ", "["); for (int i = 0; i < size; i++) { if (i > 0) generator->PrintLiteral(", "); PrintFieldValue(message, reflection, field, i, generator); @@ -2549,7 +2556,7 @@ void TextFormat::Printer::PrintFieldName(const Message& message, int field_index, int field_count, const Reflection* reflection, const FieldDescriptor* field, - TextGenerator* generator) const { + BaseTextGenerator* generator) const { // if use_field_number_ is true, prints field number instead // of field name. if (use_field_number_) { @@ -2566,7 +2573,7 @@ void TextFormat::Printer::PrintFieldValue(const Message& message, const Reflection* reflection, const FieldDescriptor* field, int index, - TextGenerator* generator) const { + BaseTextGenerator* generator) const { GOOGLE_DCHECK(field->is_repeated() || (index == -1)) << "Index must be -1 for non-repeated fields"; @@ -2687,7 +2694,7 @@ void TextFormat::Printer::PrintFieldValue(const Message& message, } void TextFormat::Printer::PrintUnknownFields( - const UnknownFieldSet& unknown_fields, TextGenerator* generator, + const UnknownFieldSet& unknown_fields, BaseTextGenerator* generator, int recursion_budget) const { for (int i = 0; i < unknown_fields.field_count(); i++) { const UnknownField& field = unknown_fields.field(i); @@ -2696,7 +2703,7 @@ void TextFormat::Printer::PrintUnknownFields( switch (field.type()) { case UnknownField::TYPE_VARINT: generator->PrintString(field_number); - generator->PrintMaybeWithMarker(": "); + generator->PrintMaybeWithMarker(MarkerToken(), ": "); generator->PrintString(absl::StrCat(field.varint())); if (single_line_mode_) { generator->PrintLiteral(" "); @@ -2706,7 +2713,7 @@ void TextFormat::Printer::PrintUnknownFields( break; case UnknownField::TYPE_FIXED32: { generator->PrintString(field_number); - generator->PrintMaybeWithMarker(": ", "0x"); + generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x"); generator->PrintString( absl::StrCat(absl::Hex(field.fixed32(), absl::kZeroPad8))); if (single_line_mode_) { @@ -2718,7 +2725,7 @@ void TextFormat::Printer::PrintUnknownFields( } case UnknownField::TYPE_FIXED64: { generator->PrintString(field_number); - generator->PrintMaybeWithMarker(": ", "0x"); + generator->PrintMaybeWithMarker(MarkerToken(), ": ", "0x"); generator->PrintString( absl::StrCat(absl::Hex(field.fixed64(), absl::kZeroPad16))); if (single_line_mode_) { @@ -2743,9 +2750,9 @@ void TextFormat::Printer::PrintUnknownFields( // This field is parseable as a Message. // So it is probably an embedded message. if (single_line_mode_) { - generator->PrintMaybeWithMarker(" ", "{ "); + generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ "); } else { - generator->PrintMaybeWithMarker(" ", "{\n"); + generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n"); generator->Indent(); } PrintUnknownFields(embedded_unknown_fields, generator, @@ -2759,7 +2766,7 @@ void TextFormat::Printer::PrintUnknownFields( } else { // This field is not parseable as a Message (or we ran out of // recursion budget). So it is probably just a plain string. - generator->PrintMaybeWithMarker(": ", "\""); + generator->PrintMaybeWithMarker(MarkerToken(), ": ", "\""); generator->PrintString(absl::CEscape(value)); if (single_line_mode_) { generator->PrintLiteral("\" "); @@ -2772,9 +2779,9 @@ void TextFormat::Printer::PrintUnknownFields( case UnknownField::TYPE_GROUP: generator->PrintString(field_number); if (single_line_mode_) { - generator->PrintMaybeWithMarker(" ", "{ "); + generator->PrintMaybeWithMarker(MarkerToken(), " ", "{ "); } else { - generator->PrintMaybeWithMarker(" ", "{\n"); + generator->PrintMaybeWithMarker(MarkerToken(), " ", "{\n"); generator->Indent(); } // For groups, we recurse without checking the budget. This is OK, diff --git a/src/google/protobuf/text_format.h b/src/google/protobuf/text_format.h index fbcbea4593..2cd4024d65 100644 --- a/src/google/protobuf/text_format.h +++ b/src/google/protobuf/text_format.h @@ -115,7 +115,22 @@ class PROTOBUF_EXPORT TextFormat { const FieldDescriptor* field, int index, std::string* output); + // Forward declare `Printer` for `BaseTextGenerator::MarkerToken` which + // restricts some methods of `BaseTextGenerator` to the class `Printer`. + class Printer; + class PROTOBUF_EXPORT BaseTextGenerator { + private: + // Passkey (go/totw/134#what-about-stdshared-ptr) that allows `Printer` + // (but not derived classes) to call `PrintMaybeWithMarker` and its + // `Printer::TextGenerator` to overload it. + // This prevents users from bypassing the marker generation. + class MarkerToken { + private: + explicit MarkerToken() = default; // 'explicit' prevents aggregate init. + friend class Printer; + }; + public: virtual ~BaseTextGenerator(); @@ -133,6 +148,20 @@ class PROTOBUF_EXPORT TextFormat { void PrintLiteral(const char (&text)[n]) { Print(text, n - 1); // n includes the terminating zero character. } + + // Internal to Printer, access regulated by `MarkerToken`. + virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text) { + Print(text.data(), text.size()); + } + + // Internal to Printer, access regulated by `MarkerToken`. + virtual void PrintMaybeWithMarker(MarkerToken, absl::string_view text_head, + absl::string_view text_tail) { + Print(text_head.data(), text_head.size()); + Print(text_tail.data(), text_tail.size()); + } + + friend class Printer; }; // The default printer that converts scalar values from fields into their @@ -381,6 +410,14 @@ class PROTOBUF_EXPORT TextFormat { bool RegisterMessagePrinter(const Descriptor* descriptor, const MessagePrinter* printer); + // Default printing for messages, which allows registered message printers + // to fall back to default printing without losing the ability to control + // sub-messages or fields. + // NOTE: If the passed in `text_generaor` is not actually the current + // `TextGenerator`, then no output will be produced. + void PrintMessage(const Message& message, + BaseTextGenerator* generator) const; + private: friend std::string Message::DebugString() const; friend std::string Message::ShortDebugString() const; @@ -404,6 +441,7 @@ class PROTOBUF_EXPORT TextFormat { // Forward declaration of an internal class used to print the text // output to the OutputStream (see text_format.cc for implementation). class TextGenerator; + using MarkerToken = BaseTextGenerator::MarkerToken; // Forward declaration of an internal class used to print field values for // DebugString APIs (see text_format.cc for implementation). @@ -417,40 +455,40 @@ class PROTOBUF_EXPORT TextFormat { // Internal Print method, used for writing to the OutputStream via // the TextGenerator class. - void Print(const Message& message, TextGenerator* generator) const; + void Print(const Message& message, BaseTextGenerator* generator) const; // Print a single field. void PrintField(const Message& message, const Reflection* reflection, const FieldDescriptor* field, - TextGenerator* generator) const; + BaseTextGenerator* generator) const; // Print a repeated primitive field in short form. void PrintShortRepeatedField(const Message& message, const Reflection* reflection, const FieldDescriptor* field, - TextGenerator* generator) const; + BaseTextGenerator* generator) const; // Print the name of a field -- i.e. everything that comes before the // ':' for a single name/value pair. void PrintFieldName(const Message& message, int field_index, int field_count, const Reflection* reflection, const FieldDescriptor* field, - TextGenerator* generator) const; + BaseTextGenerator* generator) const; // Outputs a textual representation of the value of the field supplied on // the message supplied or the default value if not set. void PrintFieldValue(const Message& message, const Reflection* reflection, const FieldDescriptor* field, int index, - TextGenerator* generator) const; + BaseTextGenerator* generator) const; // Print the fields in an UnknownFieldSet. They are printed by tag number // only. Embedded messages are heuristically identified by attempting to // parse them (subject to the recursion budget). void PrintUnknownFields(const UnknownFieldSet& unknown_fields, - TextGenerator* generator, + BaseTextGenerator* generator, int recursion_budget) const; - bool PrintAny(const Message& message, TextGenerator* generator) const; + bool PrintAny(const Message& message, BaseTextGenerator* generator) const; const FastFieldValuePrinter* GetFieldPrinter( const FieldDescriptor* field) const { diff --git a/src/google/protobuf/text_format_unittest.cc b/src/google/protobuf/text_format_unittest.cc index 6dcc044722..1bca2dfcb2 100644 --- a/src/google/protobuf/text_format_unittest.cc +++ b/src/google/protobuf/text_format_unittest.cc @@ -780,15 +780,24 @@ class CustomNestedMessagePrinter : public TextFormat::MessagePrinter { ~CustomNestedMessagePrinter() override {} void Print(const Message& message, bool single_line_mode, TextFormat::BaseTextGenerator* generator) const override { - generator->PrintLiteral("custom"); + generator->PrintLiteral("// custom\n"); + if (printer_ != nullptr) { + printer_->PrintMessage(message, generator); + } } + + void SetPrinter(TextFormat::Printer* printer) { printer_ = printer; } + + private: + TextFormat::Printer* printer_ = nullptr; }; TEST_F(TextFormatTest, CustomMessagePrinter) { TextFormat::Printer printer; + auto* custom_printer = new CustomNestedMessagePrinter; printer.RegisterMessagePrinter( unittest::TestAllTypes::NestedMessage::default_instance().descriptor(), - new CustomNestedMessagePrinter); + custom_printer); unittest::TestAllTypes message; std::string text; @@ -797,7 +806,11 @@ TEST_F(TextFormatTest, CustomMessagePrinter) { message.mutable_optional_nested_message()->set_bb(1); EXPECT_TRUE(printer.PrintToString(message, &text)); - EXPECT_EQ("optional_nested_message {\n custom}\n", text); + EXPECT_EQ("optional_nested_message {\n // custom\n}\n", text); + + custom_printer->SetPrinter(&printer); + EXPECT_TRUE(printer.PrintToString(message, &text)); + EXPECT_EQ("optional_nested_message {\n // custom\n bb: 1\n}\n", text); } TEST_F(TextFormatTest, ParseBasic) {