From 02248cbe38d1b585d1f2bee623d3d83c99dfbc4d Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Fri, 28 Jul 2023 13:22:20 -0700 Subject: [PATCH] Unify LazyParseMode and LazyVerifyOption to ensure eager parsing on verification failure. PiperOrigin-RevId: 551934633 --- .../compiler/cpp/parse_function_generator.cc | 28 +++++++++++-------- src/google/protobuf/extension_set.h | 5 ++-- src/google/protobuf/message_unittest.inc | 23 ++++++++++++++- src/google/protobuf/unittest_mset.proto | 1 + 4 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/google/protobuf/compiler/cpp/parse_function_generator.cc b/src/google/protobuf/compiler/cpp/parse_function_generator.cc index 66d2ea5493..2b82b154a4 100644 --- a/src/google/protobuf/compiler/cpp/parse_function_generator.cc +++ b/src/google/protobuf/compiler/cpp/parse_function_generator.cc @@ -1043,12 +1043,15 @@ void ParseFunctionGenerator::GenerateLengthDelim(Formatter& format, bool eager_verify = IsEagerlyVerifiedLazy(field, options_, scc_analyzer_); if (ShouldVerify(descriptor_, options_, scc_analyzer_)) { - format( - "ctx->set_lazy_eager_verify_func($1$);\n", - eager_verify - ? absl::StrCat("&", ClassName(field->message_type(), true), - "::InternalVerify") - : "nullptr"); + if (eager_verify) { + format("ctx->set_lazy_eager_verify_func(&$1$::InternalVerify);\n", + ClassName(field->message_type(), true)); + } else { + format( + "ctx->set_lazy_eager_verify_func(nullptr);\n" + "auto old_mode = " + "ctx->set_lazy_parse_mode(::_pbi::ParseContext::kLazy);\n"); + } } if (field->real_containing_oneof()) { format( @@ -1074,14 +1077,15 @@ void ParseFunctionGenerator::GenerateLengthDelim(Formatter& format, " ::$proto_ns$::internal::LazyField> parse_helper(\n" " $1$::default_instance(),\n" " $msg$GetArenaForAllocation(),\n" - " ::google::protobuf::internal::LazyVerifyOption::$2$,\n" " lazy_field);\n" "ptr = ctx->ParseMessage(&parse_helper, ptr);\n", - FieldMessageTypeName(field, options_), - eager_verify ? "kEager" : "kLazy"); - if (ShouldVerify(descriptor_, options_, scc_analyzer_) && - eager_verify) { - format("ctx->set_lazy_eager_verify_func(nullptr);\n"); + FieldMessageTypeName(field, options_)); + if (ShouldVerify(descriptor_, options_, scc_analyzer_)) { + if (eager_verify) { + format("ctx->set_lazy_eager_verify_func(nullptr);\n"); + } else { + format("(void)ctx->set_lazy_parse_mode(old_mode);\n"); + } } } else if (IsImplicitWeakField(field, options_, scc_analyzer_)) { if (!field->is_repeated()) { diff --git a/src/google/protobuf/extension_set.h b/src/google/protobuf/extension_set.h index f6070ff2ab..cfa415ef8d 100644 --- a/src/google/protobuf/extension_set.h +++ b/src/google/protobuf/extension_set.h @@ -81,7 +81,6 @@ class FeatureSet; namespace internal { class FieldSkipper; // wire_format_lite.h class WireFormat; -enum class LazyVerifyOption; } // namespace internal } // namespace protobuf } // namespace google @@ -594,8 +593,8 @@ class PROTOBUF_EXPORT ExtensionSet { virtual void Clear() = 0; virtual const char* _InternalParse(const MessageLite& prototype, - Arena* arena, LazyVerifyOption option, - const char* ptr, ParseContext* ctx) = 0; + Arena* arena, const char* ptr, + ParseContext* ctx) = 0; virtual uint8_t* WriteMessageToArray( const MessageLite* prototype, int number, uint8_t* target, io::EpsCopyOutputStream* stream) const = 0; diff --git a/src/google/protobuf/message_unittest.inc b/src/google/protobuf/message_unittest.inc index 4aaf7ee53f..7b3a3120f6 100644 --- a/src/google/protobuf/message_unittest.inc +++ b/src/google/protobuf/message_unittest.inc @@ -328,7 +328,7 @@ TEST(MESSAGE_TEST_NAME, ExplicitLazyExceedRecursionLimit) { ->mutable_payload() ->set_optional_int32(-1); std::string serialized; - EXPECT_TRUE(original.SerializeToString(&serialized)); + ASSERT_TRUE(original.SerializeToString(&serialized)); // User annotated LazyField ([lazy = true]) is eagerly verified and should // catch the recursion limit violation. @@ -342,6 +342,27 @@ TEST(MESSAGE_TEST_NAME, ExplicitLazyExceedRecursionLimit) { EXPECT_NE(parsed.lazy_child().child().payload().optional_int32(), -1); } +TEST(MESSAGE_TEST_NAME, NestedLazyRecursionLimit) { + UNITTEST::NestedTestAllTypes original, parsed; + original.mutable_lazy_child() + ->mutable_lazy_child() + ->mutable_lazy_child() + ->mutable_payload() + ->set_optional_int32(-1); + std::string serialized; + ASSERT_TRUE(original.SerializeToString(&serialized)); + ASSERT_TRUE(parsed.ParseFromString(serialized)); + + io::ArrayInputStream array_stream(serialized.data(), serialized.size()); + io::CodedInputStream input_stream(&array_stream); + input_stream.SetRecursionLimit(2); + EXPECT_FALSE(parsed.ParseFromCodedStream(&input_stream)); + EXPECT_TRUE(parsed.has_lazy_child()); + EXPECT_TRUE(parsed.lazy_child().has_lazy_child()); + EXPECT_TRUE(parsed.lazy_child().lazy_child().has_lazy_child()); + EXPECT_FALSE(parsed.lazy_child().lazy_child().lazy_child().has_payload()); +} + TEST(MESSAGE_TEST_NAME, UnparsedEmpty) { // lazy_child, LEN=100 with no payload. const char encoded[] = {'\042', 100}; diff --git a/src/google/protobuf/unittest_mset.proto b/src/google/protobuf/unittest_mset.proto index b95acace58..72c49827e9 100644 --- a/src/google/protobuf/unittest_mset.proto +++ b/src/google/protobuf/unittest_mset.proto @@ -51,6 +51,7 @@ message TestMessageSetContainer { message NestedTestMessageSetContainer { optional TestMessageSetContainer container = 1; optional NestedTestMessageSetContainer child = 2; + optional NestedTestMessageSetContainer lazy_child = 3 [lazy = true]; } message NestedTestInt {