diff --git a/upb/message/BUILD b/upb/message/BUILD index a60c2e5db5..a4ad2b7afe 100644 --- a/upb/message/BUILD +++ b/upb/message/BUILD @@ -359,6 +359,7 @@ cc_test( "//upb/test:test_proto_upb_minitable", "//upb/test:test_upb_proto", "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/strings:string_view", "@googletest//:gtest", "@googletest//:gtest_main", ], diff --git a/upb/message/promote.c b/upb/message/promote.c index 2ed22a23db..bea38d1a4f 100644 --- a/upb/message/promote.c +++ b/upb/message/promote.c @@ -26,7 +26,6 @@ #include "upb/mini_table/extension.h" #include "upb/mini_table/field.h" #include "upb/mini_table/message.h" -#include "upb/mini_table/sub.h" #include "upb/wire/decode.h" #include "upb/wire/eps_copy_input_stream.h" #include "upb/wire/reader.h" @@ -78,36 +77,74 @@ upb_GetExtension_Status upb_Message_GetOrPromoteExtension( } // Check unknown fields, if available promote. - int field_number = upb_MiniTableExtension_Number(ext_table); - upb_FindUnknownRet result = upb_Message_FindUnknown(msg, field_number, 0); - if (result.status != kUpb_FindUnknown_Ok) { - return kUpb_GetExtension_NotPresent; - } - // Decode and promote from unknown. + int found_count = 0; + uint32_t field_number = upb_MiniTableExtension_Number(ext_table); const upb_MiniTable* extension_table = upb_MiniTableExtension_GetSubMessage(ext_table); - upb_UnknownToMessageRet parse_result = upb_MiniTable_ParseUnknownMessage( - result.ptr, result.len, extension_table, - /* base_message= */ NULL, decode_options, arena); - switch (parse_result.status) { - case kUpb_UnknownToMessage_OutOfMemory: - return kUpb_GetExtension_OutOfMemory; - case kUpb_UnknownToMessage_ParseError: - return kUpb_GetExtension_ParseError; - case kUpb_UnknownToMessage_NotFound: - return kUpb_GetExtension_NotPresent; - case kUpb_UnknownToMessage_Ok: - break; + // Will be populated on first parse and then reused + upb_Message* extension_msg = NULL; + int depth_limit = 100; + uintptr_t iter = kUpb_Message_UnknownBegin; + upb_StringView data; + uintptr_t last_found_iter; + while (upb_Message_NextUnknown(msg, &data, &iter)) { + const char* ptr = data.data; + upb_EpsCopyInputStream stream; + upb_EpsCopyInputStream_Init(&stream, &ptr, data.size, true); + while (!upb_EpsCopyInputStream_IsDone(&stream, &ptr)) { + uint32_t tag; + const char* unknown_begin = ptr; + ptr = upb_WireReader_ReadTag(ptr, &tag); + if (!ptr) return kUpb_GetExtension_ParseError; + if (field_number == upb_WireReader_GetFieldNumber(tag)) { + last_found_iter = iter; + found_count++; + const char* start = + upb_EpsCopyInputStream_GetAliasedPtr(&stream, unknown_begin); + ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, &stream); + if (!ptr) return kUpb_GetExtension_ParseError; + // Because we know that the input is a flat buffer, it is safe to + // perform pointer arithmetic on aliased pointers. + size_t len = upb_EpsCopyInputStream_GetAliasedPtr(&stream, ptr) - start; + upb_UnknownToMessageRet parse_result = + upb_MiniTable_ParseUnknownMessage(start, len, extension_table, + /* base_message= */ extension_msg, + decode_options, arena); + switch (parse_result.status) { + case kUpb_UnknownToMessage_OutOfMemory: + return kUpb_GetExtension_OutOfMemory; + case kUpb_UnknownToMessage_ParseError: + return kUpb_GetExtension_ParseError; + case kUpb_UnknownToMessage_NotFound: + return kUpb_GetExtension_NotPresent; + case kUpb_UnknownToMessage_Ok: + extension_msg = parse_result.message; + } + } else { + ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, &stream); + if (!ptr) return kUpb_GetExtension_ParseError; + } + } + } + if (!extension_msg) { + return kUpb_GetExtension_NotPresent; } - upb_Message* extension_msg = parse_result.message; - // Add to extensions. + upb_Extension* ext = upb_Arena_Malloc(arena, sizeof(upb_Extension)); if (!ext) { return kUpb_GetExtension_OutOfMemory; } ext->ext = ext_table; ext->data.msg_val = extension_msg; - upb_Message_ReplaceUnknownWithExtension(msg, result.iter, ext); + + upb_Message_ReplaceUnknownWithExtension(msg, last_found_iter, ext); + while (found_count > 1) { + upb_FindUnknownRet found = upb_Message_FindUnknown(msg, field_number, 0); + UPB_ASSERT(found.status == kUpb_FindUnknown_Ok); + upb_StringView view = {.data = found.ptr, .size = found.len}; + upb_Message_DeleteUnknown(msg, &view, &found.iter); + found_count--; + } value->msg_val = extension_msg; return kUpb_GetExtension_Ok; } diff --git a/upb/message/promote_test.cc b/upb/message/promote_test.cc index a6aab32179..cb028affd8 100644 --- a/upb/message/promote_test.cc +++ b/upb/message/promote_test.cc @@ -20,6 +20,7 @@ #include #include +#include "absl/strings/string_view.h" #include "upb/base/descriptor_constants.h" #include "upb/base/status.h" #include "upb/base/string_view.h" @@ -29,7 +30,6 @@ #include "upb/message/accessors.h" #include "upb/message/array.h" #include "upb/message/copy.h" -#include "upb/message/internal/extension.h" #include "upb/message/internal/message.h" #include "upb/message/map.h" #include "upb/message/message.h" @@ -39,6 +39,7 @@ #include "upb/mini_descriptor/internal/modifiers.h" #include "upb/mini_descriptor/link.h" #include "upb/mini_table/extension.h" +#include "upb/mini_table/extension_registry.h" #include "upb/mini_table/field.h" #include "upb/mini_table/message.h" #include "upb/test/test.upb.h" @@ -94,6 +95,54 @@ TEST(GeneratedCode, FindUnknown) { upb_Arena_Free(arena); } +TEST(GeneratedCode, PromoteFromMultiple) { + int options = kUpb_DecodeOption_AliasString; + upb_Arena* arena = upb_Arena_New(); + upb_test_ModelWithExtensions* msg = upb_test_ModelWithExtensions_new(arena); + + upb_test_ModelExtension1* extension1 = upb_test_ModelExtension1_new(arena); + upb_test_ModelExtension1_set_str(extension1, + upb_StringView_FromString("World")); + + upb_test_ModelExtension1_set_model_ext(msg, extension1, arena); + + size_t serialized_size; + char* serialized1 = + upb_test_ModelWithExtensions_serialize(msg, arena, &serialized_size); + + upb_test_ModelExtension1_set_str(extension1, + upb_StringView_FromString("Everyone")); + size_t serialized_size2; + char* serialized2 = + upb_test_ModelWithExtensions_serialize(msg, arena, &serialized_size2); + char* concat = + (char*)upb_Arena_Malloc(arena, serialized_size + serialized_size2); + memcpy(concat, serialized1, serialized_size); + memcpy(concat + serialized_size, serialized2, serialized_size2); + + upb_test_ModelWithExtensions* parsed = upb_test_ModelWithExtensions_parse_ex( + concat, serialized_size + serialized_size2, + upb_ExtensionRegistry_New(arena), options, arena); + + upb_MessageValue value; + upb_GetExtension_Status result = upb_Message_GetOrPromoteExtension( + UPB_UPCAST(parsed), &upb_test_ModelExtension1_model_ext_ext, options, + arena, &value); + ASSERT_EQ(result, kUpb_GetExtension_Ok); + upb_test_ModelExtension1* parsed_ex = + (upb_test_ModelExtension1*)value.msg_val; + upb_StringView field = upb_test_ModelExtension1_str(parsed_ex); + EXPECT_EQ(absl::string_view(field.data, field.size), "Everyone"); + + upb_FindUnknownRet found = upb_Message_FindUnknown( + UPB_UPCAST(parsed), + upb_MiniTableExtension_Number(&upb_test_ModelExtension1_model_ext_ext), + 0); + EXPECT_EQ(kUpb_FindUnknown_NotPresent, found.status); + + upb_Arena_Free(arena); +} + TEST(GeneratedCode, Extensions) { upb_Arena* arena = upb_Arena_New(); upb_test_ModelWithExtensions* msg = upb_test_ModelWithExtensions_new(arena);