diff --git a/src/google/protobuf/compiler/cpp/field_generators/message_field.cc b/src/google/protobuf/compiler/cpp/field_generators/message_field.cc index 0d56447e23..30496c9522 100644 --- a/src/google/protobuf/compiler/cpp/field_generators/message_field.cc +++ b/src/google/protobuf/compiler/cpp/field_generators/message_field.cc @@ -37,47 +37,38 @@ using ::google::protobuf::io::AnnotationCollector; using Sub = ::google::protobuf::io::Printer::Sub; std::vector Vars(const FieldDescriptor* field, const Options& opts, - bool weak) { + bool is_weak, bool use_base_class) { bool split = ShouldSplit(field, opts); bool is_foreign = IsCrossFileMessage(field); std::string field_name = FieldMemberName(field, split); std::string qualified_type = FieldMessageTypeName(field, opts); std::string default_ref = QualifiedDefaultInstanceName(field->message_type(), opts); - std::string default_ptr = - QualifiedDefaultInstancePtr(field->message_type(), opts); - std::string base = - absl::StrCat("::", ProtobufNamespace(opts), "::", "MessageLite"); + std::string base = absl::StrCat( + "::", ProtobufNamespace(opts), "::", + HasDescriptorMethods(field->file(), opts) ? "Message" : "MessageLite"); return { {"Submsg", qualified_type}, - {"MemberType", !weak ? qualified_type : base}, - {"CompleteType", !is_foreign ? qualified_type : base}, + {"MemberType", use_base_class ? base : qualified_type}, {"kDefault", default_ref}, - {"kDefaultPtr", !weak - ? default_ptr - : absl::Substitute("reinterpret_cast($1)", - base, default_ptr)}, - Sub{"base_cast", !is_foreign && !weak - ? "" - : absl::Substitute("reinterpret_cast<$0*>", base)} + Sub{"cast_to_field", + use_base_class ? absl::Substitute("reinterpret_cast<$0*>", base) : ""} .ConditionalFunctionCall(), - Sub{"weak_cast", - !weak ? "" : absl::Substitute("reinterpret_cast<$0*>", base)} - .ConditionalFunctionCall(), - Sub{"foreign_cast", + Sub{"arena_cast", !is_foreign ? "" : absl::Substitute("reinterpret_cast<$0*>", base)} .ConditionalFunctionCall(), - {"cast_field_", !weak ? field_name - : absl::Substitute("reinterpret_cast<$0*>($1)", - qualified_type, field_name)}, - {"Weak", weak ? "Weak" : ""}, - {".weak", weak ? ".weak" : ""}, - {"_weak", weak ? "_weak" : ""}, - Sub("StrongRef", - !weak ? "" - : absl::StrCat( - StrongReferenceToType(field->message_type(), opts), ";")) + {"cast_field_", use_base_class + ? absl::Substitute("reinterpret_cast<$0*>($1)", + qualified_type, field_name) + : field_name}, + {"Weak", is_weak ? "Weak" : ""}, + {".weak", is_weak ? ".weak" : ""}, + {"_weak", is_weak ? "_weak" : ""}, + Sub("StrongRef", !is_weak ? "" + : absl::StrCat(StrongReferenceToType( + field->message_type(), opts), + ";")) .WithSuffix(";"), }; } @@ -94,7 +85,7 @@ class SingularMessage : public FieldGeneratorBase { ~SingularMessage() override = default; std::vector MakeVars() const override { - return Vars(field_, *opts_, is_weak()); + return Vars(field_, *opts_, is_weak(), is_weak()); } void GeneratePrivateMembers(io::Printer* p) const override { @@ -289,14 +280,11 @@ void SingularMessage::GenerateInlineAccessorDefinitions(io::Printer* p) const { $TsanDetectConcurrentMutation$; $PrepareSplitMessageForWrite$; if (message_arena == nullptr) { - delete $base_cast$($field_$); + delete reinterpret_cast<$pb$::MessageLite*>($field_$); } if (value != nullptr) { - //~ When $Submsg$ is a cross-file type, have to read the arena - //~ through the virtual method, because the type isn't defined in - //~ this file, only forward-declared. - $pb$::Arena* submessage_arena = $base_cast$(value)->GetArena(); + $pb$::Arena* submessage_arena = $arena_cast$(value)->GetArena(); if (message_arena != submessage_arena) { value = $pbi$::GetOwnedMessage(message_arena, value, submessage_arena); } @@ -468,10 +456,44 @@ class OneofMessage : public SingularMessage { public: OneofMessage(const FieldDescriptor* descriptor, const Options& options, MessageSCCAnalyzer* scc_analyzer) - : SingularMessage(descriptor, options, scc_analyzer) {} + : SingularMessage(descriptor, options, scc_analyzer) { + auto* oneof = descriptor->containing_oneof(); + num_message_fields_in_oneof_ = 0; + for (int i = 0; i < oneof->field_count(); ++i) { + num_message_fields_in_oneof_ += + oneof->field(i)->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE; + } + } ~OneofMessage() override = default; + bool use_base_class() const { + if (is_weak()) return true; + + // For non-weak oneof fields, we choose to use a base class pointer when the + // oneof has many message fields in it. Using a base class here is not + // about correctness, but about performance and binary size. + // + // This allows the compiler to merge all the different switch cases (since + // the code is identical for all message alternatives) reducing binary size. + // The runtime dispatch is effectively changed from a switch statement to a + // virtual function call. For many oneofs, it completely elides the switch + // dispatch. + // + // This constant is a tradeoff. We want to allow optimizations (like + // inlining) on small oneofs. For small oneofs the compiler can use faster + // alternatives to table-based jumps. Also, the technique used here has less + // of a binary size win for small oneofs. + static constexpr int kMaxStaticTypeCount = 3; + return num_message_fields_in_oneof_ >= kMaxStaticTypeCount && + // Hot alternatives are kept as their static type for performance.. + !IsLikelyPresent(field_, *opts_); + } + + std::vector MakeVars() const override { + return Vars(field_, *opts_, is_weak(), use_base_class()); + } + void GenerateInlineAccessorDefinitions(io::Printer* p) const override; void GenerateNonInlineAccessorDefinitions(io::Printer* p) const override; void GenerateClearingCode(io::Printer* p) const override; @@ -484,6 +506,9 @@ class OneofMessage : public SingularMessage { bool NeedsIsInitialized() const override; void GenerateMergingCode(io::Printer* p) const override; bool RequiresArena(GeneratorFunction func) const override; + + private: + int num_message_fields_in_oneof_; }; void OneofMessage::GenerateNonInlineAccessorDefinitions(io::Printer* p) const { @@ -492,7 +517,7 @@ void OneofMessage::GenerateNonInlineAccessorDefinitions(io::Printer* p) const { $pb$::Arena* message_arena = GetArena(); clear_$oneof_name$(); if ($name$) { - $pb$::Arena* submessage_arena = $foreign_cast$($name$)->GetArena(); + $pb$::Arena* submessage_arena = $arena_cast$($name$)->GetArena(); if (message_arena != submessage_arena) { $name$ = $pbi$::GetOwnedMessage(message_arena, $name$, submessage_arena); } @@ -569,7 +594,7 @@ void OneofMessage::GenerateInlineAccessorDefinitions(io::Printer* p) const { clear_$oneof_name$(); if (value) { set_has_$name_internal$(); - $field_$ = $weak_cast$(value); + $field_$ = $cast_to_field$(value); } $annotate_set$; // @@protoc_insertion_point(field_unsafe_arena_set_allocated:$pkg.Msg.field$) @@ -581,8 +606,8 @@ void OneofMessage::GenerateInlineAccessorDefinitions(io::Printer* p) const { if ($not_has_field$) { clear_$oneof_name$(); set_has_$name_internal$(); - $field_$ = - $weak_cast$($superclass$::DefaultConstruct<$Submsg$>(GetArena())); + $field_$ = $cast_to_field$( + $superclass$::DefaultConstruct<$Submsg$>(GetArena())); } return $cast_field_$; } @@ -662,22 +687,17 @@ void OneofMessage::GenerateIsInitialized(io::Printer* p) const { bool OneofMessage::NeedsIsInitialized() const { return has_required_; } void OneofMessage::GenerateMergingCode(io::Printer* p) const { - if (is_weak()) { - p->Emit(R"cc( - if (oneof_needs_init) { - _this->$field_$ = from.$field_$->New(arena); - } - _this->$field_$->CheckTypeAndMergeFrom(*from.$field_$); - )cc"); - } else { - p->Emit(R"cc( - if (oneof_needs_init) { - _this->$field_$ = $superclass$::CopyConstruct(arena, *from.$field_$); - } else { - _this->$field_$->MergeFrom(from._internal_$name$()); - } - )cc"); - } + p->Emit({{"merge", + use_base_class() && !HasDescriptorMethods(field_->file(), options_) + ? "CheckTypeAndMergeFrom" + : "MergeFrom"}}, + R"cc( + if (oneof_needs_init) { + _this->$field_$ = $superclass$::CopyConstruct(arena, *from.$field_$); + } else { + _this->$field_$->$merge$(*from.$field_$); + } + )cc"); } bool OneofMessage::RequiresArena(GeneratorFunction func) const { @@ -699,7 +719,7 @@ class RepeatedMessage : public FieldGeneratorBase { ~RepeatedMessage() override = default; std::vector MakeVars() const override { - return Vars(field_, *opts_, is_weak()); + return Vars(field_, *opts_, is_weak(), is_weak()); } void GeneratePrivateMembers(io::Printer* p) const override; diff --git a/src/google/protobuf/compiler/plugin.pb.h b/src/google/protobuf/compiler/plugin.pb.h index 4c8a9f7859..6af5172d13 100644 --- a/src/google/protobuf/compiler/plugin.pb.h +++ b/src/google/protobuf/compiler/plugin.pb.h @@ -1652,7 +1652,7 @@ inline void CodeGeneratorRequest::set_allocated_compiler_version(::google::proto ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.compiler_version_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.compiler_version_); } if (value != nullptr) { @@ -1960,7 +1960,7 @@ inline void CodeGeneratorResponse_File::set_allocated_generated_code_info(::goog } if (value != nullptr) { - ::google::protobuf::Arena* submessage_arena = reinterpret_cast<::google::protobuf::MessageLite*>(value)->GetArena(); + ::google::protobuf::Arena* submessage_arena = reinterpret_cast<::google::protobuf::Message*>(value)->GetArena(); if (message_arena != submessage_arena) { value = ::google::protobuf::internal::GetOwnedMessage(message_arena, value, submessage_arena); } diff --git a/src/google/protobuf/descriptor.pb.h b/src/google/protobuf/descriptor.pb.h index ca33551a67..388b863164 100644 --- a/src/google/protobuf/descriptor.pb.h +++ b/src/google/protobuf/descriptor.pb.h @@ -12880,7 +12880,7 @@ inline void FileDescriptorProto::set_allocated_options(::google::protobuf::FileO ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -12978,7 +12978,7 @@ inline void FileDescriptorProto::set_allocated_source_code_info(::google::protob ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.source_code_info_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.source_code_info_); } if (value != nullptr) { @@ -13236,7 +13236,7 @@ inline void DescriptorProto_ExtensionRange::set_allocated_options(::google::prot ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -13767,7 +13767,7 @@ inline void DescriptorProto::set_allocated_options(::google::protobuf::MessageOp ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -14309,7 +14309,7 @@ inline void ExtensionRangeOptions::set_allocated_features(::google::protobuf::Fe ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -14905,7 +14905,7 @@ inline void FieldDescriptorProto::set_allocated_options(::google::protobuf::Fiel ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -15104,7 +15104,7 @@ inline void OneofDescriptorProto::set_allocated_options(::google::protobuf::Oneo ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -15385,7 +15385,7 @@ inline void EnumDescriptorProto::set_allocated_options(::google::protobuf::EnumO ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -15698,7 +15698,7 @@ inline void EnumValueDescriptorProto::set_allocated_options(::google::protobuf:: ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -15919,7 +15919,7 @@ inline void ServiceDescriptorProto::set_allocated_options(::google::protobuf::Se ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -16228,7 +16228,7 @@ inline void MethodDescriptorProto::set_allocated_options(::google::protobuf::Met ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.options_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.options_); } if (value != nullptr) { @@ -17331,7 +17331,7 @@ inline void FileOptions::set_allocated_features(::google::protobuf::FeatureSet* ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -17623,7 +17623,7 @@ inline void MessageOptions::set_allocated_features(::google::protobuf::FeatureSe ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -18408,7 +18408,7 @@ inline void FieldOptions::set_allocated_features(::google::protobuf::FeatureSet* ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -18506,7 +18506,7 @@ inline void FieldOptions::set_allocated_feature_support(::google::protobuf::Fiel ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.feature_support_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.feature_support_); } if (value != nullptr) { @@ -18658,7 +18658,7 @@ inline void OneofOptions::set_allocated_features(::google::protobuf::FeatureSet* ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -18894,7 +18894,7 @@ inline void EnumOptions::set_allocated_features(::google::protobuf::FeatureSet* ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -19074,7 +19074,7 @@ inline void EnumValueOptions::set_allocated_features(::google::protobuf::Feature ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -19200,7 +19200,7 @@ inline void EnumValueOptions::set_allocated_feature_support(::google::protobuf:: ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.feature_support_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.feature_support_); } if (value != nullptr) { @@ -19352,7 +19352,7 @@ inline void ServiceOptions::set_allocated_features(::google::protobuf::FeatureSe ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -19591,7 +19591,7 @@ inline void MethodOptions::set_allocated_features(::google::protobuf::FeatureSet ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.features_); } if (value != nullptr) { @@ -20441,7 +20441,7 @@ inline void FeatureSetDefaults_FeatureSetEditionDefault::set_allocated_overridab ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.overridable_features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.overridable_features_); } if (value != nullptr) { @@ -20539,7 +20539,7 @@ inline void FeatureSetDefaults_FeatureSetEditionDefault::set_allocated_fixed_fea ::google::protobuf::Arena* message_arena = GetArena(); ::google::protobuf::internal::TSanWrite(&_impl_); if (message_arena == nullptr) { - delete _impl_.fixed_features_; + delete reinterpret_cast<::google::protobuf::MessageLite*>(_impl_.fixed_features_); } if (value != nullptr) { diff --git a/src/google/protobuf/message_lite.cc b/src/google/protobuf/message_lite.cc index 3b1d0b4c31..f0283aaa86 100644 --- a/src/google/protobuf/message_lite.cc +++ b/src/google/protobuf/message_lite.cc @@ -50,6 +50,13 @@ namespace google { namespace protobuf { +MessageLite* MessageLite::CopyConstruct(Arena* arena, const MessageLite& from) { + auto* data = from.GetClassData(); + auto* res = data->New(arena); + data->merge_to_from(*res, from); + return res; +} + void MessageLite::DestroyInstance() { #if defined(PROTOBUF_CUSTOM_VTABLE) _class_data_->destroy_message(*this); diff --git a/src/google/protobuf/message_lite.h b/src/google/protobuf/message_lite.h index 18eabd9aca..a95db71990 100644 --- a/src/google/protobuf/message_lite.h +++ b/src/google/protobuf/message_lite.h @@ -920,6 +920,15 @@ class PROTOBUF_EXPORT MessageLite { return static_cast(Arena::CopyConstruct(arena, &from)); } + // As above, but for fields that use base class type. Eg foreign weak fields. + static MessageLite* CopyConstruct(Arena* arena, const MessageLite& from); + + PROTOBUF_ALWAYS_INLINE static Message* CopyConstruct(Arena* arena, + const Message& from) { + return reinterpret_cast( + CopyConstruct(arena, reinterpret_cast(from))); + } + const internal::TcParseTableBase* GetTcParseTable() const { auto* data = GetClassData(); ABSL_DCHECK(data != nullptr);