diff --git a/protos/protos.cc b/protos/protos.cc index fde6120a5a..db814f3603 100644 --- a/protos/protos.cc +++ b/protos/protos.cc @@ -20,6 +20,7 @@ #include "upb/message/internal/extension.h" #include "upb/message/message.h" #include "upb/message/promote.h" +#include "upb/message/value.h" #include "upb/mini_table/extension.h" #include "upb/mini_table/extension_registry.h" #include "upb/mini_table/message.h" @@ -122,19 +123,12 @@ bool HasExtensionOrUnknown(const upb_Message* msg, .status == kUpb_FindUnknown_Ok; } -const upb_Extension* GetOrPromoteExtension(upb_Message* msg, - const upb_MiniTableExtension* eid, - upb_Arena* arena) { +bool GetOrPromoteExtension(upb_Message* msg, const upb_MiniTableExtension* eid, + upb_Arena* arena, upb_MessageValue* value) { MessageLock msg_lock(msg); - const upb_Extension* ext = UPB_PRIVATE(_upb_Message_Getext)(msg, eid); - if (ext == nullptr) { - upb_GetExtension_Status ext_status = upb_MiniTable_GetOrPromoteExtension( - (upb_Message*)msg, eid, 0, arena, &ext); - if (ext_status != kUpb_GetExtension_Ok) { - ext = nullptr; - } - } - return ext; + upb_GetExtension_Status ext_status = upb_Message_GetOrPromoteExtension( + (upb_Message*)msg, eid, 0, arena, value); + return ext_status == kUpb_GetExtension_Ok; } absl::StatusOr Serialize(const upb_Message* message, diff --git a/protos/protos.h b/protos/protos.h index db06de80fd..208d7c17fb 100644 --- a/protos/protos.h +++ b/protos/protos.h @@ -223,9 +223,8 @@ absl::StatusOr Serialize(const upb_Message* message, bool HasExtensionOrUnknown(const upb_Message* msg, const upb_MiniTableExtension* eid); -const upb_Extension* GetOrPromoteExtension(upb_Message* msg, - const upb_MiniTableExtension* eid, - upb_Arena* arena); +bool GetOrPromoteExtension(upb_Message* msg, const upb_MiniTableExtension* eid, + upb_Arena* arena, upb_MessageValue* value); void DeepCopy(upb_Message* target, const upb_Message* source, const upb_MiniTable* mini_table, upb_Arena* arena); @@ -410,15 +409,16 @@ absl::StatusOr> GetExtension( Ptr message, const ::protos::internal::ExtensionIdentifier& id) { // TODO: Fix const correctness issues. - const upb_Extension* ext = ::protos::internal::GetOrPromoteExtension( + upb_MessageValue value; + const bool ok = ::protos::internal::GetOrPromoteExtension( const_cast(internal::GetInternalMsg(message)), - id.mini_table_ext(), ::protos::internal::GetArena(message)); - if (!ext) { + id.mini_table_ext(), ::protos::internal::GetArena(message), &value); + if (!ok) { return ExtensionNotFoundError( upb_MiniTableExtension_Number(id.mini_table_ext())); } return Ptr(::protos::internal::CreateMessage( - (upb_Message*)ext->data.ptr, ::protos::internal::GetArena(message))); + (upb_Message*)value.msg_val, ::protos::internal::GetArena(message))); } template data, sizeof(upb_MessageValue)); return kUpb_GetExtension_Ok; } @@ -104,8 +106,8 @@ upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension( if (!ext) { return kUpb_GetExtension_OutOfMemory; } - memcpy(&ext->data, &extension_msg, sizeof(extension_msg)); - *extension = ext; + ext->data.ptr = extension_msg; + value->msg_val = extension_msg; const char* delete_ptr = upb_Message_GetUnknown(msg, &len) + ofs; upb_Message_DeleteUnknown(msg, delete_ptr, result.len); return kUpb_GetExtension_Ok; diff --git a/upb/message/promote.h b/upb/message/promote.h index 0a43f6bf70..c280c10268 100644 --- a/upb/message/promote.h +++ b/upb/message/promote.h @@ -33,14 +33,13 @@ typedef enum { kUpb_GetExtensionAsBytes_EncodeError, } upb_GetExtensionAsBytes_Status; -// Returns a message extension or promotes an unknown field to -// an extension. +// Returns a message value or promotes an unknown field to an extension. // // TODO: Only supports extension fields that are messages, // expand support to include non-message types. -upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension( +upb_GetExtension_Status upb_Message_GetOrPromoteExtension( upb_Message* msg, const upb_MiniTableExtension* ext_table, - int decode_options, upb_Arena* arena, const upb_Extension** extension); + int decode_options, upb_Arena* arena, upb_MessageValue* value); typedef enum { kUpb_FindUnknown_Ok, diff --git a/upb/message/promote_test.cc b/upb/message/promote_test.cc index 3b79cb8b55..1dc8932e87 100644 --- a/upb/message/promote_test.cc +++ b/upb/message/promote_test.cc @@ -123,57 +123,57 @@ TEST(GeneratedCode, Extensions) { char* serialized = upb_test_ModelWithExtensions_serialize(msg, arena, &serialized_size); - const upb_Extension* upb_ext2; upb_test_ModelExtension1* ext1; upb_test_ModelExtension2* ext2; upb_GetExtension_Status promote_status; + upb_MessageValue value; // Test known GetExtension 1 - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(msg), &upb_test_ModelExtension1_model_ext_ext, 0, arena, - &upb_ext2); - ext1 = (upb_test_ModelExtension1*)upb_ext2->data.ptr; + &value); + ext1 = (upb_test_ModelExtension1*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_TRUE(upb_StringView_IsEqual(upb_StringView_FromString("World"), upb_test_ModelExtension1_str(ext1))); // Test known GetExtension 2 - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(msg), &upb_test_ModelExtension2_model_ext_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(5, upb_test_ModelExtension2_i(ext2)); // Test known GetExtension 3 - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(msg), &upb_test_ModelExtension2_model_ext_2_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(6, upb_test_ModelExtension2_i(ext2)); // Test known GetExtension 4 - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(msg), &upb_test_ModelExtension2_model_ext_3_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(7, upb_test_ModelExtension2_i(ext2)); // Test known GetExtension 5 - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(msg), &upb_test_ModelExtension2_model_ext_4_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(8, upb_test_ModelExtension2_i(ext2)); // Test known GetExtension 6 - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(msg), &upb_test_ModelExtension2_model_ext_5_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(9, upb_test_ModelExtension2_i(ext2)); @@ -188,51 +188,51 @@ TEST(GeneratedCode, Extensions) { EXPECT_EQ(0, upb_Message_ExtensionCount(UPB_UPCAST(base_msg))); // Test unknown GetExtension. - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(base_msg), &upb_test_ModelExtension1_model_ext_ext, 0, arena, - &upb_ext2); - ext1 = (upb_test_ModelExtension1*)upb_ext2->data.ptr; + &value); + ext1 = (upb_test_ModelExtension1*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_TRUE(upb_StringView_IsEqual(upb_StringView_FromString("World"), upb_test_ModelExtension1_str(ext1))); // Test unknown GetExtension. - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(base_msg), &upb_test_ModelExtension2_model_ext_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(5, upb_test_ModelExtension2_i(ext2)); // Test unknown GetExtension. - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(base_msg), &upb_test_ModelExtension2_model_ext_2_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(6, upb_test_ModelExtension2_i(ext2)); // Test unknown GetExtension. - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(base_msg), &upb_test_ModelExtension2_model_ext_3_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(7, upb_test_ModelExtension2_i(ext2)); // Test unknown GetExtension. - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(base_msg), &upb_test_ModelExtension2_model_ext_4_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(8, upb_test_ModelExtension2_i(ext2)); // Test unknown GetExtension. - promote_status = upb_MiniTable_GetOrPromoteExtension( + promote_status = upb_Message_GetOrPromoteExtension( UPB_UPCAST(base_msg), &upb_test_ModelExtension2_model_ext_5_ext, 0, arena, - &upb_ext2); - ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + &value); + ext2 = (upb_test_ModelExtension2*)value.msg_val; EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); EXPECT_EQ(9, upb_test_ModelExtension2_i(ext2));