From 07f7b2086ec9537bfcbb657fa8070db42f73e029 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Mon, 26 Jun 2023 09:18:40 -0700 Subject: [PATCH] Add support for HasExtension/GetExtension without ExtensionRegistry. Update signatures to use T* Ptr consistently. Cross language blocks utility updated to use GetArena(T*) Fixes arena_ for Proxy/CProxy ::Access class now consistently fills in arena_ (where message was created in). PiperOrigin-RevId: 543457599 --- protos/BUILD | 1 + protos/protos.cc | 23 +++++ protos/protos.h | 109 +++++++++-------------- protos/repeated_field.h | 27 +++--- protos_generator/gen_accessors.cc | 5 +- protos_generator/gen_messages.cc | 31 ++++--- protos_generator/gen_repeated_fields.cc | 9 +- protos_generator/tests/test_generated.cc | 79 ++++++++++------ 8 files changed, 161 insertions(+), 123 deletions(-) diff --git a/protos/BUILD b/protos/BUILD index be62cb73ce..e6b1a7fab3 100644 --- a/protos/BUILD +++ b/protos/BUILD @@ -71,6 +71,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//:message_copy", + "//:message_promote", "//:mini_table", "//:upb", "@com_google_absl//absl/status", diff --git a/protos/protos.cc b/protos/protos.cc index 7a651c497f..229ce4cf75 100644 --- a/protos/protos.cc +++ b/protos/protos.cc @@ -28,6 +28,8 @@ #include "protos/protos.h" #include "absl/strings/str_format.h" +#include "upb/message/promote.h" +#include "upb/wire/common.h" namespace protos { @@ -90,6 +92,27 @@ upb_ExtensionRegistry* GetUpbExtensions( return extension_registry.registry_; } +bool HasExtensionOrUnknown(const upb_Message* msg, + const upb_MiniTableExtension* eid) { + return _upb_Message_Getext(msg, eid) != nullptr || + upb_MiniTable_FindUnknown(msg, eid->field.number, + kUpb_WireFormat_DefaultDepthLimit) + .status == kUpb_FindUnknown_Ok; +} + +const upb_Message_Extension* GetOrPromoteExtension( + upb_Message* msg, const upb_MiniTableExtension* eid, upb_Arena* arena) { + const upb_Message_Extension* ext = _upb_Message_Getext(msg, eid); + if (ext == nullptr) { + upb_GetExtension_Status ext_status = upb_MiniTable_GetOrPromoteExtension( + (upb_Message*)msg, eid, kUpb_WireFormat_DefaultDepthLimit, arena, &ext); + if (ext_status != kUpb_GetExtension_Ok) { + return nullptr; + } + } + return ext; +} + absl::StatusOr Serialize(const upb_Message* message, const upb_MiniTable* mini_table, upb_Arena* arena, int options) { diff --git a/protos/protos.h b/protos/protos.h index abeb427501..43b9885da4 100644 --- a/protos/protos.h +++ b/protos/protos.h @@ -33,6 +33,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "upb/mem/arena.h" #include "upb/message/copy.h" #include "upb/message/extension_internal.h" #include "upb/mini_table/types.h" @@ -195,8 +196,8 @@ typename T::Proxy CreateMessageProxy(void* msg, upb_Arena* arena) { } template -typename T::CProxy CreateMessage(upb_Message* msg) { - return typename T::CProxy(msg); +typename T::CProxy CreateMessage(upb_Message* msg, upb_Arena* arena) { + return typename T::CProxy(msg, arena); } class ExtensionMiniTableProvider { @@ -242,12 +243,12 @@ void* GetInternalMsg(Ptr message) { } template -upb_Arena* GetArena(const T& message) { - return static_cast(message.GetInternalArena()); +upb_Arena* GetArena(Ptr message) { + return static_cast(message->GetInternalArena()); } template -upb_Arena* GetArena(Ptr message) { +upb_Arena* GetArena(T* message) { return static_cast(message->GetInternalArena()); } @@ -268,6 +269,12 @@ absl::StatusOr Serialize(const upb_Message* message, const upb_MiniTable* mini_table, upb_Arena* arena, int options); +bool HasExtensionOrUnknown(const upb_Message* msg, + const upb_MiniTableExtension* eid); + +const upb_Message_Extension* GetOrPromoteExtension( + upb_Message* msg, const upb_MiniTableExtension* eid, upb_Arena* arena); + } // namespace internal class ExtensionRegistry { @@ -306,17 +313,18 @@ using EnableIfMutableProto = std::enable_if_t::value>; template > bool HasExtension( - const T& message, + const Ptr& message, const ::protos::internal::ExtensionIdentifier& id) { - return _upb_Message_Getext(message.msg(), id.mini_table_ext()) != nullptr; + return ::protos::internal::HasExtensionOrUnknown( + ::protos::internal::GetInternalMsg(message), id.mini_table_ext()); } template > bool HasExtension( - const Ptr& message, + const T* message, const ::protos::internal::ExtensionIdentifier& id) { - return _upb_Message_Getext(message->msg(), id.mini_table_ext()) != nullptr; + return HasExtension(protos::Ptr(message), id); } template & message, const ::protos::internal::ExtensionIdentifier& id) { - _upb_Message_ClearExtensionField(message->msg(), id.mini_table_ext()); + static_assert(!std::is_const_v, ""); + _upb_Message_ClearExtensionField(::protos::internal::GetInternalMsg(message), + id.mini_table_ext()); } template > void ClearExtension( - const T& message, + T* message, const ::protos::internal::ExtensionIdentifier& id) { - _upb_Message_ClearExtensionField(message.msg(), id.mini_table_ext()); -} - -template > -absl::Status SetExtension( - const T& message, - const ::protos::internal::ExtensionIdentifier& id, - Extension& value) { - auto* message_arena = static_cast(message.GetInternalArena()); - upb_Message_Extension* msg_ext = _upb_Message_GetOrCreateExtension( - message.msg(), id.mini_table_ext(), message_arena); - if (!msg_ext) { - return MessageAllocationError(); - } - auto* extension_arena = static_cast(value.GetInternalArena()); - if (message_arena != extension_arena) { - upb_Arena_Fuse(message_arena, extension_arena); - } - msg_ext->data.ptr = value.msg(); - return absl::OkStatus(); + ClearExtension(::protos::Ptr(message), id); } template > -absl::StatusOr> GetExtension( - const T& message, - const ::protos::internal::ExtensionIdentifier& id) { - const upb_Message_Extension* ext = - _upb_Message_Getext(message.msg(), id.mini_table_ext()); - if (!ext) { - return ExtensionNotFoundError(id.mini_table_ext()->field.number); - } - return Ptr( - ::protos::internal::CreateMessage(ext->data.ptr)); +absl::Status SetExtension( + T* message, + const ::protos::internal::ExtensionIdentifier& id, + Extension& value) { + return ::protos::SetExtension(::protos::Ptr(message), id, value); } template > GetExtension( const Ptr& message, const ::protos::internal::ExtensionIdentifier& id) { - const upb_Message_Extension* ext = - _upb_Message_Getext(message->msg(), id.mini_table_ext()); + const upb_Message_Extension* ext = ::protos::internal::GetOrPromoteExtension( + ::protos::internal::GetInternalMsg(message), id.mini_table_ext(), + ::protos::internal::GetArena(message)); if (!ext) { return ExtensionNotFoundError(id.mini_table_ext()->field.number); } - return Ptr( - ::protos::internal::CreateMessage(ext->data.ptr)); + return Ptr(::protos::internal::CreateMessage( + ext->data.ptr, ::protos::internal::GetArena(message))); } -template -bool Parse(T& message, absl::string_view bytes) { - upb_Message_Clear(message.msg(), ::protos::internal::GetMiniTable(&message)); - auto* arena = static_cast(message.GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(), - /* extreg= */ nullptr, /* options= */ 0, - arena) == kUpb_DecodeStatus_Ok; -} - -template -bool Parse(T& message, absl::string_view bytes, - const ::protos::ExtensionRegistry& extension_registry) { - upb_Message_Clear(message.msg(), ::protos::internal::GetMiniTable(message)); - auto* arena = static_cast(message.GetInternalArena()); - return upb_Decode(bytes.data(), bytes.size(), message.msg(), - ::protos::internal::GetMiniTable(message), - /* extreg= */ - ::protos::internal::GetUpbExtensions(extension_registry), - /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; +template > +absl::StatusOr> GetExtension( + const T* message, + const ::protos::internal::ExtensionIdentifier& id) { + return GetExtension(protos::Ptr(message), id); } template @@ -446,6 +419,12 @@ bool Parse(Ptr& message, absl::string_view bytes, /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; } +template +bool Parse(T* message, absl::string_view bytes, + const ::protos::ExtensionRegistry& extension_registry) { + return Parse(protos::Ptr(message, bytes, extension_registry)); +} + template bool Parse(T* message, absl::string_view bytes) { upb_Message_Clear(message->msg(), ::protos::internal::GetMiniTable(message)); diff --git a/protos/repeated_field.h b/protos/repeated_field.h index 3bdb20d68e..6233109d8a 100644 --- a/protos/repeated_field.h +++ b/protos/repeated_field.h @@ -64,7 +64,8 @@ class RepeatedFieldProxyBase { using Array = add_const_if_T_is_const; public: - explicit RepeatedFieldProxyBase(Array* arr) : arr_(arr) {} + explicit RepeatedFieldProxyBase(Array* arr, upb_Arena* arena) + : arr_(arr), arena_(arena) {} size_t size() const { return arr_ != nullptr ? upb_Array_Size(arr_) : 0; } @@ -78,6 +79,7 @@ class RepeatedFieldProxyBase { inline upb_Message* GetMessage(size_t n) const; Array* arr_; + upb_Arena* arena_; }; template @@ -98,12 +100,9 @@ template class RepeatedFieldProxyMutableBase : public RepeatedFieldProxyBase { public: RepeatedFieldProxyMutableBase(upb_Array* arr, upb_Arena* arena) - : RepeatedFieldProxyBase(arr), arena_(arena) {} - - void clear() { upb_Array_Resize(this->arr_, 0, arena_); } + : RepeatedFieldProxyBase(arr, arena) {} - protected: - upb_Arena* arena_; + void clear() { upb_Array_Resize(this->arr_, 0, this->arena_); } }; // RepeatedField proxy for repeated messages. @@ -117,8 +116,8 @@ class RepeatedFieldProxy static constexpr bool kIsConst = std::is_const_v; public: - explicit RepeatedFieldProxy(const upb_Array* arr) - : RepeatedFieldProxyBase(arr) {} + explicit RepeatedFieldProxy(const upb_Array* arr, upb_Arena* arena) + : RepeatedFieldProxyBase(arr, arena) {} RepeatedFieldProxy(upb_Array* arr, upb_Arena* arena) : RepeatedFieldProxyMutableBase(arr, arena) {} // Constructor used by ::protos::Ptr. @@ -128,7 +127,7 @@ class RepeatedFieldProxy typename T::CProxy operator[](size_t n) const { upb_MessageValue message_value = upb_Array_Get(this->arr_, n); return ::protos::internal::CreateMessage>( - (upb_Message*)message_value.msg_val); + (upb_Message*)message_value.msg_val, this->arena_); } // TODO(b:/280069986) : Audit/Finalize based on Iterator Design. @@ -156,7 +155,7 @@ class RepeatedFieldProxy void push_back(T&& msg) { upb_MessageValue message_value; message_value.msg_val = GetInternalMsg(&msg); - upb_Arena_Fuse(GetArena(msg), this->arena_); + upb_Arena_Fuse(GetArena(&msg), this->arena_); upb_Array_Append(this->arr_, message_value, this->arena_); T moved_msg = std::move(msg); } @@ -177,8 +176,8 @@ class RepeatedFieldStringProxy public: // Immutable constructor. - explicit RepeatedFieldStringProxy(const upb_Array* arr) - : RepeatedFieldProxyBase(arr) {} + explicit RepeatedFieldStringProxy(const upb_Array* arr, upb_Arena* arena) + : RepeatedFieldProxyBase(arr, arena) {} // Mutable constructor. RepeatedFieldStringProxy(upb_Array* arr, upb_Arena* arena) : RepeatedFieldProxyMutableBase(arr, arena) {} @@ -210,8 +209,8 @@ class RepeatedFieldScalarProxy static constexpr bool kIsConst = std::is_const_v; public: - explicit RepeatedFieldScalarProxy(const upb_Array* arr) - : RepeatedFieldProxyBase(arr) {} + explicit RepeatedFieldScalarProxy(const upb_Array* arr, upb_Arena* arena) + : RepeatedFieldProxyBase(arr, arena) {} RepeatedFieldScalarProxy(upb_Array* arr, upb_Arena* arena) : RepeatedFieldProxyMutableBase(arr, arena) {} // Constructor used by ::protos::Ptr. diff --git a/protos_generator/gen_accessors.cc b/protos_generator/gen_accessors.cc index f7324e8188..1b4f9edaca 100644 --- a/protos_generator/gen_accessors.cc +++ b/protos_generator/gen_accessors.cc @@ -262,7 +262,8 @@ void WriteAccessorsInSource(const protobuf::Descriptor* desc, Output& output) { if (!has_$2()) { return $4::default_instance(); } - return ::protos::internal::CreateMessage<$4>((upb_Message*)($3_$5(msg_))); + return ::protos::internal::CreateMessage<$4>( + (upb_Message*)($3_$5(msg_)), arena_); } )cc", class_name, MessagePtrConstType(field, /* is_const */ true), @@ -338,7 +339,7 @@ void WriteMapAccessorDefinitions(const protobuf::Descriptor* message, $5* msg_value; $7bool success = $4_$9_get(msg_, $8, &msg_value); if (success) { - return ::protos::internal::CreateMessage<$6>(msg_value); + return ::protos::internal::CreateMessage<$6>(msg_value, arena_); } return absl::NotFoundError(""); } diff --git a/protos_generator/gen_messages.cc b/protos_generator/gen_messages.cc index d8cc569f21..669d12226b 100644 --- a/protos_generator/gen_messages.cc +++ b/protos_generator/gen_messages.cc @@ -111,9 +111,13 @@ void WriteModelAccessDeclaration(const protobuf::Descriptor* descriptor, class $0Access { public: $0Access() {} - $0Access($1* msg, upb_Arena* arena) : msg_(msg), arena_(arena) {} // NOLINT + $0Access($1* msg, upb_Arena* arena) : msg_(msg), arena_(arena) { + assert(arena != nullptr); + } // NOLINT $0Access(const $1* msg, upb_Arena* arena) - : msg_(const_cast<$1*>(msg)), arena_(arena) {} // NOLINT + : msg_(const_cast<$1*>(msg)), arena_(arena) { + assert(arena != nullptr); + } // NOLINT void* GetInternalArena() const { return arena_; } )cc", ClassName(descriptor), MessageName(descriptor)); @@ -222,7 +226,7 @@ void WriteModelPublicDeclaration( absl::string_view bytes, const ::protos::ExtensionRegistry& extension_registry, int options)); - friend upb_Arena* ::protos::internal::GetArena<$0>(const $0& message); + friend upb_Arena* ::protos::internal::GetArena<$0>($0* message); friend upb_Arena* ::protos::internal::GetArena<$0>(::protos::Ptr<$0> message); friend $0(::protos::internal::MoveMessage<$0>(upb_Message* msg, upb_Arena* arena)); @@ -279,7 +283,7 @@ void WriteModelProxyDeclaration(const protobuf::Descriptor* descriptor, const $0Proxy* message); friend const upb_MiniTable* ::protos::internal::GetMiniTable<$0Proxy>( ::protos::Ptr<$0Proxy> message); - friend upb_Arena* ::protos::internal::GetArena<$2>(const $2& message); + friend upb_Arena* ::protos::internal::GetArena<$2>($2* message); friend upb_Arena* ::protos::internal::GetArena<$2>(::protos::Ptr<$2> message); friend $0Proxy(::protos::CloneMessage(::protos::Ptr<$2> message, ::upb::Arena& arena)); @@ -302,7 +306,8 @@ void WriteModelCProxyDeclaration(const protobuf::Descriptor* descriptor, class $0CProxy final : private internal::$0Access { public: $0CProxy() = delete; - $0CProxy(const $0* m) : internal::$0Access(m->msg_, nullptr) {} + $0CProxy(const $0* m) + : internal::$0Access(m->msg_, ::protos::internal::GetArena(m)) {} $0CProxy($0Proxy m); using $0Access::GetInternalArena; )cc", @@ -315,8 +320,9 @@ void WriteModelCProxyDeclaration(const protobuf::Descriptor* descriptor, output( R"cc( private: - $0CProxy(void* msg) : internal::$0Access(($1*)msg, nullptr){}; - friend $0::CProxy(::protos::internal::CreateMessage<$0>(upb_Message* msg)); + $0CProxy(void* msg, upb_Arena* arena) : internal::$0Access(($1*)msg, arena){}; + friend $0::CProxy(::protos::internal::CreateMessage<$0>( + upb_Message* msg, upb_Arena* arena)); friend class RepeatedFieldProxy; friend class ::protos::Ptr<$0>; friend class ::protos::Ptr; @@ -390,9 +396,13 @@ void WriteMessageImplementation( R"cc( struct $0DefaultTypeInternal { $1* msg; + upb_Arena* arena; }; - $0DefaultTypeInternal _$0_default_instance_ = - $0DefaultTypeInternal{$1_new(upb_Arena_New())}; + static $0DefaultTypeInternal _$0DefaultTypeBuilder() { + upb_Arena* arena = upb_Arena_New(); + return $0DefaultTypeInternal{$1_new(arena), arena}; + } + $0DefaultTypeInternal _$0_default_instance_ = _$0DefaultTypeBuilder(); )cc", ClassName(descriptor), MessageName(descriptor)); @@ -400,7 +410,8 @@ void WriteMessageImplementation( R"cc( ::protos::Ptr $0::default_instance() { return ::protos::internal::CreateMessage<$0>( - (upb_Message *)_$0_default_instance_.msg); + (upb_Message *)_$0_default_instance_.msg, + _$0_default_instance_.arena); } )cc", ClassName(descriptor)); diff --git a/protos_generator/gen_repeated_fields.cc b/protos_generator/gen_repeated_fields.cc index 48d77c1986..15219f8444 100644 --- a/protos_generator/gen_repeated_fields.cc +++ b/protos_generator/gen_repeated_fields.cc @@ -153,7 +153,8 @@ void WriteRepeatedMessageAccessor(const protobuf::Descriptor* message, size_t len; auto* ptr = $3_$5(msg_, &len); assert(index < len); - return ::protos::internal::CreateMessage<$4>((upb_Message*)*(ptr + index)); + return ::protos::internal::CreateMessage<$4>( + (upb_Message*)*(ptr + index), arena_); } )cc", class_name, MessagePtrConstType(field, /* is_const */ true), @@ -192,7 +193,7 @@ void WriteRepeatedMessageAccessor(const protobuf::Descriptor* message, const ::protos::RepeatedField::CProxy $0::$2() const { size_t size; const upb_Array* arr = _$3_$4_$5(msg_, &size); - return ::protos::RepeatedField::CProxy(arr); + return ::protos::RepeatedField::CProxy(arr, arena_); }; ::protos::Ptr<::protos::RepeatedField<$1>> $0::mutable_$2() { size_t size; @@ -258,7 +259,7 @@ void WriteRepeatedStringAccessor(const protobuf::Descriptor* message, const ::protos::RepeatedField<$1>::CProxy $0::$2() const { size_t size; const upb_Array* arr = _$3_$4_$5(msg_, &size); - return ::protos::RepeatedField<$1>::CProxy(arr); + return ::protos::RepeatedField<$1>::CProxy(arr, arena_); }; ::protos::Ptr<::protos::RepeatedField<$1>> $0::mutable_$2() { size_t size; @@ -322,7 +323,7 @@ void WriteRepeatedScalarAccessor(const protobuf::Descriptor* message, const ::protos::RepeatedField<$1>::CProxy $0::$2() const { size_t size; const upb_Array* arr = _$3_$4_$5(msg_, &size); - return ::protos::RepeatedField<$1>::CProxy(arr); + return ::protos::RepeatedField<$1>::CProxy(arr, arena_); }; ::protos::Ptr<::protos::RepeatedField<$1>> $0::mutable_$2() { size_t size; diff --git a/protos_generator/tests/test_generated.cc b/protos_generator/tests/test_generated.cc index f0406adb45..bfeffffad7 100644 --- a/protos_generator/tests/test_generated.cc +++ b/protos_generator/tests/test_generated.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include "gtest/gtest.h" @@ -626,7 +627,7 @@ TEST(CppGeneratedCode, MessageMapStringKeyAndInt32Value) { TEST(CppGeneratedCode, HasExtension) { TestModel model; - EXPECT_EQ(false, ::protos::HasExtension(model, theme)); + EXPECT_EQ(false, ::protos::HasExtension(&model, theme)); } TEST(CppGeneratedCode, HasExtensionPtr) { @@ -636,9 +637,9 @@ TEST(CppGeneratedCode, HasExtensionPtr) { TEST(CppGeneratedCode, ClearExtensionWithEmptyExtension) { TestModel model; - EXPECT_EQ(false, ::protos::HasExtension(model, theme)); - ::protos::ClearExtension(model, theme); - EXPECT_EQ(false, ::protos::HasExtension(model, theme)); + EXPECT_EQ(false, ::protos::HasExtension(&model, theme)); + ::protos::ClearExtension(&model, theme); + EXPECT_EQ(false, ::protos::HasExtension(&model, theme)); } TEST(CppGeneratedCode, ClearExtensionWithEmptyExtensionPtr) { @@ -652,9 +653,9 @@ TEST(CppGeneratedCode, SetExtension) { TestModel model; ThemeExtension extension1; extension1.set_ext_name("Hello World"); - EXPECT_EQ(false, ::protos::HasExtension(model, theme)); - EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok()); - EXPECT_EQ(true, ::protos::HasExtension(model, theme)); + EXPECT_EQ(false, ::protos::HasExtension(&model, theme)); + EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok()); + EXPECT_EQ(true, ::protos::HasExtension(&model, theme)); } TEST(CppGeneratedCode, SetExtensionOnMutableChild) { @@ -674,10 +675,10 @@ TEST(CppGeneratedCode, GetExtension) { TestModel model; ThemeExtension extension1; extension1.set_ext_name("Hello World"); - EXPECT_EQ(false, ::protos::HasExtension(model, theme)); - EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok()); + EXPECT_EQ(false, ::protos::HasExtension(&model, theme)); + EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok()); EXPECT_EQ("Hello World", - ::protos::GetExtension(model, theme).value()->ext_name()); + ::protos::GetExtension(&model, theme).value()->ext_name()); } TEST(CppGeneratedCode, GetExtensionOnMutableChild) { @@ -750,14 +751,13 @@ TEST(CppGeneratedCode, Parse) { model.set_str1("Test123"); ThemeExtension extension1; extension1.set_ext_name("Hello World"); - EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok()); + EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok()); ::upb::Arena arena; auto bytes = ::protos::Serialize(&model, arena); EXPECT_EQ(true, bytes.ok()); TestModel parsed_model = ::protos::Parse(bytes.value()).value(); EXPECT_EQ("Test123", parsed_model.str1()); - // Should not return an extension since we did not pass ExtensionRegistry. - EXPECT_EQ(false, ::protos::GetExtension(parsed_model, theme).ok()); + EXPECT_EQ(true, ::protos::GetExtension(&parsed_model, theme).ok()); } TEST(CppGeneratedCode, ParseIntoPtrToModel) { @@ -765,7 +765,7 @@ TEST(CppGeneratedCode, ParseIntoPtrToModel) { model.set_str1("Test123"); ThemeExtension extension1; extension1.set_ext_name("Hello World"); - EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok()); + EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok()); ::upb::Arena arena; auto bytes = ::protos::Serialize(&model, arena); EXPECT_EQ(true, bytes.ok()); @@ -773,8 +773,9 @@ TEST(CppGeneratedCode, ParseIntoPtrToModel) { ::protos::CreateMessage(arena); EXPECT_TRUE(::protos::Parse(parsed_model, bytes.value())); EXPECT_EQ("Test123", parsed_model->str1()); - // Should not return an extension since we did not pass ExtensionRegistry. - EXPECT_EQ(false, ::protos::GetExtension(parsed_model, theme).ok()); + // Should return an extension even if we don't pass ExtensionRegistry + // by promoting unknown. + EXPECT_EQ(true, ::protos::GetExtension(parsed_model, theme).ok()); } TEST(CppGeneratedCode, ParseWithExtensionRegistry) { @@ -782,9 +783,9 @@ TEST(CppGeneratedCode, ParseWithExtensionRegistry) { model.set_str1("Test123"); ThemeExtension extension1; extension1.set_ext_name("Hello World"); - EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok()); - EXPECT_EQ(true, ::protos::SetExtension(model, ThemeExtension::theme_extension, - extension1) + EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok()); + EXPECT_EQ(true, ::protos::SetExtension( + &model, ThemeExtension::theme_extension, extension1) .ok()); ::upb::Arena arena; auto bytes = ::protos::Serialize(&model, arena); @@ -794,12 +795,12 @@ TEST(CppGeneratedCode, ParseWithExtensionRegistry) { TestModel parsed_model = ::protos::Parse(bytes.value(), extensions).value(); EXPECT_EQ("Test123", parsed_model.str1()); - EXPECT_EQ(true, ::protos::GetExtension(parsed_model, theme).ok()); - EXPECT_EQ(true, ::protos::GetExtension(parsed_model, + EXPECT_EQ(true, ::protos::GetExtension(&parsed_model, theme).ok()); + EXPECT_EQ(true, ::protos::GetExtension(&parsed_model, ThemeExtension::theme_extension) .ok()); EXPECT_EQ("Hello World", ::protos::GetExtension( - parsed_model, ThemeExtension::theme_extension) + &parsed_model, ThemeExtension::theme_extension) .value() ->ext_name()); } @@ -898,7 +899,7 @@ TEST(CppGeneratedCode, ClearSubMessage) { new_child->set_child_str1("text in child"); ThemeExtension extension1; extension1.set_ext_name("name in extension"); - EXPECT_TRUE(::protos::SetExtension(model, theme, extension1).ok()); + EXPECT_TRUE(::protos::SetExtension(&model, theme, extension1).ok()); EXPECT_TRUE(model.mutable_child_model_1()->has_child_str1()); // Clear using Ptr ::protos::ClearMessage(model.mutable_child_model_1()); @@ -915,14 +916,14 @@ TEST(CppGeneratedCode, ClearMessage) { new_child.value()->set_child_str1("text in child"); ThemeExtension extension1; extension1.set_ext_name("name in extension"); - EXPECT_TRUE(::protos::SetExtension(model, theme, extension1).ok()); + EXPECT_TRUE(::protos::SetExtension(&model, theme, extension1).ok()); // Clear using T* ::protos::ClearMessage(&model); // Verify that scalars, repeated fields and extensions are cleared. EXPECT_FALSE(model.has_int64()); EXPECT_FALSE(model.has_str2()); EXPECT_TRUE(model.child_models().empty()); - EXPECT_FALSE(::protos::HasExtension(model, theme)); + EXPECT_FALSE(::protos::HasExtension(&model, theme)); } TEST(CppGeneratedCode, DeepCopy) { @@ -935,13 +936,35 @@ TEST(CppGeneratedCode, DeepCopy) { new_child.value()->set_child_str1("text in child"); ThemeExtension extension1; extension1.set_ext_name("name in extension"); - EXPECT_TRUE(::protos::SetExtension(model, theme, extension1).ok()); + EXPECT_TRUE(::protos::SetExtension(&model, theme, extension1).ok()); TestModel target; target.set_b1(true); ::protos::DeepCopy(&model, &target); - EXPECT_FALSE(target.b1()) << "Target was not cleared before copying content"; + EXPECT_FALSE(target.b1()) << "Target was not cleared before copying content "; EXPECT_EQ(target.str2(), "Hello"); - EXPECT_TRUE(::protos::HasExtension(target, theme)); + EXPECT_TRUE(::protos::HasExtension(&target, theme)); +} + +TEST(CppGeneratedCode, HasExtensionAndRegistry) { + // Fill model. + TestModel source; + source.set_int64(5); + source.set_str2("Hello"); + auto new_child = source.add_child_models(); + ASSERT_TRUE(new_child.ok()); + new_child.value()->set_child_str1("text in child"); + ThemeExtension extension1; + extension1.set_ext_name("name in extension"); + ASSERT_TRUE(::protos::SetExtension(&source, theme, extension1).ok()); + + // Now that we have a source model with extension data, serialize. + ::protos::Arena arena; + std::string data = std::string(::protos::Serialize(&source, arena).value()); + + // Test with ExtensionRegistry + ::protos::ExtensionRegistry extensions({&theme}, arena); + TestModel parsed_model = ::protos::Parse(data, extensions).value(); + EXPECT_TRUE(::protos::HasExtension(&parsed_model, theme)); } // TODO(b/288491350) : Add BUILD rule to test failures below.