diff --git a/BUILD b/BUILD index 97545bf591..79e3e51600 100644 --- a/BUILD +++ b/BUILD @@ -191,6 +191,7 @@ cc_test( ":mini_table_internal", ":test_messages_proto2_proto_upb", ":test_messages_proto3_proto_upb", + ":test_upb_proto", ":upb", "@com_google_absl//absl/container:flat_hash_set", "@com_google_googletest//:gtest_main", diff --git a/upb/mini_table_accessors.c b/upb/mini_table_accessors.c index 8a1b891325..7ba8aed4f9 100644 --- a/upb/mini_table_accessors.c +++ b/upb/mini_table_accessors.c @@ -83,3 +83,281 @@ void upb_MiniTable_ClearField(upb_Message* msg, } memset(mem, 0, upb_MiniTable_Field_GetSize(field)); } + +typedef struct { + const char* ptr; + uint64_t val; +} decode_vret; + +UPB_NOINLINE +static decode_vret decode_longvarint64(const char* ptr, uint64_t val) { + decode_vret ret = {NULL, 0}; + uint64_t byte; + int i; + for (i = 1; i < 10; i++) { + byte = (uint8_t)ptr[i]; + val += (byte - 1) << (i * 7); + if (!(byte & 0x80)) { + ret.ptr = ptr + i + 1; + ret.val = val; + return ret; + } + } + return ret; +} + +UPB_FORCEINLINE +static const char* decode_varint64(const char* ptr, uint64_t* val) { + uint64_t byte = (uint8_t)*ptr; + if (UPB_LIKELY((byte & 0x80) == 0)) { + *val = byte; + return ptr + 1; + } else { + decode_vret res = decode_longvarint64(ptr, byte); + if (!res.ptr) return NULL; + *val = res.val; + return res.ptr; + } +} + +UPB_FORCEINLINE +static const char* decode_tag(const char* ptr, uint32_t* val) { + uint64_t byte = (uint8_t)*ptr; + if (UPB_LIKELY((byte & 0x80) == 0)) { + *val = (uint32_t)byte; + return ptr + 1; + } else { + const char* start = ptr; + decode_vret res = decode_longvarint64(ptr, byte); + if (!res.ptr || res.ptr - start > 5 || res.val > UINT32_MAX) { + return NULL; // Malformed. + } + *val = (uint32_t)res.val; + return res.ptr; + } +} + +typedef enum { + kUpb_FindUnknown_Ok, + kUpb_FindUnknown_NotPresent, + kUpb_FindUnknown_ParseError, +} upb_FindUnknown_Status; + +typedef struct { + upb_FindUnknown_Status status; + const char* ptr; + size_t len; +} find_unknown_ret; + +static find_unknown_ret UnknownFieldSet_FindField(const upb_Message* msg, + int field_number); + +upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension( + upb_Message* msg, const upb_MiniTable_Extension* ext_table, + int decode_options, upb_Arena* arena, + const upb_Message_Extension** extension) { + UPB_ASSERT(ext_table->field.descriptortype == kUpb_FieldType_Message); + *extension = _upb_Message_Getext(msg, ext_table); + if (*extension) { + return kUpb_GetExtension_Ok; + } + + // Check unknown fields, if available promote. + int field_number = ext_table->field.number; + find_unknown_ret result = UnknownFieldSet_FindField(msg, field_number); + if (result.status != kUpb_FindUnknown_Ok) { + UPB_ASSERT(result.status != kUpb_GetExtension_ParseError); + return kUpb_GetExtension_NotPresent; + } + // Decode and promote from unknown. + const upb_MiniTable* extension_table = ext_table->sub.submsg; + upb_Message* extension_msg = _upb_Message_New(extension_table, arena); + if (!extension_msg) { + return kUpb_GetExtension_OutOfMemory; + } + const char* data = result.ptr; + uint32_t tag; + uint64_t message_len; + data = decode_tag(data, &tag); + data = decode_varint64(data, &message_len); + upb_DecodeStatus status = + upb_Decode(data, message_len, extension_msg, extension_table, NULL, + decode_options, arena); + if (status == kUpb_DecodeStatus_OutOfMemory) { + return kUpb_GetExtension_OutOfMemory; + } + if (status != kUpb_DecodeStatus_Ok) return kUpb_GetExtension_ParseError; + // Add to extensions. + upb_Message_Extension* ext = + _upb_Message_GetOrCreateExtension(msg, ext_table, arena); + if (!ext) { + return kUpb_GetExtension_OutOfMemory; + } + memcpy(&ext->data, &extension_msg, sizeof(extension_msg)); + *extension = ext; + // Remove unknown field. + upb_Message_Internal* in = upb_Message_Getinternal(msg); + const char* internal_unknown_end = + UPB_PTR_AT(in->internal, in->internal->unknown_end, char); + if ((result.ptr + result.len) != internal_unknown_end) { + memmove((char*)result.ptr, result.ptr + result.len, + internal_unknown_end - result.ptr - result.len); + } + in->internal->unknown_end -= result.len; + return kUpb_GetExtension_Ok; +} + +upb_GetExtensionAsBytes_Status upb_MiniTable_GetExtensionAsBytes( + const upb_Message* msg, const upb_MiniTable_Extension* ext_table, + int encode_options, upb_Arena* arena, const char** extension_data, + size_t* len) { + const upb_Message_Extension* msg_ext = _upb_Message_Getext(msg, ext_table); + UPB_ASSERT(ext_table->field.descriptortype == kUpb_FieldType_Message); + if (msg_ext) { + *extension_data = upb_Encode(msg_ext->data.ptr, msg_ext->ext->sub.submsg, + encode_options, arena, len); + if (extension_data) { + return kUpb_GetExtensionAsBytes_Ok; + } + return kUpb_GetExtensionAsBytes_EncodeError; + } + int field_number = ext_table->field.number; + find_unknown_ret result = UnknownFieldSet_FindField(msg, field_number); + if (result.status != kUpb_FindUnknown_Ok) { + UPB_ASSERT(result.status != kUpb_GetExtension_ParseError); + return kUpb_GetExtensionAsBytes_NotPresent; + } + const char* data = result.ptr; + uint32_t tag; + uint64_t message_len; + data = decode_tag(data, &tag); + data = decode_varint64(data, &message_len); + *extension_data = data; + *len = message_len; + return kUpb_GetExtensionAsBytes_Ok; +} + +static const char* UnknownFieldSet_SkipGroup(const char* ptr, const char* end, + int group_number); + +static const char* UnknownFieldSet_SkipField(const char* ptr, const char* end, + uint32_t tag) { + int field_number = tag >> 3; + int wire_type = tag & 7; + switch (wire_type) { + case kUpb_WireType_Varint: { + uint64_t val; + return decode_varint64(ptr, &val); + } + case kUpb_WireType_64Bit: + if (end - ptr < 8) return NULL; + return ptr + 8; + case kUpb_WireType_32Bit: + if (end - ptr < 4) return NULL; + return ptr + 4; + case kUpb_WireType_Delimited: { + uint64_t size; + ptr = decode_varint64(ptr, &size); + if (!ptr || end - ptr < size) return NULL; + return ptr + size; + } + case kUpb_WireType_StartGroup: + return UnknownFieldSet_SkipGroup(ptr, end, field_number); + case kUpb_WireType_EndGroup: + return NULL; + default: + assert(0); + return NULL; + } +} + +static const char* UnknownFieldSet_SkipGroup(const char* ptr, const char* end, + int group_number) { + uint32_t end_tag = (group_number << 3) | kUpb_WireType_EndGroup; + while (true) { + if (ptr == end) return NULL; + uint64_t tag; + ptr = decode_varint64(ptr, &tag); + if (!ptr) return NULL; + if (tag == end_tag) return ptr; + ptr = UnknownFieldSet_SkipField(ptr, end, (uint32_t)tag); + if (!ptr) return NULL; + } + return ptr; +} + +enum { + kUpb_MessageSet_StartItemTag = (1 << 3) | kUpb_WireType_StartGroup, + kUpb_MessageSet_EndItemTag = (1 << 3) | kUpb_WireType_EndGroup, + kUpb_MessageSet_TypeIdTag = (2 << 3) | kUpb_WireType_Varint, + kUpb_MessageSet_MessageTag = (3 << 3) | kUpb_WireType_Delimited, +}; + +static find_unknown_ret UnknownFieldSet_FindField(const upb_Message* msg, + int field_number) { + size_t size; + find_unknown_ret ret; + + const char* ptr = upb_Message_GetUnknown(msg, &size); + if (size == 0) { + ret.ptr = NULL; + return ret; + } + const char* end = ptr + size; + uint64_t uint64_val; + + while (ptr < end) { + uint32_t tag; + int field; + int wire_type; + const char* unknown_begin = ptr; + ptr = decode_tag(ptr, &tag); + field = tag >> 3; + wire_type = tag & 7; + switch (wire_type) { + case kUpb_WireType_EndGroup: + ret.status = kUpb_FindUnknown_ParseError; + return ret; + case kUpb_WireType_Varint: + ptr = decode_varint64(ptr, &uint64_val); + if (!ptr) { + ret.status = kUpb_FindUnknown_ParseError; + return ret; + } + break; + case kUpb_WireType_32Bit: + ptr += 4; + break; + case kUpb_WireType_64Bit: + ptr += 8; + break; + case kUpb_WireType_Delimited: + // Read size. + ptr = decode_varint64(ptr, &uint64_val); + if (uint64_val >= INT32_MAX || !ptr) { + ret.status = kUpb_FindUnknown_ParseError; + return ret; + } + ptr += uint64_val; + break; + case kUpb_WireType_StartGroup: + // tag >> 3 specifies the group number, recurse and skip + // until we see group end tag. + ptr = UnknownFieldSet_SkipGroup(ptr, end, field_number); + break; + default: + ret.status = kUpb_FindUnknown_ParseError; + return ret; + } + if (field_number == field) { + ret.status = kUpb_FindUnknown_Ok; + ret.ptr = unknown_begin; + ret.len = ptr - unknown_begin; + return ret; + } + } + ret.status = kUpb_FindUnknown_NotPresent; + ret.ptr = NULL; + ret.len = 0; + return ret; +} diff --git a/upb/mini_table_accessors.h b/upb/mini_table_accessors.h index dddaba63ad..6c430d6be6 100644 --- a/upb/mini_table_accessors.h +++ b/upb/mini_table_accessors.h @@ -221,7 +221,38 @@ UPB_INLINE upb_Array* upb_MiniTable_GetMutableArray( return (upb_Array*)*UPB_PTR_AT(msg, field->offset, upb_Array*); } -// TODO(ferhat): Add support for extensions. +typedef enum { + kUpb_GetExtension_Ok, + kUpb_GetExtension_NotPresent, + kUpb_GetExtension_ParseError, + kUpb_GetExtension_OutOfMemory, +} upb_GetExtension_Status; + +typedef enum { + kUpb_GetExtensionAsBytes_Ok, + kUpb_GetExtensionAsBytes_NotPresent, + kUpb_GetExtensionAsBytes_EncodeError, +} upb_GetExtensionAsBytes_Status; + +// Returns a message extension or promotes an unknown field to +// an extension. +// +// TODO(ferhat): Only supports extension fields that are messages, +// expand support to include non-message types. +upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension( + upb_Message* msg, const upb_MiniTable_Extension* ext_table, + int decode_options, upb_Arena* arena, + const upb_Message_Extension** extension); + +// Returns a message extension or unknown field matching the extension +// data as bytes. +// +// If an extension has already been decoded it will be re-encoded +// to bytes. +upb_GetExtensionAsBytes_Status upb_MiniTable_GetExtensionAsBytes( + const upb_Message* msg, const upb_MiniTable_Extension* ext_table, + int encode_options, upb_Arena* arena, const char** extension_data, + size_t* len); #ifdef __cplusplus } /* extern "C" */ diff --git a/upb/mini_table_accessors_test.cc b/upb/mini_table_accessors_test.cc index 04ac5c34b2..498471692f 100644 --- a/upb/mini_table_accessors_test.cc +++ b/upb/mini_table_accessors_test.cc @@ -38,6 +38,7 @@ #include "src/google/protobuf/test_messages_proto3.upb.h" #include "upb/collections.h" #include "upb/mini_table.h" +#include "upb/test.upb.h" namespace { @@ -356,4 +357,69 @@ TEST(GeneratedCode, RepeatedScalar) { upb_Arena_Free(arena); } +TEST(GeneratedCode, Extensions) { + upb_Arena* arena = upb_Arena_New(); + upb_test_ModelWithExtensions* msg = upb_test_ModelWithExtensions_new(arena); + upb_test_ModelWithExtensions_set_random_int32(msg, 10); + upb_test_ModelWithExtensions_set_random_name( + msg, upb_StringView_FromString("Hello")); + + upb_test_ModelExtension1* extension1 = upb_test_ModelExtension1_new(arena); + upb_test_ModelExtension1_set_str(extension1, + upb_StringView_FromString("World")); + + upb_test_ModelExtension2* extension2 = upb_test_ModelExtension2_new(arena); + upb_test_ModelExtension2_set_i(extension2, 5); + + upb_test_ModelExtension1_set_model_ext(msg, extension1, arena); + upb_test_ModelExtension2_set_model_ext(msg, extension2, arena); + + size_t serialized_size; + char* serialized = + upb_test_ModelWithExtensions_serialize(msg, arena, &serialized_size); + + // Test known GetExtension + const upb_Message_Extension* upb_ext2; + upb_GetExtension_Status promote_status = upb_MiniTable_GetOrPromoteExtension( + msg, &upb_test_ModelExtension2_model_ext_ext, 0, arena, &upb_ext2); + + upb_test_ModelExtension2* ext2 = + (upb_test_ModelExtension2*)upb_ext2->data.ptr; + EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); + EXPECT_EQ(5, upb_test_ModelExtension2_i(ext2)); + + upb_test_EmptyMessageWithExtensions* base_msg = + upb_test_EmptyMessageWithExtensions_parse(serialized, serialized_size, + arena); + + // Get unknown extension bytes before promotion. + const char* extension_data; + size_t len; + upb_GetExtensionAsBytes_Status status = status = + upb_MiniTable_GetExtensionAsBytes(base_msg, + &upb_test_ModelExtension2_model_ext_ext, + 0, arena, &extension_data, &len); + EXPECT_EQ(kUpb_GetExtensionAsBytes_Ok, status); + EXPECT_EQ(0x48, extension_data[0]); + EXPECT_EQ(5, extension_data[1]); + + // Test unknown GetExtension. + promote_status = upb_MiniTable_GetOrPromoteExtension( + base_msg, &upb_test_ModelExtension2_model_ext_ext, 0, arena, &upb_ext2); + + ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr; + EXPECT_EQ(kUpb_GetExtension_Ok, promote_status); + EXPECT_EQ(5, upb_test_ModelExtension2_i(ext2)); + + // Get unknown extension bytes after promotion. + status = upb_MiniTable_GetExtensionAsBytes( + base_msg, &upb_test_ModelExtension2_model_ext_ext, 0, arena, + &extension_data, &len); + EXPECT_EQ(kUpb_GetExtensionAsBytes_Ok, status); + EXPECT_EQ(0x48, extension_data[0]); + EXPECT_EQ(5, extension_data[1]); + + upb_Arena_Free(arena); +} + } // namespace diff --git a/upb/test.proto b/upb/test.proto index f4b94b4032..d4e6b84488 100644 --- a/upb/test.proto +++ b/upb/test.proto @@ -46,3 +46,29 @@ message HelloRequest { optional uint32 random_name_c9 = 31; optional string version = 32; } + +message EmptyMessageWithExtensions { + // Reserved for unknown fields/extensions test. + reserved 1000 to max; +} + +message ModelWithExtensions { + optional int32 random_int32 = 3; + optional string random_name = 4; + // Reserved for unknown fields/extensions test. + extensions 1000 to max; +} + +message ModelExtension1 { + extend ModelWithExtensions { + optional ModelExtension1 model_ext = 1547; + } + optional string str = 25; +} + +message ModelExtension2 { + extend ModelWithExtensions { + optional ModelExtension2 model_ext = 4135; + } + optional int32 i = 9; +}