From 923e8b756faa1ef7fdd5d2a8363ad1d35e5aed38 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Fri, 23 Jun 2023 10:26:51 -0700 Subject: [PATCH] Fix Serialize failing for Proxy/CProxy. Refactor minitable access. PiperOrigin-RevId: 542895074 --- protos/protos.h | 106 ++++++++--------------- protos/repeated_field.h | 6 +- protos_generator/gen_messages.cc | 29 ++++--- protos_generator/tests/BUILD | 2 + protos_generator/tests/test_generated.cc | 32 +++++-- 5 files changed, 86 insertions(+), 89 deletions(-) diff --git a/protos/protos.h b/protos/protos.h index 809e964e96..2e2fae915c 100644 --- a/protos/protos.h +++ b/protos/protos.h @@ -199,12 +199,12 @@ class ExtensionIdentifier : public ExtensionMiniTableProvider { }; template -void* GetInternalMsg(const T& message) { - return message.msg(); +void* GetInternalMsg(const T* message) { + return message->msg(); } template -void* GetInternalMsg(const Ptr& message) { +void* GetInternalMsg(Ptr message) { return message->msg(); } @@ -214,10 +214,20 @@ upb_Arena* GetArena(const T& message) { } template -upb_Arena* GetArena(const Ptr& message) { +upb_Arena* GetArena(Ptr message) { return static_cast(message->GetInternalArena()); } +template +const upb_MiniTable* GetMiniTable(const T*) { + return T::minitable(); +} + +template +const upb_MiniTable* GetMiniTable(Ptr) { + return T::minitable(); +} + upb_ExtensionRegistry* GetUpbExtensions( const ExtensionRegistry& extension_registry); @@ -328,7 +338,7 @@ absl::Status SetExtension( if (message_arena != extension_arena) { upb_Arena_Fuse(message_arena, extension_arena); } - msg_ext->data.ptr = ::protos::internal::GetInternalMsg(value); + msg_ext->data.ptr = ::protos::internal::GetInternalMsg(&value); return absl::OkStatus(); } @@ -362,7 +372,7 @@ absl::StatusOr> GetExtension( template bool Parse(T& message, absl::string_view bytes) { - upb_Message_Clear(message.msg(), T::minitable()); + upb_Message_Clear(message.msg(), ::protos::internal::GetMiniTable(&message)); auto* arena = static_cast(message.GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(), /* extreg= */ nullptr, /* options= */ 0, @@ -372,9 +382,10 @@ bool Parse(T& message, absl::string_view bytes) { template bool Parse(T& message, absl::string_view bytes, const ::protos::ExtensionRegistry& extension_registry) { - upb_Message_Clear(message.msg(), T::minitable()); + upb_Message_Clear(message.msg(), ::protos::internal::GetMiniTable(message)); auto* arena = static_cast(message.GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(), + return upb_Decode(bytes.data(), bytes.size(), message.msg(), + ::protos::internal::GetMiniTable(message), /* extreg= */ ::protos::internal::GetUpbExtensions(extension_registry), /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; @@ -382,9 +393,10 @@ bool Parse(T& message, absl::string_view bytes, template bool Parse(Ptr& message, absl::string_view bytes) { - upb_Message_Clear(message->msg(), T::minitable()); + upb_Message_Clear(message->msg(), ::protos::internal::GetMiniTable(message)); auto* arena = static_cast(message->GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), + return upb_Decode(bytes.data(), bytes.size(), message->msg(), + ::protos::internal::GetMiniTable(message), /* extreg= */ nullptr, /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; } @@ -392,60 +404,32 @@ bool Parse(Ptr& message, absl::string_view bytes) { template bool Parse(Ptr& message, absl::string_view bytes, const ::protos::ExtensionRegistry& extension_registry) { - upb_Message_Clear(message->msg(), T::minitable()); - auto* arena = static_cast(message->GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), - /* extreg= */ - ::protos::internal::GetUpbExtensions(extension_registry), - /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; -} - -template -bool Parse(std::unique_ptr& message, absl::string_view bytes) { - upb_Message_Clear(message->msg(), T::minitable()); - auto* arena = static_cast(message->GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), - /* extreg= */ nullptr, /* options= */ 0, - arena) == kUpb_DecodeStatus_Ok; -} - -template -bool Parse(std::unique_ptr& message, absl::string_view bytes, - const ::protos::ExtensionRegistry& extension_registry) { - upb_Message_Clear(message->msg(), T::minitable()); + upb_Message_Clear(message->msg(), ::protos::internal::GetMiniTable(message)); auto* arena = static_cast(message->GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), + return upb_Decode(bytes.data(), bytes.size(), message->msg(), + ::protos::internal::GetMiniTable(message), /* extreg= */ ::protos::internal::GetUpbExtensions(extension_registry), /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; } template -bool Parse(std::shared_ptr& message, absl::string_view bytes) { - upb_Message_Clear(message->msg(), T::minitable()); +bool Parse(T* message, absl::string_view bytes) { + upb_Message_Clear(message->msg(), ::protos::internal::GetMiniTable(message)); auto* arena = static_cast(message->GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), + return upb_Decode(bytes.data(), bytes.size(), message->msg(), + ::protos::internal::GetMiniTable(message), /* extreg= */ nullptr, /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; } -template -bool Parse(std::shared_ptr& message, absl::string_view bytes, - const ::protos::ExtensionRegistry& extension_registry) { - upb_Message_Clear(message->msg(), T::minitable()); - auto* arena = static_cast(message->GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), - /* extreg= */ - ::protos::internal::GetUpbExtensions(extension_registry), - /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; -} - template absl::StatusOr Parse(absl::string_view bytes, int options = 0) { T message; auto* arena = static_cast(message.GetInternalArena()); upb_DecodeStatus status = - upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(), + upb_Decode(bytes.data(), bytes.size(), message.msg(), + ::protos::internal::GetMiniTable(&message), /* extreg= */ nullptr, /* options= */ 0, arena); if (status == kUpb_DecodeStatus_Ok) { return message; @@ -460,7 +444,8 @@ absl::StatusOr Parse(absl::string_view bytes, T message; auto* arena = static_cast(message.GetInternalArena()); upb_DecodeStatus status = - upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(), + upb_Decode(bytes.data(), bytes.size(), message.msg(), + ::protos::internal::GetMiniTable(&message), ::protos::internal::GetUpbExtensions(extension_registry), /* options= */ 0, arena); if (status == kUpb_DecodeStatus_Ok) { @@ -470,34 +455,19 @@ absl::StatusOr Parse(absl::string_view bytes, } template -absl::StatusOr Serialize(const T& message, upb::Arena& arena, +absl::StatusOr Serialize(const T* message, upb::Arena& arena, int options = 0) { return ::protos::internal::Serialize( - ::protos::internal::GetInternalMsg(message), T::minitable(), arena.ptr(), + message->msg(), ::protos::internal::GetMiniTable(message), arena.ptr(), options); } -template -absl::StatusOr Serialize(std::unique_ptr& message, - upb::Arena& arena, - int options = 0) { - return ::protos::internal::Serialize(message->msg(), T::minitable(), - arena.ptr(), options); -} - -template -absl::StatusOr Serialize(std::shared_ptr& message, - upb::Arena& arena, - int options = 0) { - return ::protos::internal::Serialize(message->msg(), T::minitable(), - arena.ptr(), options); -} - template absl::StatusOr Serialize(Ptr message, upb::Arena& arena, int options = 0) { - return ::protos::internal::Serialize(message->msg(), T::minitable(), - arena.ptr(), options); + return ::protos::internal::Serialize( + message->msg(), ::protos::internal::GetMiniTable(message), arena.ptr(), + options); } } // namespace protos diff --git a/protos/repeated_field.h b/protos/repeated_field.h index a040e5059e..3bdb20d68e 100644 --- a/protos/repeated_field.h +++ b/protos/repeated_field.h @@ -145,8 +145,8 @@ class RepeatedFieldProxy typename = std::enable_if_t> void push_back(const T& t) { upb_MessageValue message_value; - message_value.msg_val = - upb_Message_DeepClone(GetInternalMsg(t), T::minitable(), this->arena_); + message_value.msg_val = upb_Message_DeepClone( + GetInternalMsg(&t), ::protos::internal::GetMiniTable(&t), this->arena_); upb_Array_Append(this->arr_, message_value, this->arena_); } @@ -155,7 +155,7 @@ class RepeatedFieldProxy typename = std::enable_if_t> void push_back(T&& msg) { upb_MessageValue message_value; - message_value.msg_val = GetInternalMsg(msg); + message_value.msg_val = GetInternalMsg(&msg); upb_Arena_Fuse(GetArena(msg), this->arena_); upb_Array_Append(this->arr_, message_value, this->arena_); T moved_msg = std::move(msg); diff --git a/protos_generator/gen_messages.cc b/protos_generator/gen_messages.cc index 3fb5d4dbbd..6396c07ab0 100644 --- a/protos_generator/gen_messages.cc +++ b/protos_generator/gen_messages.cc @@ -31,6 +31,7 @@ #include #include "google/protobuf/descriptor.pb.h" +#include "absl/strings/str_cat.h" #include "google/protobuf/descriptor.h" #include "protos_generator/gen_accessors.h" #include "protos_generator/gen_enums.h" @@ -127,9 +128,8 @@ void WriteModelAccessDeclaration(const protobuf::Descriptor* descriptor, friend class $2; friend class $0Proxy; friend class $0CProxy; - friend void* ::protos::internal::GetInternalMsg<$2>(const $2& message); - friend void* ::protos::internal::GetInternalMsg<$2>( - const ::protos::Ptr<$2>& message); + friend void* ::protos::internal::GetInternalMsg<$2>(const $2* message); + friend void* ::protos::internal::GetInternalMsg<$2>(::protos::Ptr<$2> message); $1* msg_; upb_Arena* arena_; )cc", @@ -166,7 +166,7 @@ void WriteModelPublicDeclaration( inline $0& operator=(const CProxy& from) { arena_ = owned_arena_.ptr(); msg_ = ($2*)upb_Message_DeepClone( - ::protos::internal::GetInternalMsg(from), &$1, arena_); + ::protos::internal::GetInternalMsg(&from), &$1, arena_); return *this; } $0($0&& m) @@ -223,8 +223,7 @@ void WriteModelPublicDeclaration( const ::protos::ExtensionRegistry& extension_registry, int options)); friend upb_Arena* ::protos::internal::GetArena<$0>(const $0& message); - friend upb_Arena* ::protos::internal::GetArena<$0>( - const ::protos::Ptr<$0>& message); + friend upb_Arena* ::protos::internal::GetArena<$0>(::protos::Ptr<$0> message); friend $0(::protos::internal::MoveMessage<$0>(upb_Message* msg, upb_Arena* arena)); )cc", @@ -271,9 +270,13 @@ void WriteModelProxyDeclaration(const protobuf::Descriptor* descriptor, friend class $0Access; friend class ::protos::Ptr<$0>; friend class ::protos::Ptr; + static const upb_MiniTable* minitable() { return $0::minitable(); } + friend const upb_MiniTable* ::protos::internal::GetMiniTable<$0Proxy>( + const $0Proxy* message); + friend const upb_MiniTable* ::protos::internal::GetMiniTable<$0Proxy>( + ::protos::Ptr<$0Proxy> message); friend upb_Arena* ::protos::internal::GetArena<$2>(const $2& message); - friend upb_Arena* ::protos::internal::GetArena<$2>( - const ::protos::Ptr<$2>& message); + friend upb_Arena* ::protos::internal::GetArena<$2>(::protos::Ptr<$2> message); friend $0Proxy(::protos::CloneMessage(::protos::Ptr<$2> message, ::upb::Arena& arena)); static void Rebind($0Proxy& lhs, const $0Proxy& rhs) { @@ -313,6 +316,12 @@ void WriteModelCProxyDeclaration(const protobuf::Descriptor* descriptor, friend class RepeatedFieldProxy; friend class ::protos::Ptr<$0>; friend class ::protos::Ptr; + static const upb_MiniTable* minitable() { return $0::minitable(); } + friend const upb_MiniTable* ::protos::internal::GetMiniTable<$0CProxy>( + const $0CProxy* message); + friend const upb_MiniTable* ::protos::internal::GetMiniTable<$0CProxy>( + ::protos::Ptr<$0CProxy> message); + static void Rebind($0CProxy& lhs, const $0CProxy& rhs) { lhs.msg_ = rhs.msg_; lhs.arena_ = rhs.arena_; @@ -349,12 +358,12 @@ void WriteMessageImplementation( $0::$0(const CProxy& from) : $0Access() { arena_ = owned_arena_.ptr(); msg_ = ($1*)upb_Message_DeepClone( - ::protos::internal::GetInternalMsg(from), &$2, arena_); + ::protos::internal::GetInternalMsg(&from), &$2, arena_); } $0::$0(const Proxy& from) : $0(static_cast(from)) {} internal::$0CProxy::$0CProxy($0Proxy m) : $0Access() { arena_ = m.arena_; - msg_ = ($1*)::protos::internal::GetInternalMsg(m); + msg_ = ($1*)::protos::internal::GetInternalMsg(&m); } )cc", ClassName(descriptor), MessageName(descriptor), diff --git a/protos_generator/tests/BUILD b/protos_generator/tests/BUILD index 82cbf5596e..5e66184c04 100644 --- a/protos_generator/tests/BUILD +++ b/protos_generator/tests/BUILD @@ -147,6 +147,8 @@ cc_test( ":test_model_upb_proto", ":naming_conflict_upb_cc_proto", "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "//:upb", "//protos", "//protos:repeated_field", diff --git a/protos_generator/tests/test_generated.cc b/protos_generator/tests/test_generated.cc index b2866b66c3..fd117a1fe0 100644 --- a/protos_generator/tests/test_generated.cc +++ b/protos_generator/tests/test_generated.cc @@ -28,12 +28,15 @@ #include #include "gtest/gtest.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "protos/protos.h" #include "protos/repeated_field.h" #include "protos/repeated_field_iterator.h" #include "protos_generator/tests/child_model.upb.proto.h" #include "protos_generator/tests/no_package.upb.proto.h" #include "protos_generator/tests/test_model.upb.proto.h" +#include "upb/upb.hpp" using ::protos_generator::test::protos::ChildModel1; using ::protos_generator::test::protos::other_ext; @@ -712,7 +715,20 @@ TEST(CppGeneratedCode, SerializeUsingArena) { TestModel model; model.set_str1("Hello World"); ::upb::Arena arena; - absl::StatusOr bytes = ::protos::Serialize(model, arena); + absl::StatusOr bytes = ::protos::Serialize(&model, arena); + EXPECT_EQ(true, bytes.ok()); + TestModel parsed_model = ::protos::Parse(bytes.value()).value(); + EXPECT_EQ("Hello World", parsed_model.str1()); +} + +TEST(CppGeneratedCode, SerializeProxyUsingArena) { + ::upb::Arena message_arena; + TestModel::Proxy model_proxy = + ::protos::CreateMessage(message_arena); + model_proxy.set_str1("Hello World"); + ::upb::Arena arena; + absl::StatusOr bytes = + ::protos::Serialize(&model_proxy, arena); EXPECT_EQ(true, bytes.ok()); TestModel parsed_model = ::protos::Parse(bytes.value()).value(); EXPECT_EQ("Hello World", parsed_model.str1()); @@ -736,7 +752,7 @@ TEST(CppGeneratedCode, Parse) { extension1.set_ext_name("Hello World"); EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok()); ::upb::Arena arena; - auto bytes = ::protos::Serialize(model, arena); + auto bytes = ::protos::Serialize(&model, arena); EXPECT_EQ(true, bytes.ok()); TestModel parsed_model = ::protos::Parse(bytes.value()).value(); EXPECT_EQ("Test123", parsed_model.str1()); @@ -751,7 +767,7 @@ TEST(CppGeneratedCode, ParseIntoPtrToModel) { extension1.set_ext_name("Hello World"); EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok()); ::upb::Arena arena; - auto bytes = ::protos::Serialize(model, arena); + auto bytes = ::protos::Serialize(&model, arena); EXPECT_EQ(true, bytes.ok()); ::protos::Ptr parsed_model = ::protos::CreateMessage(arena); @@ -771,7 +787,7 @@ TEST(CppGeneratedCode, ParseWithExtensionRegistry) { extension1) .ok()); ::upb::Arena arena; - auto bytes = ::protos::Serialize(model, arena); + auto bytes = ::protos::Serialize(&model, arena); EXPECT_EQ(true, bytes.ok()); ::protos::ExtensionRegistry extensions( {&theme, &other_ext, &ThemeExtension::theme_extension}, arena); @@ -799,15 +815,15 @@ TEST(CppGeneratedCode, NameCollisions) { TEST(CppGeneratedCode, SharedPointer) { std::shared_ptr model = std::make_shared(); ::upb::Arena arena; - auto bytes = protos::Serialize(model, arena); - EXPECT_TRUE(protos::Parse(model, bytes.value())); + auto bytes = protos::Serialize(model.get(), arena); + EXPECT_TRUE(protos::Parse(model.get(), bytes.value())); } TEST(CppGeneratedCode, UniquePointer) { auto model = std::make_unique(); ::upb::Arena arena; - auto bytes = protos::Serialize(model, arena); - EXPECT_TRUE(protos::Parse(model, bytes.value())); + auto bytes = protos::Serialize(model.get(), arena); + EXPECT_TRUE(protos::Parse(model.get(), bytes.value())); } TEST(CppGeneratedCode, Assignment) {