From 9cc02bb60d01aae03326c4a81b3ca921ba89e94c Mon Sep 17 00:00:00 2001 From: Joshua Haberman Date: Sun, 10 Apr 2022 21:22:33 -0700 Subject: [PATCH] Rewrote the MessageSet parsing code in the upb decoder to properly handle several edge cases. PiperOrigin-RevId: 440788402 --- upb/decode.c | 217 +++++++++++++++++++++++++++++------------ upb/msg.c | 2 +- upb/msg_internal.h | 2 +- upb/msg_test.cc | 62 ++++++++++++ upb/msg_test.proto | 14 +++ upb/reflection.c | 2 +- upbc/protoc-gen-upb.cc | 2 +- 7 files changed, 235 insertions(+), 66 deletions(-) diff --git a/upb/decode.c b/upb/decode.c index 2d5bb49525..0985abc965 100644 --- a/upb/decode.c +++ b/upb/decode.c @@ -93,13 +93,11 @@ static const unsigned FIXED64_OK_MASK = (1 << kUpb_FieldType_Double) | /* Three fake field types for MessageSet. */ #define TYPE_MSGSET_ITEM 19 -#define TYPE_MSGSET_TYPE_ID 20 -#define TYPE_COUNT 20 +#define TYPE_COUNT 19 /* Op: an action to be performed for a wire-type/field-type combination. */ #define OP_UNKNOWN -1 /* Unknown field. */ #define OP_MSGSET_ITEM -2 -#define OP_MSGSET_TYPEID -3 #define OP_SCALAR_LG2(n) (n) /* n in [0, 2, 3] => op in [0, 2, 3] */ #define OP_ENUM 1 #define OP_STRING 4 @@ -131,7 +129,6 @@ static const int8_t varint_ops[] = { OP_SCALAR_LG2(2), /* SINT32 */ OP_SCALAR_LG2(3), /* SINT64 */ OP_UNKNOWN, /* MSGSET_ITEM */ - OP_MSGSET_TYPEID, /* MSGSET TYPEID */ }; static const int8_t delim_ops[] = { @@ -156,7 +153,6 @@ static const int8_t delim_ops[] = { OP_UNKNOWN, /* SINT32 */ OP_UNKNOWN, /* SINT64 */ OP_UNKNOWN, /* MSGSET_ITEM */ - OP_UNKNOWN, /* MSGSET TYPEID */ /* For repeated field type. */ OP_FIXPCK_LG2(3), /* REPEATED DOUBLE */ OP_FIXPCK_LG2(2), /* REPEATED FLOAT */ @@ -266,6 +262,18 @@ static const char* decode_tag(upb_Decoder* d, const char* ptr, uint32_t* val) { } } +UPB_FORCEINLINE +static const char* upb_Decoder_DecodeSize(upb_Decoder* d, const char* ptr, + uint32_t* size) { + uint64_t size64; + ptr = decode_varint64(d, ptr, &size64); + if (size64 >= INT32_MAX || ptr - d->end + (int)size64 > d->limit) { + decode_err(d, kUpb_DecodeStatus_Malformed); + } + *size = size64; + return ptr; +} + static void decode_munge_int32(wireval* val) { if (!_upb_IsLittleEndian()) { /* The next stage will memcpy(dst, &val, 4) */ @@ -300,7 +308,9 @@ static upb_Message* decode_newsubmsg(upb_Decoder* d, const upb_MiniTable_Sub* subs, const upb_MiniTable_Field* field) { const upb_MiniTable* subl = subs[field->submsg_index].submsg; - return _upb_Message_New_inl(subl, &d->arena); + upb_Message* msg = _upb_Message_New_inl(subl, &d->arena); + if (!msg) decode_err(d, kUpb_DecodeStatus_OutOfMemory); + return msg; } UPB_NOINLINE @@ -375,7 +385,7 @@ static const char* decode_togroup(upb_Decoder* d, const char* ptr, return decode_group(d, ptr, submsg, subl, field->number); } -static char* encode_varint32(uint32_t val, char* ptr) { +static char* upb_Decoder_EncodeVarint32(uint32_t val, char* ptr) { do { uint8_t byte = val & 0x7fU; val >>= 7; @@ -389,8 +399,8 @@ static void upb_Decode_AddUnknownVarints(upb_Decoder* d, upb_Message* msg, uint32_t val1, uint32_t val2) { char buf[20]; char* end = buf; - end = encode_varint32(val1, end); - end = encode_varint32(val2, end); + end = upb_Decoder_EncodeVarint32(val1, end); + end = upb_Decoder_EncodeVarint32(val2, end); if (!_upb_Message_AddUnknown(msg, buf, end - buf, &d->arena)) { decode_err(d, kUpb_DecodeStatus_OutOfMemory); @@ -743,25 +753,139 @@ static bool decode_tryfastdispatch(upb_Decoder* d, const char** ptr, return false; } -static const char* decode_msgset(upb_Decoder* d, const char* ptr, - upb_Message* msg, - const upb_MiniTable* layout) { - // We create a temporary upb_MiniTable here and abuse its fields as temporary - // storage, to avoid creating lots of MessageSet-specific parsing code-paths: - // 1. We store 'layout' in item_layout.subs. We will need this later as - // a key to look up extensions for this MessageSet. - // 2. We use item_layout.fields as temporary storage to store the extension - // we - // found when parsing the type id. - upb_MiniTable item_layout = { - .subs = (const upb_MiniTable_Sub[]){{.submsg = layout}}, - .fields = NULL, - .size = 0, - .field_count = 0, - .ext = kUpb_ExtMode_IsMessageSet_ITEM, - .dense_below = 0, - .table_mask = -1}; - return decode_group(d, ptr, msg, &item_layout, 1); +static const char* upb_Decoder_SkipField(upb_Decoder* d, const char* ptr, + uint32_t tag) { + int field_number = tag >> 3; + int wire_type = tag & 7; + switch (wire_type) { + case kUpb_WireType_Varint: { + uint64_t val; + return decode_varint64(d, ptr, &val); + } + case kUpb_WireType_64Bit: + return ptr + 8; + case kUpb_WireType_32Bit: + return ptr + 4; + case kUpb_WireType_Delimited: { + uint32_t size; + ptr = upb_Decoder_DecodeSize(d, ptr, &size); + return ptr + size; + } + case kUpb_WireType_StartGroup: + return decode_group(d, ptr, NULL, NULL, field_number); + default: + decode_err(d, kUpb_DecodeStatus_Malformed); + } +} + +enum { + kStartItemTag = ((1 << 3) | kUpb_WireType_StartGroup), + kEndItemTag = ((1 << 3) | kUpb_WireType_EndGroup), + kTypeIdTag = ((2 << 3) | kUpb_WireType_Varint), + kMessageTag = ((3 << 3) | kUpb_WireType_Delimited), +}; + +static void upb_Decoder_AddKnownMessageSetItem( + upb_Decoder* d, upb_Message* msg, const upb_MiniTable_Extension* item_mt, + const char* data, uint32_t size) { + upb_Message_Extension* ext = + _upb_Message_GetOrCreateExtension(msg, item_mt, &d->arena); + if (UPB_UNLIKELY(!ext)) decode_err(d, kUpb_DecodeStatus_OutOfMemory); + upb_Message* submsg = decode_newsubmsg(d, &ext->ext->sub, &ext->ext->field); + upb_DecodeStatus status = upb_Decode(data, size, submsg, item_mt->sub.submsg, + d->extreg, d->options, &d->arena); + memcpy(&ext->data, &submsg, sizeof(submsg)); + if (status != kUpb_DecodeStatus_Ok) decode_err(d, status); +} + +static void upb_Decoder_AddUnknownMessageSetItem(upb_Decoder* d, + upb_Message* msg, + uint32_t type_id, + const char* message_data, + uint32_t message_size) { + char buf[60]; + char* ptr = buf; + ptr = upb_Decoder_EncodeVarint32(kStartItemTag, ptr); + ptr = upb_Decoder_EncodeVarint32(kTypeIdTag, ptr); + ptr = upb_Decoder_EncodeVarint32(type_id, ptr); + ptr = upb_Decoder_EncodeVarint32(kMessageTag, ptr); + ptr = upb_Decoder_EncodeVarint32(message_size, ptr); + char* split = ptr; + + ptr = upb_Decoder_EncodeVarint32(kEndItemTag, ptr); + char* end = ptr; + + if (!_upb_Message_AddUnknown(msg, buf, split - buf, &d->arena) || + !_upb_Message_AddUnknown(msg, message_data, message_size, &d->arena) || + !_upb_Message_AddUnknown(msg, split, end - split, &d->arena)) { + decode_err(d, kUpb_DecodeStatus_OutOfMemory); + } +} + +static void upb_Decoder_AddMessageSetItem(upb_Decoder* d, upb_Message* msg, + const upb_MiniTable* layout, + uint32_t type_id, const char* data, + uint32_t size) { + const upb_MiniTable_Extension* item_mt = + _upb_extreg_get(d->extreg, layout, type_id); + if (item_mt) { + upb_Decoder_AddKnownMessageSetItem(d, msg, item_mt, data, size); + } else { + upb_Decoder_AddUnknownMessageSetItem(d, msg, type_id, data, size); + } +} + +static const char* upb_Decoder_DecodeMessageSetItem( + upb_Decoder* d, const char* ptr, upb_Message* msg, + const upb_MiniTable* layout) { + uint32_t type_id = 0; + upb_StringView preserved = {NULL, 0}; + typedef enum { + kUpb_HaveId = 1 << 0, + kUpb_HavePayload = 1 << 1, + } StateMask; + StateMask state_mask = 0; + while (!decode_isdone(d, &ptr)) { + uint32_t tag; + ptr = decode_tag(d, ptr, &tag); + switch (tag) { + case kEndItemTag: + return ptr; + case kTypeIdTag: { + uint64_t tmp; + ptr = decode_varint64(d, ptr, &tmp); + if (state_mask & kUpb_HaveId) break; // Ignore dup. + state_mask |= kUpb_HaveId; + type_id = tmp; + if (state_mask & kUpb_HavePayload) { + upb_Decoder_AddMessageSetItem(d, msg, layout, type_id, preserved.data, + preserved.size); + } + break; + } + case kMessageTag: { + uint32_t size; + ptr = upb_Decoder_DecodeSize(d, ptr, &size); + const char* data = ptr; + ptr += size; + if (state_mask & kUpb_HavePayload) break; // Ignore dup. + state_mask |= kUpb_HavePayload; + if (state_mask & kUpb_HaveId) { + upb_Decoder_AddMessageSetItem(d, msg, layout, type_id, data, size); + } else { + // Out of order, we must preserve the payload. + preserved.data = data; + preserved.size = size; + } + break; + } + default: + // We do not preserve unexpected fields inside a message set item. + ptr = upb_Decoder_SkipField(d, ptr, tag); + break; + } + } + decode_err(d, kUpb_DecodeStatus_Malformed); } static const upb_MiniTable_Field* decode_findfield(upb_Decoder* d, @@ -808,26 +932,6 @@ static const upb_MiniTable_Field* decode_findfield(upb_Decoder* d, return &item; } break; - case kUpb_ExtMode_IsMessageSet_ITEM: - switch (field_number) { - case _UPB_MSGSET_TYPEID: { - static upb_MiniTable_Field type_id = { - 0, 0, 0, 0, TYPE_MSGSET_TYPE_ID, 0}; - return &type_id; - } - case _UPB_MSGSET_MESSAGE: - if (l->fields) { - // We saw type_id previously and succeeded in looking up msg. - return l->fields; - } else { - // TODO: out of order MessageSet. - // This is a very rare case: all serializers will emit in-order - // MessageSets. To hit this case there has to be some kind of - // re-ordering proxy. We should eventually handle this case, but - // not today. - } - break; - } } } @@ -867,14 +971,9 @@ static const char* decode_wireval(upb_Decoder* d, const char* ptr, return ptr + 8; case kUpb_WireType_Delimited: { int ndx = field->descriptortype; - uint64_t size; if (upb_FieldMode_Get(field) == kUpb_FieldMode_Array) ndx += TYPE_COUNT; - ptr = decode_varint64(d, ptr, &size); - if (size >= INT32_MAX || ptr - d->end + (int32_t)size > d->limit) { - break; /* Length overflow. */ - } + ptr = upb_Decoder_DecodeSize(d, ptr, &val->size); *op = delim_ops[ndx]; - val->size = size; return ptr; } case kUpb_WireType_StartGroup: @@ -905,7 +1004,7 @@ static const char* decode_known(upb_Decoder* d, const char* ptr, const upb_MiniTable_Extension* ext_layout = (const upb_MiniTable_Extension*)field; upb_Message_Extension* ext = - _upb_Message_Getorcreateext(msg, ext_layout, &d->arena); + _upb_Message_GetOrCreateExtension(msg, ext_layout, &d->arena); if (UPB_UNLIKELY(!ext)) return decode_err(d, kUpb_DecodeStatus_OutOfMemory); msg = &ext->data; subs = &ext->ext->sub; @@ -1038,14 +1137,8 @@ static const char* decode_msg(upb_Decoder* d, const char* ptr, upb_Message* msg, ptr = decode_unknown(d, ptr, msg, field_number, wire_type, val); break; case OP_MSGSET_ITEM: - ptr = decode_msgset(d, ptr, msg, layout); + ptr = upb_Decoder_DecodeMessageSetItem(d, ptr, msg, layout); break; - case OP_MSGSET_TYPEID: { - const upb_MiniTable_Extension* ext = _upb_extreg_get( - d->extreg, layout->subs[0].submsg, val.uint64_val); - if (ext) ((upb_MiniTable*)layout)->fields = &ext->field; - break; - } } } } diff --git a/upb/msg.c b/upb/msg.c index b8a629467d..46738ca1fb 100644 --- a/upb/msg.c +++ b/upb/msg.c @@ -153,7 +153,7 @@ void _upb_Message_Clearext(upb_Message* msg, } } -upb_Message_Extension* _upb_Message_Getorcreateext( +upb_Message_Extension* _upb_Message_GetOrCreateExtension( upb_Message* msg, const upb_MiniTable_Extension* e, upb_Arena* arena) { upb_Message_Extension* ext = (upb_Message_Extension*)_upb_Message_Getext(msg, e); diff --git a/upb/msg_internal.h b/upb/msg_internal.h index 39adfdb9d4..4c1321f842 100644 --- a/upb/msg_internal.h +++ b/upb/msg_internal.h @@ -336,7 +336,7 @@ typedef struct { /* Adds the given extension data to the given message. |ext| is copied into the * message instance. This logically replaces any previously-added extension with * this number */ -upb_Message_Extension* _upb_Message_Getorcreateext( +upb_Message_Extension* _upb_Message_GetOrCreateExtension( upb_Message* msg, const upb_MiniTable_Extension* ext, upb_Arena* arena); /* Returns an array of extensions for this message. Note: the array is diff --git a/upb/msg_test.cc b/upb/msg_test.cc index 2d1f8e991b..2f33b4dd88 100644 --- a/upb/msg_test.cc +++ b/upb/msg_test.cc @@ -102,6 +102,7 @@ TEST(MessageTest, Extensions) { } void VerifyMessageSet(const upb_test_TestMessageSet* mset_msg) { + ASSERT_TRUE(mset_msg != nullptr); bool has = upb_test_MessageSetMember_has_message_set_extension(mset_msg); EXPECT_TRUE(has); if (!has) return; @@ -160,6 +161,67 @@ TEST(MessageTest, MessageSet) { VerifyMessageSet(ext_msg3); } +TEST(MessageTest, UnknownMessageSet) { + static const char data[] = "ABCDE"; + upb_StringView data_view = upb_StringView_FromString(data); + upb::Arena arena; + upb_test_FakeMessageSet* fake = upb_test_FakeMessageSet_new(arena.ptr()); + + // Add a MessageSet item that is unknown (there is no matching extension in + // the .proto file) + upb_test_FakeMessageSet_Item* item = + upb_test_FakeMessageSet_add_item(fake, arena.ptr()); + upb_test_FakeMessageSet_Item_set_type_id(item, 12345); + upb_test_FakeMessageSet_Item_set_message(item, data_view); + + // Set unknown fields inside the message set to test that we can skip them. + upb_test_FakeMessageSet_Item_set_unknown_varint(item, 12345678); + upb_test_FakeMessageSet_Item_set_unknown_fixed32(item, 12345678); + upb_test_FakeMessageSet_Item_set_unknown_fixed64(item, 12345678); + upb_test_FakeMessageSet_Item_set_unknown_bytes(item, data_view); + upb_test_FakeMessageSet_Item_mutable_unknowngroup(item, arena.ptr()); + + // Round trip through a true MessageSet where this item_id is unknown. + size_t size; + char* serialized = + upb_test_FakeMessageSet_serialize(fake, arena.ptr(), &size); + ASSERT_TRUE(serialized != nullptr); + ASSERT_GE(size, 0); + + upb::SymbolTable symtab; + upb::MessageDefPtr m(upb_test_TestMessageSet_getmsgdef(symtab.ptr())); + EXPECT_TRUE(m.ptr() != nullptr); + upb_test_TestMessageSet* message_set = upb_test_TestMessageSet_parse_ex( + serialized, size, upb_DefPool_ExtensionRegistry(symtab.ptr()), 0, + arena.ptr()); + ASSERT_TRUE(message_set != nullptr); + + char* serialized2 = + upb_test_TestMessageSet_serialize(message_set, arena.ptr(), &size); + ASSERT_TRUE(serialized2 != nullptr); + ASSERT_GE(size, 0); + + // Parse back into a fake MessageSet and verify that the unknown MessageSet + // item was preserved in full (both type_id and message). + upb_test_FakeMessageSet* fake2 = + upb_test_FakeMessageSet_parse(serialized2, size, arena.ptr()); + ASSERT_TRUE(fake2 != nullptr); + + const upb_test_FakeMessageSet_Item* const* items = + upb_test_FakeMessageSet_item(fake2, &size); + ASSERT_EQ(1, size); + EXPECT_EQ(12345, upb_test_FakeMessageSet_Item_type_id(items[0])); + EXPECT_TRUE(upb_StringView_IsEqual( + data_view, upb_test_FakeMessageSet_Item_message(items[0]))); + + // The non-MessageSet unknown fields should have been discarded. + EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_varint(items[0])); + EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_fixed32(items[0])); + EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_fixed64(items[0])); + EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_bytes(items[0])); + EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknowngroup(items[0])); +} + TEST(MessageTest, Proto2Enum) { upb::Arena arena; upb_test_Proto2FakeEnumMessage* fake_msg = diff --git a/upb/msg_test.proto b/upb/msg_test.proto index 1cdd84a3c2..011fb82e6c 100644 --- a/upb/msg_test.proto +++ b/upb/msg_test.proto @@ -25,6 +25,8 @@ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +// LINT: ALLOW_GROUPS + syntax = "proto2"; package upb_test; @@ -61,6 +63,18 @@ message MessageSetMember { } } +message FakeMessageSet { + repeated group Item = 1 { + optional int32 type_id = 2; + optional bytes message = 3; + optional int32 unknown_varint = 4; + optional fixed32 unknown_fixed32 = 5; + optional fixed64 unknown_fixed64 = 6; + optional bytes unknown_bytes = 7; + optional group UnknownGroup = 8 {} + } +} + message Proto2EnumMessage { enum Proto2TestEnum { ZERO = 0; diff --git a/upb/reflection.c b/upb/reflection.c index a3a64d2780..31c487b230 100644 --- a/upb/reflection.c +++ b/upb/reflection.c @@ -202,7 +202,7 @@ make: bool upb_Message_Set(upb_Message* msg, const upb_FieldDef* f, upb_MessageValue val, upb_Arena* a) { if (upb_FieldDef_IsExtension(f)) { - upb_Message_Extension* ext = _upb_Message_Getorcreateext( + upb_Message_Extension* ext = _upb_Message_GetOrCreateExtension( msg, _upb_FieldDef_ExtensionMiniTable(f), a); if (!ext) return false; memcpy(&ext->data, &val, sizeof(val)); diff --git a/upbc/protoc-gen-upb.cc b/upbc/protoc-gen-upb.cc index bd6196c711..5a176bb363 100644 --- a/upbc/protoc-gen-upb.cc +++ b/upbc/protoc-gen-upb.cc @@ -734,7 +734,7 @@ void GenerateExtensionInHeader(const protobuf::FieldDescriptor* ext, R"cc( UPB_INLINE void $1_set_$2(struct $3* msg, $0 ext, upb_Arena* arena) { const upb_Message_Extension* msg_ext = - _upb_Message_Getorcreateext(msg, &$4, arena); + _upb_Message_GetOrCreateExtension(msg, &$4, arena); UPB_ASSERT(msg_ext); *UPB_PTR_AT(&msg_ext->data, 0, $0) = ext; }