diff --git a/src/google/protobuf/compiler/BUILD.bazel b/src/google/protobuf/compiler/BUILD.bazel index cee2678cff..1631d0c8f1 100644 --- a/src/google/protobuf/compiler/BUILD.bazel +++ b/src/google/protobuf/compiler/BUILD.bazel @@ -52,6 +52,7 @@ cc_library( "//src/google/protobuf/io:io_win32", "//src/google/protobuf/io:tokenizer", "@com_google_absl//absl/base", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", diff --git a/src/google/protobuf/compiler/parser.cc b/src/google/protobuf/compiler/parser.cc index 18b525512a..f6f7d1d509 100644 --- a/src/google/protobuf/compiler/parser.cc +++ b/src/google/protobuf/compiler/parser.cc @@ -24,6 +24,7 @@ #include #include "absl/base/casts.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/absl_check.h" @@ -774,6 +775,8 @@ bool Parser::ParseTopLevelStatement(FileDescriptorProto* file, LocationRecorder location(root_location, FileDescriptorProto::kMessageTypeFieldNumber, file->message_type_size()); + // Maximum depth allowed by the DescriptorPool. + recursion_depth_ = internal::cpp::MaxMessageDeclarationNestingDepth(); return ParseMessageDefinition(file->add_message_type(), location, file); } else if (LookingAt("enum")) { LocationRecorder location(root_location, @@ -851,6 +854,12 @@ PROTOBUF_NOINLINE static void GenerateSyntheticOneofs( bool Parser::ParseMessageDefinition( DescriptorProto* message, const LocationRecorder& message_location, const FileDescriptorProto* containing_file) { + const auto undo_depth = absl::MakeCleanup([&] { ++recursion_depth_; }); + if (--recursion_depth_ <= 0) { + RecordError("Reached maximum recursion limit for nested messages."); + return false; + } + DO(Consume("message")); { LocationRecorder location(message_location, diff --git a/src/google/protobuf/compiler/parser.h b/src/google/protobuf/compiler/parser.h index 0762b9acac..f1e709d1bc 100644 --- a/src/google/protobuf/compiler/parser.h +++ b/src/google/protobuf/compiler/parser.h @@ -563,6 +563,7 @@ class PROTOBUF_EXPORT Parser final { bool stop_after_syntax_identifier_; std::string syntax_identifier_; Edition edition_ = Edition::EDITION_UNKNOWN; + int recursion_depth_; // Leading doc comments for the next declaration. These are not complete // yet; use ConsumeEndOfDeclaration() to get the complete comments. diff --git a/src/google/protobuf/compiler/parser_unittest.cc b/src/google/protobuf/compiler/parser_unittest.cc index ee2f6b4bb0..46c1e2e5ff 100644 --- a/src/google/protobuf/compiler/parser_unittest.cc +++ b/src/google/protobuf/compiler/parser_unittest.cc @@ -19,11 +19,13 @@ #include "google/protobuf/any.pb.h" #include "google/protobuf/descriptor.pb.h" +#include #include "google/protobuf/testing/googletest.h" #include #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "google/protobuf/compiler/retention.h" @@ -131,7 +133,8 @@ class ParserTest : public testing::Test { } // Parse the text and expect that the given errors are reported. - void ExpectHasErrors(const char* text, const char* expected_errors) { + void ExpectHasErrors(absl::string_view text, + const testing::Matcher& expected_errors) { ExpectHasEarlyExitErrors(text, expected_errors); EXPECT_EQ(io::Tokenizer::TYPE_END, input_->current().type); } @@ -148,20 +151,22 @@ class ParserTest : public testing::Test { // Same as above but does not expect that the parser parses the complete // input. - void ExpectHasEarlyExitErrors(absl::string_view text, - absl::string_view expected_errors) { + void ExpectHasEarlyExitErrors( + absl::string_view text, + const testing::Matcher& expected_errors) { SetupParser(text); SourceLocationTable source_locations; parser_->RecordSourceLocationsTo(&source_locations); FileDescriptorProto file; EXPECT_FALSE(parser_->Parse(input_.get(), &file)); - EXPECT_EQ(expected_errors, error_collector_.text_); + EXPECT_THAT(error_collector_.text_, expected_errors); } // Parse the text as a file and validate it (with a DescriptorPool), and // expect that the validation step reports the given errors. - void ExpectHasValidationErrors(const char* text, - const char* expected_errors) { + void ExpectHasValidationErrors( + absl::string_view text, + const testing::Matcher& expected_errors) { SetupParser(text); SourceLocationTable source_locations; parser_->RecordSourceLocationsTo(&source_locations); @@ -176,7 +181,7 @@ class ParserTest : public testing::Test { &error_collector_); EXPECT_TRUE(pool_.BuildFileCollectingErrors( file, &validation_error_collector) == nullptr); - EXPECT_EQ(expected_errors, error_collector_.text_); + EXPECT_THAT(error_collector_.text_, expected_errors); } MockErrorCollector error_collector_; @@ -1454,6 +1459,41 @@ TEST_F(ParseErrorTest, EofInMessage) { "0:21: Reached end of input in message definition (missing '}').\n"); } +TEST_F(ParseErrorTest, NestingIsLimitedWithoutCrashing) { + std::string start = "syntax = \"proto2\";\n"; + std::string end; + + const auto add = [&] { + absl::StrAppend(&start, "message M {"); + absl::StrAppend(&end, "}"); + }; + const auto input = [&] { return absl::StrCat(start, end); }; + + // The first ones work correctly. + for (int i = 1; i < internal::cpp::MaxMessageDeclarationNestingDepth(); ++i) { + add(); + const std::string str = input(); + SetupParser(str); + FileDescriptorProto proto; + proto.set_name("foo.proto"); + EXPECT_TRUE(parser_->Parse(input_.get(), &proto)) << input(); + EXPECT_EQ(io::Tokenizer::TYPE_END, input_->current().type); + ASSERT_EQ("", error_collector_.text_); + DescriptorPool pool; + ASSERT_TRUE(pool.BuildFile(proto)); + } + // The rest have parsing errors but they don't crash no matter how deep we + // make them. + const auto error = testing::HasSubstr( + "Reached maximum recursion limit for nested messages."); + add(); + ExpectHasErrors(input(), error); + for (int i = 0; i < 100000; ++i) { + add(); + } + ExpectHasErrors(input(), error); +} + TEST_F(ParseErrorTest, MissingFieldNumber) { ExpectHasErrors( "message TestMessage {\n" diff --git a/src/google/protobuf/descriptor.cc b/src/google/protobuf/descriptor.cc index a477920e66..2ad232f4f4 100644 --- a/src/google/protobuf/descriptor.cc +++ b/src/google/protobuf/descriptor.cc @@ -4051,7 +4051,7 @@ class DescriptorBuilder { // Counts down to 0 when there is no depth remaining. // // Maximum recursion depth corresponds to 32 nested message declarations. - int recursion_depth_ = 32; + int recursion_depth_ = internal::cpp::MaxMessageDeclarationNestingDepth(); // Note: Both AddError and AddWarning functions are extremely sensitive to // the *caller* stack space used. We call these functions many times in diff --git a/src/google/protobuf/descriptor.h b/src/google/protobuf/descriptor.h index c380edcf07..13446c7d91 100644 --- a/src/google/protobuf/descriptor.h +++ b/src/google/protobuf/descriptor.h @@ -2845,6 +2845,11 @@ bool ParseNoReflection(absl::string_view from, google::protobuf::MessageLite& to // In particular, questions like "does this field have a has bit?" have a // different answer depending on the language. namespace cpp { + +// The maximum allowed nesting for message declarations. +// Going over this limit will make the proto definition invalid. +constexpr int MaxMessageDeclarationNestingDepth() { return 32; } + // Returns true if 'enum' semantics are such that unknown values are preserved // in the enum field itself, rather than going to the UnknownFieldSet. PROTOBUF_EXPORT bool HasPreservingUnknownEnumSemantics(