From e1e7435e70ac631f6f43ec616b0bf07e6bf39b3f Mon Sep 17 00:00:00 2001 From: Protobuf Team Date: Sun, 10 Apr 2022 22:37:30 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 440796832 --- upb/decode.c | 217 ++++++++++++----------------------------- upb/msg.c | 2 +- upb/msg_internal.h | 2 +- upb/msg_test.cc | 62 ------------ upb/msg_test.proto | 14 --- upb/reflection.c | 2 +- upbc/protoc-gen-upb.cc | 2 +- 7 files changed, 66 insertions(+), 235 deletions(-) diff --git a/upb/decode.c b/upb/decode.c index 0985abc965..2d5bb49525 100644 --- a/upb/decode.c +++ b/upb/decode.c @@ -93,11 +93,13 @@ static const unsigned FIXED64_OK_MASK = (1 << kUpb_FieldType_Double) | /* Three fake field types for MessageSet. */ #define TYPE_MSGSET_ITEM 19 -#define TYPE_COUNT 19 +#define TYPE_MSGSET_TYPE_ID 20 +#define TYPE_COUNT 20 /* Op: an action to be performed for a wire-type/field-type combination. */ #define OP_UNKNOWN -1 /* Unknown field. */ #define OP_MSGSET_ITEM -2 +#define OP_MSGSET_TYPEID -3 #define OP_SCALAR_LG2(n) (n) /* n in [0, 2, 3] => op in [0, 2, 3] */ #define OP_ENUM 1 #define OP_STRING 4 @@ -129,6 +131,7 @@ static const int8_t varint_ops[] = { OP_SCALAR_LG2(2), /* SINT32 */ OP_SCALAR_LG2(3), /* SINT64 */ OP_UNKNOWN, /* MSGSET_ITEM */ + OP_MSGSET_TYPEID, /* MSGSET TYPEID */ }; static const int8_t delim_ops[] = { @@ -153,6 +156,7 @@ static const int8_t delim_ops[] = { OP_UNKNOWN, /* SINT32 */ OP_UNKNOWN, /* SINT64 */ OP_UNKNOWN, /* MSGSET_ITEM */ + OP_UNKNOWN, /* MSGSET TYPEID */ /* For repeated field type. */ OP_FIXPCK_LG2(3), /* REPEATED DOUBLE */ OP_FIXPCK_LG2(2), /* REPEATED FLOAT */ @@ -262,18 +266,6 @@ static const char* decode_tag(upb_Decoder* d, const char* ptr, uint32_t* val) { } } -UPB_FORCEINLINE -static const char* upb_Decoder_DecodeSize(upb_Decoder* d, const char* ptr, - uint32_t* size) { - uint64_t size64; - ptr = decode_varint64(d, ptr, &size64); - if (size64 >= INT32_MAX || ptr - d->end + (int)size64 > d->limit) { - decode_err(d, kUpb_DecodeStatus_Malformed); - } - *size = size64; - return ptr; -} - static void decode_munge_int32(wireval* val) { if (!_upb_IsLittleEndian()) { /* The next stage will memcpy(dst, &val, 4) */ @@ -308,9 +300,7 @@ static upb_Message* decode_newsubmsg(upb_Decoder* d, const upb_MiniTable_Sub* subs, const upb_MiniTable_Field* field) { const upb_MiniTable* subl = subs[field->submsg_index].submsg; - upb_Message* msg = _upb_Message_New_inl(subl, &d->arena); - if (!msg) decode_err(d, kUpb_DecodeStatus_OutOfMemory); - return msg; + return _upb_Message_New_inl(subl, &d->arena); } UPB_NOINLINE @@ -385,7 +375,7 @@ static const char* decode_togroup(upb_Decoder* d, const char* ptr, return decode_group(d, ptr, submsg, subl, field->number); } -static char* upb_Decoder_EncodeVarint32(uint32_t val, char* ptr) { +static char* encode_varint32(uint32_t val, char* ptr) { do { uint8_t byte = val & 0x7fU; val >>= 7; @@ -399,8 +389,8 @@ static void upb_Decode_AddUnknownVarints(upb_Decoder* d, upb_Message* msg, uint32_t val1, uint32_t val2) { char buf[20]; char* end = buf; - end = upb_Decoder_EncodeVarint32(val1, end); - end = upb_Decoder_EncodeVarint32(val2, end); + end = encode_varint32(val1, end); + end = encode_varint32(val2, end); if (!_upb_Message_AddUnknown(msg, buf, end - buf, &d->arena)) { decode_err(d, kUpb_DecodeStatus_OutOfMemory); @@ -753,139 +743,25 @@ static bool decode_tryfastdispatch(upb_Decoder* d, const char** ptr, return false; } -static const char* upb_Decoder_SkipField(upb_Decoder* d, const char* ptr, - uint32_t tag) { - int field_number = tag >> 3; - int wire_type = tag & 7; - switch (wire_type) { - case kUpb_WireType_Varint: { - uint64_t val; - return decode_varint64(d, ptr, &val); - } - case kUpb_WireType_64Bit: - return ptr + 8; - case kUpb_WireType_32Bit: - return ptr + 4; - case kUpb_WireType_Delimited: { - uint32_t size; - ptr = upb_Decoder_DecodeSize(d, ptr, &size); - return ptr + size; - } - case kUpb_WireType_StartGroup: - return decode_group(d, ptr, NULL, NULL, field_number); - default: - decode_err(d, kUpb_DecodeStatus_Malformed); - } -} - -enum { - kStartItemTag = ((1 << 3) | kUpb_WireType_StartGroup), - kEndItemTag = ((1 << 3) | kUpb_WireType_EndGroup), - kTypeIdTag = ((2 << 3) | kUpb_WireType_Varint), - kMessageTag = ((3 << 3) | kUpb_WireType_Delimited), -}; - -static void upb_Decoder_AddKnownMessageSetItem( - upb_Decoder* d, upb_Message* msg, const upb_MiniTable_Extension* item_mt, - const char* data, uint32_t size) { - upb_Message_Extension* ext = - _upb_Message_GetOrCreateExtension(msg, item_mt, &d->arena); - if (UPB_UNLIKELY(!ext)) decode_err(d, kUpb_DecodeStatus_OutOfMemory); - upb_Message* submsg = decode_newsubmsg(d, &ext->ext->sub, &ext->ext->field); - upb_DecodeStatus status = upb_Decode(data, size, submsg, item_mt->sub.submsg, - d->extreg, d->options, &d->arena); - memcpy(&ext->data, &submsg, sizeof(submsg)); - if (status != kUpb_DecodeStatus_Ok) decode_err(d, status); -} - -static void upb_Decoder_AddUnknownMessageSetItem(upb_Decoder* d, - upb_Message* msg, - uint32_t type_id, - const char* message_data, - uint32_t message_size) { - char buf[60]; - char* ptr = buf; - ptr = upb_Decoder_EncodeVarint32(kStartItemTag, ptr); - ptr = upb_Decoder_EncodeVarint32(kTypeIdTag, ptr); - ptr = upb_Decoder_EncodeVarint32(type_id, ptr); - ptr = upb_Decoder_EncodeVarint32(kMessageTag, ptr); - ptr = upb_Decoder_EncodeVarint32(message_size, ptr); - char* split = ptr; - - ptr = upb_Decoder_EncodeVarint32(kEndItemTag, ptr); - char* end = ptr; - - if (!_upb_Message_AddUnknown(msg, buf, split - buf, &d->arena) || - !_upb_Message_AddUnknown(msg, message_data, message_size, &d->arena) || - !_upb_Message_AddUnknown(msg, split, end - split, &d->arena)) { - decode_err(d, kUpb_DecodeStatus_OutOfMemory); - } -} - -static void upb_Decoder_AddMessageSetItem(upb_Decoder* d, upb_Message* msg, - const upb_MiniTable* layout, - uint32_t type_id, const char* data, - uint32_t size) { - const upb_MiniTable_Extension* item_mt = - _upb_extreg_get(d->extreg, layout, type_id); - if (item_mt) { - upb_Decoder_AddKnownMessageSetItem(d, msg, item_mt, data, size); - } else { - upb_Decoder_AddUnknownMessageSetItem(d, msg, type_id, data, size); - } -} - -static const char* upb_Decoder_DecodeMessageSetItem( - upb_Decoder* d, const char* ptr, upb_Message* msg, - const upb_MiniTable* layout) { - uint32_t type_id = 0; - upb_StringView preserved = {NULL, 0}; - typedef enum { - kUpb_HaveId = 1 << 0, - kUpb_HavePayload = 1 << 1, - } StateMask; - StateMask state_mask = 0; - while (!decode_isdone(d, &ptr)) { - uint32_t tag; - ptr = decode_tag(d, ptr, &tag); - switch (tag) { - case kEndItemTag: - return ptr; - case kTypeIdTag: { - uint64_t tmp; - ptr = decode_varint64(d, ptr, &tmp); - if (state_mask & kUpb_HaveId) break; // Ignore dup. - state_mask |= kUpb_HaveId; - type_id = tmp; - if (state_mask & kUpb_HavePayload) { - upb_Decoder_AddMessageSetItem(d, msg, layout, type_id, preserved.data, - preserved.size); - } - break; - } - case kMessageTag: { - uint32_t size; - ptr = upb_Decoder_DecodeSize(d, ptr, &size); - const char* data = ptr; - ptr += size; - if (state_mask & kUpb_HavePayload) break; // Ignore dup. - state_mask |= kUpb_HavePayload; - if (state_mask & kUpb_HaveId) { - upb_Decoder_AddMessageSetItem(d, msg, layout, type_id, data, size); - } else { - // Out of order, we must preserve the payload. - preserved.data = data; - preserved.size = size; - } - break; - } - default: - // We do not preserve unexpected fields inside a message set item. - ptr = upb_Decoder_SkipField(d, ptr, tag); - break; - } - } - decode_err(d, kUpb_DecodeStatus_Malformed); +static const char* decode_msgset(upb_Decoder* d, const char* ptr, + upb_Message* msg, + const upb_MiniTable* layout) { + // We create a temporary upb_MiniTable here and abuse its fields as temporary + // storage, to avoid creating lots of MessageSet-specific parsing code-paths: + // 1. We store 'layout' in item_layout.subs. We will need this later as + // a key to look up extensions for this MessageSet. + // 2. We use item_layout.fields as temporary storage to store the extension + // we + // found when parsing the type id. + upb_MiniTable item_layout = { + .subs = (const upb_MiniTable_Sub[]){{.submsg = layout}}, + .fields = NULL, + .size = 0, + .field_count = 0, + .ext = kUpb_ExtMode_IsMessageSet_ITEM, + .dense_below = 0, + .table_mask = -1}; + return decode_group(d, ptr, msg, &item_layout, 1); } static const upb_MiniTable_Field* decode_findfield(upb_Decoder* d, @@ -932,6 +808,26 @@ static const upb_MiniTable_Field* decode_findfield(upb_Decoder* d, return &item; } break; + case kUpb_ExtMode_IsMessageSet_ITEM: + switch (field_number) { + case _UPB_MSGSET_TYPEID: { + static upb_MiniTable_Field type_id = { + 0, 0, 0, 0, TYPE_MSGSET_TYPE_ID, 0}; + return &type_id; + } + case _UPB_MSGSET_MESSAGE: + if (l->fields) { + // We saw type_id previously and succeeded in looking up msg. + return l->fields; + } else { + // TODO: out of order MessageSet. + // This is a very rare case: all serializers will emit in-order + // MessageSets. To hit this case there has to be some kind of + // re-ordering proxy. We should eventually handle this case, but + // not today. + } + break; + } } } @@ -971,9 +867,14 @@ static const char* decode_wireval(upb_Decoder* d, const char* ptr, return ptr + 8; case kUpb_WireType_Delimited: { int ndx = field->descriptortype; + uint64_t size; if (upb_FieldMode_Get(field) == kUpb_FieldMode_Array) ndx += TYPE_COUNT; - ptr = upb_Decoder_DecodeSize(d, ptr, &val->size); + ptr = decode_varint64(d, ptr, &size); + if (size >= INT32_MAX || ptr - d->end + (int32_t)size > d->limit) { + break; /* Length overflow. */ + } *op = delim_ops[ndx]; + val->size = size; return ptr; } case kUpb_WireType_StartGroup: @@ -1004,7 +905,7 @@ static const char* decode_known(upb_Decoder* d, const char* ptr, const upb_MiniTable_Extension* ext_layout = (const upb_MiniTable_Extension*)field; upb_Message_Extension* ext = - _upb_Message_GetOrCreateExtension(msg, ext_layout, &d->arena); + _upb_Message_Getorcreateext(msg, ext_layout, &d->arena); if (UPB_UNLIKELY(!ext)) return decode_err(d, kUpb_DecodeStatus_OutOfMemory); msg = &ext->data; subs = &ext->ext->sub; @@ -1137,8 +1038,14 @@ static const char* decode_msg(upb_Decoder* d, const char* ptr, upb_Message* msg, ptr = decode_unknown(d, ptr, msg, field_number, wire_type, val); break; case OP_MSGSET_ITEM: - ptr = upb_Decoder_DecodeMessageSetItem(d, ptr, msg, layout); + ptr = decode_msgset(d, ptr, msg, layout); break; + case OP_MSGSET_TYPEID: { + const upb_MiniTable_Extension* ext = _upb_extreg_get( + d->extreg, layout->subs[0].submsg, val.uint64_val); + if (ext) ((upb_MiniTable*)layout)->fields = &ext->field; + break; + } } } } diff --git a/upb/msg.c b/upb/msg.c index 46738ca1fb..b8a629467d 100644 --- a/upb/msg.c +++ b/upb/msg.c @@ -153,7 +153,7 @@ void _upb_Message_Clearext(upb_Message* msg, } } -upb_Message_Extension* _upb_Message_GetOrCreateExtension( +upb_Message_Extension* _upb_Message_Getorcreateext( upb_Message* msg, const upb_MiniTable_Extension* e, upb_Arena* arena) { upb_Message_Extension* ext = (upb_Message_Extension*)_upb_Message_Getext(msg, e); diff --git a/upb/msg_internal.h b/upb/msg_internal.h index 4c1321f842..39adfdb9d4 100644 --- a/upb/msg_internal.h +++ b/upb/msg_internal.h @@ -336,7 +336,7 @@ typedef struct { /* Adds the given extension data to the given message. |ext| is copied into the * message instance. This logically replaces any previously-added extension with * this number */ -upb_Message_Extension* _upb_Message_GetOrCreateExtension( +upb_Message_Extension* _upb_Message_Getorcreateext( upb_Message* msg, const upb_MiniTable_Extension* ext, upb_Arena* arena); /* Returns an array of extensions for this message. Note: the array is diff --git a/upb/msg_test.cc b/upb/msg_test.cc index 2f33b4dd88..2d1f8e991b 100644 --- a/upb/msg_test.cc +++ b/upb/msg_test.cc @@ -102,7 +102,6 @@ TEST(MessageTest, Extensions) { } void VerifyMessageSet(const upb_test_TestMessageSet* mset_msg) { - ASSERT_TRUE(mset_msg != nullptr); bool has = upb_test_MessageSetMember_has_message_set_extension(mset_msg); EXPECT_TRUE(has); if (!has) return; @@ -161,67 +160,6 @@ TEST(MessageTest, MessageSet) { VerifyMessageSet(ext_msg3); } -TEST(MessageTest, UnknownMessageSet) { - static const char data[] = "ABCDE"; - upb_StringView data_view = upb_StringView_FromString(data); - upb::Arena arena; - upb_test_FakeMessageSet* fake = upb_test_FakeMessageSet_new(arena.ptr()); - - // Add a MessageSet item that is unknown (there is no matching extension in - // the .proto file) - upb_test_FakeMessageSet_Item* item = - upb_test_FakeMessageSet_add_item(fake, arena.ptr()); - upb_test_FakeMessageSet_Item_set_type_id(item, 12345); - upb_test_FakeMessageSet_Item_set_message(item, data_view); - - // Set unknown fields inside the message set to test that we can skip them. - upb_test_FakeMessageSet_Item_set_unknown_varint(item, 12345678); - upb_test_FakeMessageSet_Item_set_unknown_fixed32(item, 12345678); - upb_test_FakeMessageSet_Item_set_unknown_fixed64(item, 12345678); - upb_test_FakeMessageSet_Item_set_unknown_bytes(item, data_view); - upb_test_FakeMessageSet_Item_mutable_unknowngroup(item, arena.ptr()); - - // Round trip through a true MessageSet where this item_id is unknown. - size_t size; - char* serialized = - upb_test_FakeMessageSet_serialize(fake, arena.ptr(), &size); - ASSERT_TRUE(serialized != nullptr); - ASSERT_GE(size, 0); - - upb::SymbolTable symtab; - upb::MessageDefPtr m(upb_test_TestMessageSet_getmsgdef(symtab.ptr())); - EXPECT_TRUE(m.ptr() != nullptr); - upb_test_TestMessageSet* message_set = upb_test_TestMessageSet_parse_ex( - serialized, size, upb_DefPool_ExtensionRegistry(symtab.ptr()), 0, - arena.ptr()); - ASSERT_TRUE(message_set != nullptr); - - char* serialized2 = - upb_test_TestMessageSet_serialize(message_set, arena.ptr(), &size); - ASSERT_TRUE(serialized2 != nullptr); - ASSERT_GE(size, 0); - - // Parse back into a fake MessageSet and verify that the unknown MessageSet - // item was preserved in full (both type_id and message). - upb_test_FakeMessageSet* fake2 = - upb_test_FakeMessageSet_parse(serialized2, size, arena.ptr()); - ASSERT_TRUE(fake2 != nullptr); - - const upb_test_FakeMessageSet_Item* const* items = - upb_test_FakeMessageSet_item(fake2, &size); - ASSERT_EQ(1, size); - EXPECT_EQ(12345, upb_test_FakeMessageSet_Item_type_id(items[0])); - EXPECT_TRUE(upb_StringView_IsEqual( - data_view, upb_test_FakeMessageSet_Item_message(items[0]))); - - // The non-MessageSet unknown fields should have been discarded. - EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_varint(items[0])); - EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_fixed32(items[0])); - EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_fixed64(items[0])); - EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknown_bytes(items[0])); - EXPECT_FALSE(upb_test_FakeMessageSet_Item_has_unknowngroup(items[0])); -} - TEST(MessageTest, Proto2Enum) { upb::Arena arena; upb_test_Proto2FakeEnumMessage* fake_msg = diff --git a/upb/msg_test.proto b/upb/msg_test.proto index 011fb82e6c..1cdd84a3c2 100644 --- a/upb/msg_test.proto +++ b/upb/msg_test.proto @@ -25,8 +25,6 @@ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -// LINT: ALLOW_GROUPS - syntax = "proto2"; package upb_test; @@ -63,18 +61,6 @@ message MessageSetMember { } } -message FakeMessageSet { - repeated group Item = 1 { - optional int32 type_id = 2; - optional bytes message = 3; - optional int32 unknown_varint = 4; - optional fixed32 unknown_fixed32 = 5; - optional fixed64 unknown_fixed64 = 6; - optional bytes unknown_bytes = 7; - optional group UnknownGroup = 8 {} - } -} - message Proto2EnumMessage { enum Proto2TestEnum { ZERO = 0; diff --git a/upb/reflection.c b/upb/reflection.c index 31c487b230..a3a64d2780 100644 --- a/upb/reflection.c +++ b/upb/reflection.c @@ -202,7 +202,7 @@ make: bool upb_Message_Set(upb_Message* msg, const upb_FieldDef* f, upb_MessageValue val, upb_Arena* a) { if (upb_FieldDef_IsExtension(f)) { - upb_Message_Extension* ext = _upb_Message_GetOrCreateExtension( + upb_Message_Extension* ext = _upb_Message_Getorcreateext( msg, _upb_FieldDef_ExtensionMiniTable(f), a); if (!ext) return false; memcpy(&ext->data, &val, sizeof(val)); diff --git a/upbc/protoc-gen-upb.cc b/upbc/protoc-gen-upb.cc index 5a176bb363..bd6196c711 100644 --- a/upbc/protoc-gen-upb.cc +++ b/upbc/protoc-gen-upb.cc @@ -734,7 +734,7 @@ void GenerateExtensionInHeader(const protobuf::FieldDescriptor* ext, R"cc( UPB_INLINE void $1_set_$2(struct $3* msg, $0 ext, upb_Arena* arena) { const upb_Message_Extension* msg_ext = - _upb_Message_GetOrCreateExtension(msg, &$4, arena); + _upb_Message_Getorcreateext(msg, &$4, arena); UPB_ASSERT(msg_ext); *UPB_PTR_AT(&msg_ext->data, 0, $0) = ext; }