Rewrote the MessageSet parsing code in the upb decoder to properly handle several edge cases.

PiperOrigin-RevId: 440788402
pull/13171/head
Joshua Haberman 3 years ago committed by Copybara-Service
parent bef53686ec
commit 9cc02bb60d
  1. 217
      upb/decode.c
  2. 2
      upb/msg.c
  3. 2
      upb/msg_internal.h
  4. 62
      upb/msg_test.cc
  5. 14
      upb/msg_test.proto
  6. 2
      upb/reflection.c
  7. 2
      upbc/protoc-gen-upb.cc

@ -93,13 +93,11 @@ static const unsigned FIXED64_OK_MASK = (1 << kUpb_FieldType_Double) |
/* Three fake field types for MessageSet. */
#define TYPE_MSGSET_ITEM 19
#define TYPE_MSGSET_TYPE_ID 20
#define TYPE_COUNT 20
#define TYPE_COUNT 19
/* Op: an action to be performed for a wire-type/field-type combination. */
#define OP_UNKNOWN -1 /* Unknown field. */
#define OP_MSGSET_ITEM -2
#define OP_MSGSET_TYPEID -3
#define OP_SCALAR_LG2(n) (n) /* n in [0, 2, 3] => op in [0, 2, 3] */
#define OP_ENUM 1
#define OP_STRING 4
@ -131,7 +129,6 @@ static const int8_t varint_ops[] = {
OP_SCALAR_LG2(2), /* SINT32 */
OP_SCALAR_LG2(3), /* SINT64 */
OP_UNKNOWN, /* MSGSET_ITEM */
OP_MSGSET_TYPEID, /* MSGSET TYPEID */
};
static const int8_t delim_ops[] = {
@ -156,7 +153,6 @@ static const int8_t delim_ops[] = {
OP_UNKNOWN, /* SINT32 */
OP_UNKNOWN, /* SINT64 */
OP_UNKNOWN, /* MSGSET_ITEM */
OP_UNKNOWN, /* MSGSET TYPEID */
/* For repeated field type. */
OP_FIXPCK_LG2(3), /* REPEATED DOUBLE */
OP_FIXPCK_LG2(2), /* REPEATED FLOAT */
@ -266,6 +262,18 @@ static const char* decode_tag(upb_Decoder* d, const char* ptr, uint32_t* val) {
}
}
UPB_FORCEINLINE
static const char* upb_Decoder_DecodeSize(upb_Decoder* d, const char* ptr,
uint32_t* size) {
uint64_t size64;
ptr = decode_varint64(d, ptr, &size64);
if (size64 >= INT32_MAX || ptr - d->end + (int)size64 > d->limit) {
decode_err(d, kUpb_DecodeStatus_Malformed);
}
*size = size64;
return ptr;
}
static void decode_munge_int32(wireval* val) {
if (!_upb_IsLittleEndian()) {
/* The next stage will memcpy(dst, &val, 4) */
@ -300,7 +308,9 @@ static upb_Message* decode_newsubmsg(upb_Decoder* d,
const upb_MiniTable_Sub* subs,
const upb_MiniTable_Field* field) {
const upb_MiniTable* subl = subs[field->submsg_index].submsg;
return _upb_Message_New_inl(subl, &d->arena);
upb_Message* msg = _upb_Message_New_inl(subl, &d->arena);
if (!msg) decode_err(d, kUpb_DecodeStatus_OutOfMemory);
return msg;
}
UPB_NOINLINE
@ -375,7 +385,7 @@ static const char* decode_togroup(upb_Decoder* d, const char* ptr,
return decode_group(d, ptr, submsg, subl, field->number);
}
static char* encode_varint32(uint32_t val, char* ptr) {
static char* upb_Decoder_EncodeVarint32(uint32_t val, char* ptr) {
do {
uint8_t byte = val & 0x7fU;
val >>= 7;
@ -389,8 +399,8 @@ static void upb_Decode_AddUnknownVarints(upb_Decoder* d, upb_Message* msg,
uint32_t val1, uint32_t val2) {
char buf[20];
char* end = buf;
end = encode_varint32(val1, end);
end = encode_varint32(val2, end);
end = upb_Decoder_EncodeVarint32(val1, end);
end = upb_Decoder_EncodeVarint32(val2, end);
if (!_upb_Message_AddUnknown(msg, buf, end - buf, &d->arena)) {
decode_err(d, kUpb_DecodeStatus_OutOfMemory);
@ -743,25 +753,139 @@ static bool decode_tryfastdispatch(upb_Decoder* d, const char** ptr,
return false;
}
static const char* decode_msgset(upb_Decoder* d, const char* ptr,
upb_Message* msg,
const upb_MiniTable* layout) {
// We create a temporary upb_MiniTable here and abuse its fields as temporary
// storage, to avoid creating lots of MessageSet-specific parsing code-paths:
// 1. We store 'layout' in item_layout.subs. We will need this later as
// a key to look up extensions for this MessageSet.
// 2. We use item_layout.fields as temporary storage to store the extension
// we
// found when parsing the type id.
upb_MiniTable item_layout = {
.subs = (const upb_MiniTable_Sub[]){{.submsg = layout}},
.fields = NULL,
.size = 0,
.field_count = 0,
.ext = kUpb_ExtMode_IsMessageSet_ITEM,
.dense_below = 0,
.table_mask = -1};
return decode_group(d, ptr, msg, &item_layout, 1);
static const char* upb_Decoder_SkipField(upb_Decoder* d, const char* ptr,
uint32_t tag) {
int field_number = tag >> 3;
int wire_type = tag & 7;
switch (wire_type) {
case kUpb_WireType_Varint: {
uint64_t val;
return decode_varint64(d, ptr, &val);
}
case kUpb_WireType_64Bit:
return ptr + 8;
case kUpb_WireType_32Bit:
return ptr + 4;
case kUpb_WireType_Delimited: {
uint32_t size;
ptr = upb_Decoder_DecodeSize(d, ptr, &size);
return ptr + size;
}
case kUpb_WireType_StartGroup:
return decode_group(d, ptr, NULL, NULL, field_number);
default:
decode_err(d, kUpb_DecodeStatus_Malformed);
}
}
enum {
kStartItemTag = ((1 << 3) | kUpb_WireType_StartGroup),
kEndItemTag = ((1 << 3) | kUpb_WireType_EndGroup),
kTypeIdTag = ((2 << 3) | kUpb_WireType_Varint),
kMessageTag = ((3 << 3) | kUpb_WireType_Delimited),
};
static void upb_Decoder_AddKnownMessageSetItem(
upb_Decoder* d, upb_Message* msg, const upb_MiniTable_Extension* item_mt,
const char* data, uint32_t size) {
upb_Message_Extension* ext =
_upb_Message_GetOrCreateExtension(msg, item_mt, &d->arena);
if (UPB_UNLIKELY(!ext)) decode_err(d, kUpb_DecodeStatus_OutOfMemory);
upb_Message* submsg = decode_newsubmsg(d, &ext->ext->sub, &ext->ext->field);
upb_DecodeStatus status = upb_Decode(data, size, submsg, item_mt->sub.submsg,
d->extreg, d->options, &d->arena);
memcpy(&ext->data, &submsg, sizeof(submsg));
if (status != kUpb_DecodeStatus_Ok) decode_err(d, status);
}
static void upb_Decoder_AddUnknownMessageSetItem(upb_Decoder* d,
upb_Message* msg,
uint32_t type_id,
const char* message_data,
uint32_t message_size) {
char buf[60];
char* ptr = buf;
ptr = upb_Decoder_EncodeVarint32(kStartItemTag, ptr);
ptr = upb_Decoder_EncodeVarint32(kTypeIdTag, ptr);
ptr = upb_Decoder_EncodeVarint32(type_id, ptr);
ptr = upb_Decoder_EncodeVarint32(kMessageTag, ptr);
ptr = upb_Decoder_EncodeVarint32(message_size, ptr);
char* split = ptr;
ptr = upb_Decoder_EncodeVarint32(kEndItemTag, ptr);
char* end = ptr;
if (!_upb_Message_AddUnknown(msg, buf, split - buf, &d->arena) ||
!_upb_Message_AddUnknown(msg, message_data, message_size, &d->arena) ||
!_upb_Message_AddUnknown(msg, split, end - split, &d->arena)) {
decode_err(d, kUpb_DecodeStatus_OutOfMemory);
}
}
static void upb_Decoder_AddMessageSetItem(upb_Decoder* d, upb_Message* msg,
const upb_MiniTable* layout,
uint32_t type_id, const char* data,
uint32_t size) {
const upb_MiniTable_Extension* item_mt =
_upb_extreg_get(d->extreg, layout, type_id);
if (item_mt) {
upb_Decoder_AddKnownMessageSetItem(d, msg, item_mt, data, size);
} else {
upb_Decoder_AddUnknownMessageSetItem(d, msg, type_id, data, size);
}
}
static const char* upb_Decoder_DecodeMessageSetItem(
upb_Decoder* d, const char* ptr, upb_Message* msg,
const upb_MiniTable* layout) {
uint32_t type_id = 0;
upb_StringView preserved = {NULL, 0};
typedef enum {
kUpb_HaveId = 1 << 0,
kUpb_HavePayload = 1 << 1,
} StateMask;
StateMask state_mask = 0;
while (!decode_isdone(d, &ptr)) {
uint32_t tag;
ptr = decode_tag(d, ptr, &tag);
switch (tag) {
case kEndItemTag:
return ptr;
case kTypeIdTag: {
uint64_t tmp;
ptr = decode_varint64(d, ptr, &tmp);
if (state_mask & kUpb_HaveId) break; // Ignore dup.
state_mask |= kUpb_HaveId;
type_id = tmp;
if (state_mask & kUpb_HavePayload) {
upb_Decoder_AddMessageSetItem(d, msg, layout, type_id, preserved.data,
preserved.size);
}
break;
}
case kMessageTag: {
uint32_t size;
ptr = upb_Decoder_DecodeSize(d, ptr, &size);
const char* data = ptr;
ptr += size;
if (state_mask & kUpb_HavePayload) break; // Ignore dup.
state_mask |= kUpb_HavePayload;
if (state_mask & kUpb_HaveId) {
upb_Decoder_AddMessageSetItem(d, msg, layout, type_id, data, size);
} else {
// Out of order, we must preserve the payload.
preserved.data = data;
preserved.size = size;
}
break;
}
default:
// We do not preserve unexpected fields inside a message set item.
ptr = upb_Decoder_SkipField(d, ptr, tag);
break;
}
}
decode_err(d, kUpb_DecodeStatus_Malformed);
}
static const upb_MiniTable_Field* decode_findfield(upb_Decoder* d,
@ -808,26 +932,6 @@ static const upb_MiniTable_Field* decode_findfield(upb_Decoder* d,
return &item;
}
break;
case kUpb_ExtMode_IsMessageSet_ITEM:
switch (field_number) {
case _UPB_MSGSET_TYPEID: {
static upb_MiniTable_Field type_id = {
0, 0, 0, 0, TYPE_MSGSET_TYPE_ID, 0};
return &type_id;
}
case _UPB_MSGSET_MESSAGE:
if (l->fields) {
// We saw type_id previously and succeeded in looking up msg.
return l->fields;
} else {
// TODO: out of order MessageSet.
// This is a very rare case: all serializers will emit in-order
// MessageSets. To hit this case there has to be some kind of
// re-ordering proxy. We should eventually handle this case, but
// not today.
}
break;
}
}
}
@ -867,14 +971,9 @@ static const char* decode_wireval(upb_Decoder* d, const char* ptr,
return ptr + 8;
case kUpb_WireType_Delimited: {
int ndx = field->descriptortype;
uint64_t size;
if (upb_FieldMode_Get(field) == kUpb_FieldMode_Array) ndx += TYPE_COUNT;
ptr = decode_varint64(d, ptr, &size);
if (size >= INT32_MAX || ptr - d->end + (int32_t)size > d->limit) {
break; /* Length overflow. */
}
ptr = upb_Decoder_DecodeSize(d, ptr, &val->size);
*op = delim_ops[ndx];
val->size = size;
return ptr;
}
case kUpb_WireType_StartGroup:
@ -905,7 +1004,7 @@ static const char* decode_known(upb_Decoder* d, const char* ptr,
const upb_MiniTable_Extension* ext_layout =
(const upb_MiniTable_Extension*)field;
upb_Message_Extension* ext =
_upb_Message_Getorcreateext(msg, ext_layout, &d->arena);
_upb_Message_GetOrCreateExtension(msg, ext_layout, &d->arena);
if (UPB_UNLIKELY(!ext)) return decode_err(d, kUpb_DecodeStatus_OutOfMemory);
msg = &ext->data;
subs = &ext->ext->sub;
@ -1038,14 +1137,8 @@ static const char* decode_msg(upb_Decoder* d, const char* ptr, upb_Message* msg,
ptr = decode_unknown(d, ptr, msg, field_number, wire_type, val);
break;
case OP_MSGSET_ITEM:
ptr = decode_msgset(d, ptr, msg, layout);
ptr = upb_Decoder_DecodeMessageSetItem(d, ptr, msg, layout);
break;
case OP_MSGSET_TYPEID: {
const upb_MiniTable_Extension* ext = _upb_extreg_get(
d->extreg, layout->subs[0].submsg, val.uint64_val);
if (ext) ((upb_MiniTable*)layout)->fields = &ext->field;
break;
}
}
}
}

@ -153,7 +153,7 @@ void _upb_Message_Clearext(upb_Message* msg,
}
}
upb_Message_Extension* _upb_Message_Getorcreateext(
upb_Message_Extension* _upb_Message_GetOrCreateExtension(
upb_Message* msg, const upb_MiniTable_Extension* e, upb_Arena* arena) {
upb_Message_Extension* ext =
(upb_Message_Extension*)_upb_Message_Getext(msg, e);

@ -336,7 +336,7 @@ typedef struct {
/* Adds the given extension data to the given message. |ext| is copied into the
* message instance. This logically replaces any previously-added extension with
* this number */
upb_Message_Extension* _upb_Message_Getorcreateext(
upb_Message_Extension* _upb_Message_GetOrCreateExtension(
upb_Message* msg, const upb_MiniTable_Extension* ext, upb_Arena* arena);
/* Returns an array of extensions for this message. Note: the array is

@ -102,6 +102,7 @@ TEST(MessageTest, Extensions) {
}
void VerifyMessageSet(const upb_test_TestMessageSet* mset_msg) {
ASSERT_TRUE(mset_msg != nullptr);
bool has = upb_test_MessageSetMember_has_message_set_extension(mset_msg);
EXPECT_TRUE(has);
if (!has) return;
@ -160,6 +161,67 @@ TEST(MessageTest, MessageSet) {
VerifyMessageSet(ext_msg3);
}
TEST(MessageTest, UnknownMessageSet) {
static const char data[] = "ABCDE";
upb_StringView data_view = upb_StringView_FromString(data);
upb::Arena arena;
upb_test_FakeMessageSet* fake = upb_test_FakeMessageSet_new(arena.ptr());
// Add a MessageSet item that is unknown (there is no matching extension in
// the .proto file)
upb_test_FakeMessageSet_Item* item =
upb_test_FakeMessageSet_add_item(fake, arena.ptr());
upb_test_FakeMessageSet_Item_set_type_id(item, 12345);
upb_test_FakeMessageSet_Item_set_message(item, data_view);
// Set unknown fields inside the message set to test that we can skip them.
upb_test_FakeMessageSet_Item_set_unknown_varint(item, 12345678);
upb_test_FakeMessageSet_Item_set_unknown_fixed32(item, 12345678);
upb_test_FakeMessageSet_Item_set_unknown_fixed64(item, 12345678);
upb_test_FakeMessageSet_Item_set_unknown_bytes(item, data_view);
upb_test_FakeMessageSet_Item_mutable_unknowngroup(item, arena.ptr());
// Round trip through a true MessageSet where this item_id is unknown.
size_t size;
char* serialized =
upb_test_FakeMessageSet_serialize(fake, arena.ptr(), &size);
ASSERT_TRUE(serialized != nullptr);
ASSERT_GE(size, 0);
upb::SymbolTable symtab;
upb::MessageDefPtr m(upb_test_TestMessageSet_getmsgdef(symtab.ptr()));
EXPECT_TRUE(m.ptr() != nullptr);
upb_test_TestMessageSet* message_set = upb_test_TestMessageSet_parse_ex(
serialized, size, upb_DefPool_ExtensionRegistry(symtab.ptr()), 0,
arena.ptr());
ASSERT_TRUE(message_set != nullptr);
char* serialized2 =
upb_test_TestMessageSet_serialize(message_set, arena.ptr(), &size);
ASSERT_TRUE(serialized2 != nullptr);
ASSERT_GE(size, 0);
// Parse back into a fake MessageSet and verify that the unknown MessageSet
// item was preserved in full (both type_id and message).
upb_test_FakeMessageSet* fake2 =
upb_test_FakeMessageSet_parse(serialized2, size, arena.ptr());
ASSERT_TRUE(fake2 != nullptr);
const upb_test_FakeMessageSet_Item* const* items =
upb_test_FakeMessageSet_item(fake2, &size);
ASSERT_EQ(1, size);
EXPECT_EQ(12345, upb_test_FakeMessageSet_Item_type_id(items[0]));
EXPECT_TRUE(upb_StringView_IsEqual(
data_view, upb_test_FakeMessageSet_Item_message(items[0])));
// The non-MessageSet unknown fields should have been discarded.
EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_varint(items[0]));
EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_fixed32(items[0]));
EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_fixed64(items[0]));
EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_bytes(items[0]));
EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknowngroup(items[0]));
}
TEST(MessageTest, Proto2Enum) {
upb::Arena arena;
upb_test_Proto2FakeEnumMessage* fake_msg =

@ -25,6 +25,8 @@
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
// LINT: ALLOW_GROUPS
syntax = "proto2";
package upb_test;
@ -61,6 +63,18 @@ message MessageSetMember {
}
}
message FakeMessageSet {
repeated group Item = 1 {
optional int32 type_id = 2;
optional bytes message = 3;
optional int32 unknown_varint = 4;
optional fixed32 unknown_fixed32 = 5;
optional fixed64 unknown_fixed64 = 6;
optional bytes unknown_bytes = 7;
optional group UnknownGroup = 8 {}
}
}
message Proto2EnumMessage {
enum Proto2TestEnum {
ZERO = 0;

@ -202,7 +202,7 @@ make:
bool upb_Message_Set(upb_Message* msg, const upb_FieldDef* f,
upb_MessageValue val, upb_Arena* a) {
if (upb_FieldDef_IsExtension(f)) {
upb_Message_Extension* ext = _upb_Message_Getorcreateext(
upb_Message_Extension* ext = _upb_Message_GetOrCreateExtension(
msg, _upb_FieldDef_ExtensionMiniTable(f), a);
if (!ext) return false;
memcpy(&ext->data, &val, sizeof(val));

@ -734,7 +734,7 @@ void GenerateExtensionInHeader(const protobuf::FieldDescriptor* ext,
R"cc(
UPB_INLINE void $1_set_$2(struct $3* msg, $0 ext, upb_Arena* arena) {
const upb_Message_Extension* msg_ext =
_upb_Message_Getorcreateext(msg, &$4, arena);
_upb_Message_GetOrCreateExtension(msg, &$4, arena);
UPB_ASSERT(msg_ext);
*UPB_PTR_AT(&msg_ext->data, 0, $0) = ext;
}

Loading…
Cancel
Save