From b72eb3f2337599b0e91cadbeeec4e17bdcfcb4b6 Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Tue, 28 Mar 2023 15:58:26 -0700 Subject: [PATCH] Add optimizations to extension declaration validation. PiperOrigin-RevId: 520159132 --- src/google/protobuf/descriptor.cc | 165 +++++++++++++++--------------- 1 file changed, 85 insertions(+), 80 deletions(-) diff --git a/src/google/protobuf/descriptor.cc b/src/google/protobuf/descriptor.cc index 048397bd77..a2e0fda2e6 100644 --- a/src/google/protobuf/descriptor.cc +++ b/src/google/protobuf/descriptor.cc @@ -986,34 +986,39 @@ struct SymbolByParentEq { using SymbolsByParentSet = absl::flat_hash_set; -struct FilesByNameHash { +template +struct DescriptorsByNameHash { using is_transparent = void; size_t operator()(absl::string_view name) const { return absl::HashOf(name); } - size_t operator()(const FileDescriptor* file) const { + size_t operator()(const DescriptorT* file) const { return absl::HashOf(file->name()); } }; -struct FilesByNameEq { +template +struct DescriptorsByNameEq { using is_transparent = void; bool operator()(absl::string_view lhs, absl::string_view rhs) const { return lhs == rhs; } - bool operator()(absl::string_view lhs, const FileDescriptor* rhs) const { + bool operator()(absl::string_view lhs, const DescriptorT* rhs) const { return lhs == rhs->name(); } - bool operator()(const FileDescriptor* lhs, absl::string_view rhs) const { + bool operator()(const DescriptorT* lhs, absl::string_view rhs) const { return lhs->name() == rhs; } - bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const { + bool operator()(const DescriptorT* lhs, const DescriptorT* rhs) const { return lhs == rhs || lhs->name() == rhs->name(); } }; -using FilesByNameSet = - absl::flat_hash_set; + +template +using DescriptorsByNameSet = + absl::flat_hash_set, + DescriptorsByNameEq>; using FieldsByNameMap = absl::flat_hash_map, @@ -1364,7 +1369,7 @@ class DescriptorPool::Tables { flat_allocs_; SymbolsByNameSet symbols_by_name_; - FilesByNameSet files_by_name_; + DescriptorsByNameSet files_by_name_; ExtensionsGroupedByDescriptorMap extensions_; struct CheckPoint { @@ -2970,9 +2975,8 @@ class SourceLocationCommentPrinter { std::string FormatComment(const std::string& comment_text) { std::string stripped_comment = comment_text; absl::StripAsciiWhitespace(&stripped_comment); - std::vector lines = absl::StrSplit(stripped_comment, "\n"); std::string output; - for (const std::string& line : lines) { + for (absl::string_view line : absl::StrSplit(stripped_comment, '\n')) { absl::SubstituteAndAppend(&output, "$0// $1\n", prefix_, line); } return output; @@ -3791,12 +3795,9 @@ class DescriptorBuilder { // Maximum recursion depth corresponds to 32 nested message declarations. int recursion_depth_ = 32; - void AddError(const std::string& element_name, const Message& descriptor, - DescriptorPool::ErrorCollector::ErrorLocation location, - const std::string& error); - void AddError(const std::string& element_name, const Message& descriptor, + void AddError(absl::string_view element_name, const Message& descriptor, DescriptorPool::ErrorCollector::ErrorLocation location, - const char* error); + absl::string_view error); void AddRecursiveImportError(const FileDescriptorProto& proto, int from_here); void AddTwiceListedError(const FileDescriptorProto& proto, int index); void AddImportError(const FileDescriptorProto& proto, int index); @@ -3804,13 +3805,13 @@ class DescriptorBuilder { // Adds an error indicating that undefined_symbol was not defined. Must // only be called after LookupSymbol() fails. void AddNotDefinedError( - const std::string& element_name, const Message& descriptor, + absl::string_view element_name, const Message& descriptor, DescriptorPool::ErrorCollector::ErrorLocation location, - const std::string& undefined_symbol); + absl::string_view undefined_symbol); - void AddWarning(const std::string& element_name, const Message& descriptor, + void AddWarning(absl::string_view element_name, const Message& descriptor, DescriptorPool::ErrorCollector::ErrorLocation location, - const std::string& error); + absl::string_view error); // Silly helper which determines if the given file is in the given package. // I.e., either file->package() == package_name or file->package() is a @@ -4172,7 +4173,7 @@ class DescriptorBuilder { const std::string& full_name, const RepeatedPtrField& declarations, const DescriptorProto_ExtensionRange& proto, - absl::flat_hash_set& full_name_set); + absl::flat_hash_set& full_name_set); void ValidateServiceOptions(ServiceDescriptor* service, const ServiceDescriptorProto& proto); void ValidateMethodOptions(MethodDescriptor* method, @@ -4251,9 +4252,9 @@ DescriptorBuilder::DescriptorBuilder( DescriptorBuilder::~DescriptorBuilder() {} PROTOBUF_NOINLINE void DescriptorBuilder::AddError( - const std::string& element_name, const Message& descriptor, + absl::string_view element_name, const Message& descriptor, DescriptorPool::ErrorCollector::ErrorLocation location, - const std::string& error) { + absl::string_view error) { if (error_collector_ == nullptr) { if (!had_errors_) { ABSL_LOG(ERROR) << "Invalid proto descriptor for file \"" << filename_ @@ -4267,16 +4268,10 @@ PROTOBUF_NOINLINE void DescriptorBuilder::AddError( had_errors_ = true; } -PROTOBUF_NOINLINE void DescriptorBuilder::AddError( - const std::string& element_name, const Message& descriptor, - DescriptorPool::ErrorCollector::ErrorLocation location, const char* error) { - AddError(element_name, descriptor, location, std::string(error)); -} - PROTOBUF_NOINLINE void DescriptorBuilder::AddNotDefinedError( - const std::string& element_name, const Message& descriptor, + absl::string_view element_name, const Message& descriptor, DescriptorPool::ErrorCollector::ErrorLocation location, - const std::string& undefined_symbol) { + absl::string_view undefined_symbol) { if (possible_undeclared_dependency_ == nullptr && undefine_resolved_name_.empty()) { AddError(element_name, descriptor, location, @@ -4307,9 +4302,9 @@ PROTOBUF_NOINLINE void DescriptorBuilder::AddNotDefinedError( } PROTOBUF_NOINLINE void DescriptorBuilder::AddWarning( - const std::string& element_name, const Message& descriptor, + absl::string_view element_name, const Message& descriptor, DescriptorPool::ErrorCollector::ErrorLocation location, - const std::string& error) { + absl::string_view error) { if (error_collector_ == nullptr) { ABSL_LOG(WARNING) << filename_ << " " << element_name << ": " << error; } else { @@ -5439,10 +5434,14 @@ struct IncrementWhenDestroyed { } // namespace namespace { -static constexpr auto kNonMessageTypes = { - "double", "float", "int64", "uint64", "int32", "fixed32", - "fixed64", "bool", "string", "bytes", "uint32", "enum", - "sfixed32", "sfixed64", "sint32", "sint64"}; +bool IsNonMessageType(absl::string_view type) { + static const auto* non_message_types = + new absl::flat_hash_set( + {"double", "float", "int64", "uint64", "int32", "fixed32", "fixed64", + "bool", "string", "bytes", "uint32", "enum", "sfixed32", "sfixed64", + "sint32", "sint64"}); + return non_message_types->contains(type); +} } // namespace @@ -5534,9 +5533,8 @@ void DescriptorBuilder::BuildMessage(const DescriptorProto& proto, } } - absl::flat_hash_set reserved_name_set; - for (int i = 0; i < proto.reserved_name_size(); i++) { - const std::string& name = proto.reserved_name(i); + absl::flat_hash_set reserved_name_set; + for (absl::string_view name : proto.reserved_name()) { if (!reserved_name_set.insert(name).second) { AddError(name, proto, DescriptorPool::ErrorCollector::NAME, absl::Substitute("Field name \"$0\" is reserved multiple times.", @@ -5571,7 +5569,7 @@ void DescriptorBuilder::BuildMessage(const DescriptorProto& proto, field->name(), field->number())); } } - if (reserved_name_set.find(field->name()) != reserved_name_set.end()) { + if (reserved_name_set.contains(field->name())) { AddError( field->full_name(), proto.field(i), DescriptorPool::ErrorCollector::NAME, @@ -6217,12 +6215,9 @@ void DescriptorBuilder::BuildEnum(const EnumDescriptorProto& proto, } } - absl::flat_hash_set reserved_name_set; - for (int i = 0; i < proto.reserved_name_size(); i++) { - const std::string& name = proto.reserved_name(i); - if (reserved_name_set.find(name) == reserved_name_set.end()) { - reserved_name_set.insert(name); - } else { + absl::flat_hash_set reserved_name_set; + for (absl::string_view name : proto.reserved_name()) { + if (!reserved_name_set.insert(name).second) { AddError(name, proto, DescriptorPool::ErrorCollector::NAME, absl::Substitute("Enum value \"$0\" is reserved multiple times.", name)); @@ -6240,7 +6235,7 @@ void DescriptorBuilder::BuildEnum(const EnumDescriptorProto& proto, value->name(), value->number())); } } - if (reserved_name_set.find(value->name()) != reserved_name_set.end()) { + if (reserved_name_set.contains(value->name())) { AddError( value->full_name(), proto.value(i), DescriptorPool::ErrorCollector::NAME, @@ -7253,17 +7248,15 @@ void DescriptorBuilder::ValidateEnumValueOptions( namespace { // Validates that a fully-qualified symbol for extension declaration must // have a leading dot and valid identifiers. -absl::optional ValidateSymbolsForDeclaration( - absl::flat_hash_set symbols) { - for (absl::string_view symbol : symbols) { - if (!absl::StartsWith(symbol, ".")) { - return absl::StrCat("\"", symbol, - "\" must have a leading dot to indicate the " - "fully-qualified scope."); - } - if (!ValidateQualifiedName(symbol)) { - return absl::StrCat("\"", symbol, "\" contains invalid identifiers."); - } +absl::optional ValidateSymbolForDeclaration( + absl::string_view symbol) { + if (!absl::StartsWith(symbol, ".")) { + return absl::StrCat("\"", symbol, + "\" must have a leading dot to indicate the " + "fully-qualified scope."); + } + if (!ValidateQualifiedName(symbol)) { + return absl::StrCat("\"", symbol, "\" contains invalid identifiers."); } return absl::nullopt; } @@ -7274,8 +7267,7 @@ void DescriptorBuilder::ValidateExtensionDeclaration( const std::string& full_name, const RepeatedPtrField& declarations, const DescriptorProto_ExtensionRange& proto, - absl::flat_hash_set& full_name_set) { - absl::flat_hash_set symbols; + absl::flat_hash_set& full_name_set) { for (const auto& declaration : declarations) { if (declaration.number() < proto.start() || declaration.number() >= proto.end()) { @@ -7299,15 +7291,18 @@ void DescriptorBuilder::ValidateExtensionDeclaration( declaration.full_name())); return; } - symbols.insert(declaration.full_name()); - if (std::find(std::begin(kNonMessageTypes), std::end(kNonMessageTypes), - declaration.type()) == std::end(kNonMessageTypes)) { - symbols.insert(declaration.type()); + absl::optional err = + ValidateSymbolForDeclaration(declaration.full_name()); + if (err.has_value()) { + AddError(full_name, proto, DescriptorPool::ErrorCollector::NAME, *err); + } + if (!IsNonMessageType(declaration.type())) { + err = ValidateSymbolForDeclaration(declaration.type()); + if (err.has_value()) { + AddError(full_name, proto, DescriptorPool::ErrorCollector::NAME, + *err); + } } - } - absl::optional err = ValidateSymbolsForDeclaration(symbols); - if (err.has_value()) { - AddError(full_name, proto, DescriptorPool::ErrorCollector::NAME, *err); } } } @@ -7318,8 +7313,18 @@ void DescriptorBuilder::ValidateExtensionRangeOptions( static_cast(message.options().message_set_wire_format() ? std::numeric_limits::max() : FieldDescriptor::kMaxNumber); - // Contains the full names of all declarations. - absl::flat_hash_set declaration_full_name_set; + + size_t num_declarations = 0; + for (int i = 0; i < message.extension_range_count(); i++) { + if (message.extension_range(i)->options_ == nullptr) continue; + num_declarations += + message.extension_range(i)->options_->declaration_size(); + } + + // Contains the full names from both "declaration" and "metadata". + absl::flat_hash_set declaration_full_name_set; + declaration_full_name_set.reserve(num_declarations); + for (int i = 0; i < message.extension_range_count(); i++) { const auto& range = *message.extension_range(i); if (range.end > max_extension_range + 1) { @@ -7433,13 +7438,13 @@ bool DescriptorBuilder::ValidateMapEntry(FieldDescriptor* field, void DescriptorBuilder::DetectMapConflicts(const Descriptor* message, const DescriptorProto& proto) { - absl::flat_hash_map seen_types; + DescriptorsByNameSet seen_types; for (int i = 0; i < message->nested_type_count(); ++i) { const Descriptor* nested = message->nested_type(i); - auto insert_result = seen_types.emplace(nested->name(), nested); + auto insert_result = seen_types.insert(nested); bool inserted = insert_result.second; if (!inserted) { - if (insert_result.first->second->options().map_entry() || + if ((*insert_result.first)->options().map_entry() || nested->options().map_entry()) { AddError( message->full_name(), proto, DescriptorPool::ErrorCollector::NAME, @@ -7455,10 +7460,10 @@ void DescriptorBuilder::DetectMapConflicts(const Descriptor* message, for (int i = 0; i < message->field_count(); ++i) { const FieldDescriptor* field = message->field(i); auto iter = seen_types.find(field->name()); - if (iter != seen_types.end() && iter->second->options().map_entry()) { + if (iter != seen_types.end() && (*iter)->options().map_entry()) { AddError(message->full_name(), proto, DescriptorPool::ErrorCollector::NAME, - absl::StrCat("Expanded map entry type ", iter->second->name(), + absl::StrCat("Expanded map entry type ", (*iter)->name(), " conflicts with an existing field.")); } } @@ -7466,10 +7471,10 @@ void DescriptorBuilder::DetectMapConflicts(const Descriptor* message, for (int i = 0; i < message->enum_type_count(); ++i) { const EnumDescriptor* enum_desc = message->enum_type(i); auto iter = seen_types.find(enum_desc->name()); - if (iter != seen_types.end() && iter->second->options().map_entry()) { + if (iter != seen_types.end() && (*iter)->options().map_entry()) { AddError(message->full_name(), proto, DescriptorPool::ErrorCollector::NAME, - absl::StrCat("Expanded map entry type ", iter->second->name(), + absl::StrCat("Expanded map entry type ", (*iter)->name(), " conflicts with an existing enum type.")); } } @@ -7477,10 +7482,10 @@ void DescriptorBuilder::DetectMapConflicts(const Descriptor* message, for (int i = 0; i < message->oneof_decl_count(); ++i) { const OneofDescriptor* oneof_desc = message->oneof_decl(i); auto iter = seen_types.find(oneof_desc->name()); - if (iter != seen_types.end() && iter->second->options().map_entry()) { + if (iter != seen_types.end() && (*iter)->options().map_entry()) { AddError(message->full_name(), proto, DescriptorPool::ErrorCollector::NAME, - absl::StrCat("Expanded map entry type ", iter->second->name(), + absl::StrCat("Expanded map entry type ", (*iter)->name(), " conflicts with an existing oneof type.")); } }