Add Extensions to mini table based reflection apis.

PiperOrigin-RevId: 447561140
pull/13171/head
Protobuf Team 3 years ago committed by Copybara-Service
parent 2a5919deb3
commit 459059e301
  1. 1
      BUILD
  2. 278
      upb/mini_table_accessors.c
  3. 33
      upb/mini_table_accessors.h
  4. 66
      upb/mini_table_accessors_test.cc
  5. 26
      upb/test.proto

@ -191,6 +191,7 @@ cc_test(
":mini_table_internal",
":test_messages_proto2_proto_upb",
":test_messages_proto3_proto_upb",
":test_upb_proto",
":upb",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main",

@ -83,3 +83,281 @@ void upb_MiniTable_ClearField(upb_Message* msg,
}
memset(mem, 0, upb_MiniTable_Field_GetSize(field));
}
typedef struct {
const char* ptr;
uint64_t val;
} decode_vret;
UPB_NOINLINE
static decode_vret decode_longvarint64(const char* ptr, uint64_t val) {
decode_vret ret = {NULL, 0};
uint64_t byte;
int i;
for (i = 1; i < 10; i++) {
byte = (uint8_t)ptr[i];
val += (byte - 1) << (i * 7);
if (!(byte & 0x80)) {
ret.ptr = ptr + i + 1;
ret.val = val;
return ret;
}
}
return ret;
}
UPB_FORCEINLINE
static const char* decode_varint64(const char* ptr, uint64_t* val) {
uint64_t byte = (uint8_t)*ptr;
if (UPB_LIKELY((byte & 0x80) == 0)) {
*val = byte;
return ptr + 1;
} else {
decode_vret res = decode_longvarint64(ptr, byte);
if (!res.ptr) return NULL;
*val = res.val;
return res.ptr;
}
}
UPB_FORCEINLINE
static const char* decode_tag(const char* ptr, uint32_t* val) {
uint64_t byte = (uint8_t)*ptr;
if (UPB_LIKELY((byte & 0x80) == 0)) {
*val = (uint32_t)byte;
return ptr + 1;
} else {
const char* start = ptr;
decode_vret res = decode_longvarint64(ptr, byte);
if (!res.ptr || res.ptr - start > 5 || res.val > UINT32_MAX) {
return NULL; // Malformed.
}
*val = (uint32_t)res.val;
return res.ptr;
}
}
typedef enum {
kUpb_FindUnknown_Ok,
kUpb_FindUnknown_NotPresent,
kUpb_FindUnknown_ParseError,
} upb_FindUnknown_Status;
typedef struct {
upb_FindUnknown_Status status;
const char* ptr;
size_t len;
} find_unknown_ret;
static find_unknown_ret UnknownFieldSet_FindField(const upb_Message* msg,
int field_number);
upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension(
upb_Message* msg, const upb_MiniTable_Extension* ext_table,
int decode_options, upb_Arena* arena,
const upb_Message_Extension** extension) {
UPB_ASSERT(ext_table->field.descriptortype == kUpb_FieldType_Message);
*extension = _upb_Message_Getext(msg, ext_table);
if (*extension) {
return kUpb_GetExtension_Ok;
}
// Check unknown fields, if available promote.
int field_number = ext_table->field.number;
find_unknown_ret result = UnknownFieldSet_FindField(msg, field_number);
if (result.status != kUpb_FindUnknown_Ok) {
UPB_ASSERT(result.status != kUpb_GetExtension_ParseError);
return kUpb_GetExtension_NotPresent;
}
// Decode and promote from unknown.
const upb_MiniTable* extension_table = ext_table->sub.submsg;
upb_Message* extension_msg = _upb_Message_New(extension_table, arena);
if (!extension_msg) {
return kUpb_GetExtension_OutOfMemory;
}
const char* data = result.ptr;
uint32_t tag;
uint64_t message_len;
data = decode_tag(data, &tag);
data = decode_varint64(data, &message_len);
upb_DecodeStatus status =
upb_Decode(data, message_len, extension_msg, extension_table, NULL,
decode_options, arena);
if (status == kUpb_DecodeStatus_OutOfMemory) {
return kUpb_GetExtension_OutOfMemory;
}
if (status != kUpb_DecodeStatus_Ok) return kUpb_GetExtension_ParseError;
// Add to extensions.
upb_Message_Extension* ext =
_upb_Message_GetOrCreateExtension(msg, ext_table, arena);
if (!ext) {
return kUpb_GetExtension_OutOfMemory;
}
memcpy(&ext->data, &extension_msg, sizeof(extension_msg));
*extension = ext;
// Remove unknown field.
upb_Message_Internal* in = upb_Message_Getinternal(msg);
const char* internal_unknown_end =
UPB_PTR_AT(in->internal, in->internal->unknown_end, char);
if ((result.ptr + result.len) != internal_unknown_end) {
memmove((char*)result.ptr, result.ptr + result.len,
internal_unknown_end - result.ptr - result.len);
}
in->internal->unknown_end -= result.len;
return kUpb_GetExtension_Ok;
}
upb_GetExtensionAsBytes_Status upb_MiniTable_GetExtensionAsBytes(
const upb_Message* msg, const upb_MiniTable_Extension* ext_table,
int encode_options, upb_Arena* arena, const char** extension_data,
size_t* len) {
const upb_Message_Extension* msg_ext = _upb_Message_Getext(msg, ext_table);
UPB_ASSERT(ext_table->field.descriptortype == kUpb_FieldType_Message);
if (msg_ext) {
*extension_data = upb_Encode(msg_ext->data.ptr, msg_ext->ext->sub.submsg,
encode_options, arena, len);
if (extension_data) {
return kUpb_GetExtensionAsBytes_Ok;
}
return kUpb_GetExtensionAsBytes_EncodeError;
}
int field_number = ext_table->field.number;
find_unknown_ret result = UnknownFieldSet_FindField(msg, field_number);
if (result.status != kUpb_FindUnknown_Ok) {
UPB_ASSERT(result.status != kUpb_GetExtension_ParseError);
return kUpb_GetExtensionAsBytes_NotPresent;
}
const char* data = result.ptr;
uint32_t tag;
uint64_t message_len;
data = decode_tag(data, &tag);
data = decode_varint64(data, &message_len);
*extension_data = data;
*len = message_len;
return kUpb_GetExtensionAsBytes_Ok;
}
static const char* UnknownFieldSet_SkipGroup(const char* ptr, const char* end,
int group_number);
static const char* UnknownFieldSet_SkipField(const char* ptr, const char* end,
uint32_t tag) {
int field_number = tag >> 3;
int wire_type = tag & 7;
switch (wire_type) {
case kUpb_WireType_Varint: {
uint64_t val;
return decode_varint64(ptr, &val);
}
case kUpb_WireType_64Bit:
if (end - ptr < 8) return NULL;
return ptr + 8;
case kUpb_WireType_32Bit:
if (end - ptr < 4) return NULL;
return ptr + 4;
case kUpb_WireType_Delimited: {
uint64_t size;
ptr = decode_varint64(ptr, &size);
if (!ptr || end - ptr < size) return NULL;
return ptr + size;
}
case kUpb_WireType_StartGroup:
return UnknownFieldSet_SkipGroup(ptr, end, field_number);
case kUpb_WireType_EndGroup:
return NULL;
default:
assert(0);
return NULL;
}
}
static const char* UnknownFieldSet_SkipGroup(const char* ptr, const char* end,
int group_number) {
uint32_t end_tag = (group_number << 3) | kUpb_WireType_EndGroup;
while (true) {
if (ptr == end) return NULL;
uint64_t tag;
ptr = decode_varint64(ptr, &tag);
if (!ptr) return NULL;
if (tag == end_tag) return ptr;
ptr = UnknownFieldSet_SkipField(ptr, end, (uint32_t)tag);
if (!ptr) return NULL;
}
return ptr;
}
enum {
kUpb_MessageSet_StartItemTag = (1 << 3) | kUpb_WireType_StartGroup,
kUpb_MessageSet_EndItemTag = (1 << 3) | kUpb_WireType_EndGroup,
kUpb_MessageSet_TypeIdTag = (2 << 3) | kUpb_WireType_Varint,
kUpb_MessageSet_MessageTag = (3 << 3) | kUpb_WireType_Delimited,
};
static find_unknown_ret UnknownFieldSet_FindField(const upb_Message* msg,
int field_number) {
size_t size;
find_unknown_ret ret;
const char* ptr = upb_Message_GetUnknown(msg, &size);
if (size == 0) {
ret.ptr = NULL;
return ret;
}
const char* end = ptr + size;
uint64_t uint64_val;
while (ptr < end) {
uint32_t tag;
int field;
int wire_type;
const char* unknown_begin = ptr;
ptr = decode_tag(ptr, &tag);
field = tag >> 3;
wire_type = tag & 7;
switch (wire_type) {
case kUpb_WireType_EndGroup:
ret.status = kUpb_FindUnknown_ParseError;
return ret;
case kUpb_WireType_Varint:
ptr = decode_varint64(ptr, &uint64_val);
if (!ptr) {
ret.status = kUpb_FindUnknown_ParseError;
return ret;
}
break;
case kUpb_WireType_32Bit:
ptr += 4;
break;
case kUpb_WireType_64Bit:
ptr += 8;
break;
case kUpb_WireType_Delimited:
// Read size.
ptr = decode_varint64(ptr, &uint64_val);
if (uint64_val >= INT32_MAX || !ptr) {
ret.status = kUpb_FindUnknown_ParseError;
return ret;
}
ptr += uint64_val;
break;
case kUpb_WireType_StartGroup:
// tag >> 3 specifies the group number, recurse and skip
// until we see group end tag.
ptr = UnknownFieldSet_SkipGroup(ptr, end, field_number);
break;
default:
ret.status = kUpb_FindUnknown_ParseError;
return ret;
}
if (field_number == field) {
ret.status = kUpb_FindUnknown_Ok;
ret.ptr = unknown_begin;
ret.len = ptr - unknown_begin;
return ret;
}
}
ret.status = kUpb_FindUnknown_NotPresent;
ret.ptr = NULL;
ret.len = 0;
return ret;
}

@ -221,7 +221,38 @@ UPB_INLINE upb_Array* upb_MiniTable_GetMutableArray(
return (upb_Array*)*UPB_PTR_AT(msg, field->offset, upb_Array*);
}
// TODO(ferhat): Add support for extensions.
typedef enum {
kUpb_GetExtension_Ok,
kUpb_GetExtension_NotPresent,
kUpb_GetExtension_ParseError,
kUpb_GetExtension_OutOfMemory,
} upb_GetExtension_Status;
typedef enum {
kUpb_GetExtensionAsBytes_Ok,
kUpb_GetExtensionAsBytes_NotPresent,
kUpb_GetExtensionAsBytes_EncodeError,
} upb_GetExtensionAsBytes_Status;
// Returns a message extension or promotes an unknown field to
// an extension.
//
// TODO(ferhat): Only supports extension fields that are messages,
// expand support to include non-message types.
upb_GetExtension_Status upb_MiniTable_GetOrPromoteExtension(
upb_Message* msg, const upb_MiniTable_Extension* ext_table,
int decode_options, upb_Arena* arena,
const upb_Message_Extension** extension);
// Returns a message extension or unknown field matching the extension
// data as bytes.
//
// If an extension has already been decoded it will be re-encoded
// to bytes.
upb_GetExtensionAsBytes_Status upb_MiniTable_GetExtensionAsBytes(
const upb_Message* msg, const upb_MiniTable_Extension* ext_table,
int encode_options, upb_Arena* arena, const char** extension_data,
size_t* len);
#ifdef __cplusplus
} /* extern "C" */

@ -38,6 +38,7 @@
#include "src/google/protobuf/test_messages_proto3.upb.h"
#include "upb/collections.h"
#include "upb/mini_table.h"
#include "upb/test.upb.h"
namespace {
@ -356,4 +357,69 @@ TEST(GeneratedCode, RepeatedScalar) {
upb_Arena_Free(arena);
}
TEST(GeneratedCode, Extensions) {
upb_Arena* arena = upb_Arena_New();
upb_test_ModelWithExtensions* msg = upb_test_ModelWithExtensions_new(arena);
upb_test_ModelWithExtensions_set_random_int32(msg, 10);
upb_test_ModelWithExtensions_set_random_name(
msg, upb_StringView_FromString("Hello"));
upb_test_ModelExtension1* extension1 = upb_test_ModelExtension1_new(arena);
upb_test_ModelExtension1_set_str(extension1,
upb_StringView_FromString("World"));
upb_test_ModelExtension2* extension2 = upb_test_ModelExtension2_new(arena);
upb_test_ModelExtension2_set_i(extension2, 5);
upb_test_ModelExtension1_set_model_ext(msg, extension1, arena);
upb_test_ModelExtension2_set_model_ext(msg, extension2, arena);
size_t serialized_size;
char* serialized =
upb_test_ModelWithExtensions_serialize(msg, arena, &serialized_size);
// Test known GetExtension
const upb_Message_Extension* upb_ext2;
upb_GetExtension_Status promote_status = upb_MiniTable_GetOrPromoteExtension(
msg, &upb_test_ModelExtension2_model_ext_ext, 0, arena, &upb_ext2);
upb_test_ModelExtension2* ext2 =
(upb_test_ModelExtension2*)upb_ext2->data.ptr;
EXPECT_EQ(kUpb_GetExtension_Ok, promote_status);
EXPECT_EQ(5, upb_test_ModelExtension2_i(ext2));
upb_test_EmptyMessageWithExtensions* base_msg =
upb_test_EmptyMessageWithExtensions_parse(serialized, serialized_size,
arena);
// Get unknown extension bytes before promotion.
const char* extension_data;
size_t len;
upb_GetExtensionAsBytes_Status status = status =
upb_MiniTable_GetExtensionAsBytes(base_msg,
&upb_test_ModelExtension2_model_ext_ext,
0, arena, &extension_data, &len);
EXPECT_EQ(kUpb_GetExtensionAsBytes_Ok, status);
EXPECT_EQ(0x48, extension_data[0]);
EXPECT_EQ(5, extension_data[1]);
// Test unknown GetExtension.
promote_status = upb_MiniTable_GetOrPromoteExtension(
base_msg, &upb_test_ModelExtension2_model_ext_ext, 0, arena, &upb_ext2);
ext2 = (upb_test_ModelExtension2*)upb_ext2->data.ptr;
EXPECT_EQ(kUpb_GetExtension_Ok, promote_status);
EXPECT_EQ(5, upb_test_ModelExtension2_i(ext2));
// Get unknown extension bytes after promotion.
status = upb_MiniTable_GetExtensionAsBytes(
base_msg, &upb_test_ModelExtension2_model_ext_ext, 0, arena,
&extension_data, &len);
EXPECT_EQ(kUpb_GetExtensionAsBytes_Ok, status);
EXPECT_EQ(0x48, extension_data[0]);
EXPECT_EQ(5, extension_data[1]);
upb_Arena_Free(arena);
}
} // namespace

@ -46,3 +46,29 @@ message HelloRequest {
optional uint32 random_name_c9 = 31;
optional string version = 32;
}
message EmptyMessageWithExtensions {
// Reserved for unknown fields/extensions test.
reserved 1000 to max;
}
message ModelWithExtensions {
optional int32 random_int32 = 3;
optional string random_name = 4;
// Reserved for unknown fields/extensions test.
extensions 1000 to max;
}
message ModelExtension1 {
extend ModelWithExtensions {
optional ModelExtension1 model_ext = 1547;
}
optional string str = 25;
}
message ModelExtension2 {
extend ModelWithExtensions {
optional ModelExtension2 model_ext = 4135;
}
optional int32 i = 9;
}

Loading…
Cancel
Save