diff --git a/upb/mini_table/decode.c b/upb/mini_table/decode.c index 10988f9568..cc020abbab 100644 --- a/upb/mini_table/decode.c +++ b/upb/mini_table/decode.c @@ -136,51 +136,47 @@ static const char* upb_MiniTable_DecodeBase92Varint(upb_MtDecoder* d, } } -static bool upb_MiniTable_HasSub(upb_MiniTableField* field, - uint64_t msg_modifiers) { - switch (field->UPB_PRIVATE(descriptortype)) { - case kUpb_FieldType_Message: - case kUpb_FieldType_Group: - case kUpb_FieldType_Enum: - return true; - case kUpb_FieldType_String: - if (!(msg_modifiers & kUpb_MessageModifier_ValidateUtf8)) { - field->UPB_PRIVATE(descriptortype) = kUpb_FieldType_Bytes; - field->mode |= kUpb_LabelFlags_IsAlternate; - } - return false; - default: - return false; - } -} - static bool upb_MtDecoder_FieldIsPackable(upb_MiniTableField* field) { return (field->mode & kUpb_FieldMode_Array) && upb_FieldType_IsPackable(field->UPB_PRIVATE(descriptortype)); } +typedef struct { + uint16_t submsg_count; + uint16_t subenum_count; +} upb_SubCounts; + static void upb_MiniTable_SetTypeAndSub(upb_MiniTableField* field, - upb_FieldType type, uint32_t* sub_count, + upb_FieldType type, + upb_SubCounts* sub_counts, uint64_t msg_modifiers, bool is_proto3_enum) { - field->UPB_PRIVATE(descriptortype) = type; - if (is_proto3_enum) { - UPB_ASSERT(field->UPB_PRIVATE(descriptortype) == kUpb_FieldType_Enum); - field->UPB_PRIVATE(descriptortype) = kUpb_FieldType_Int32; + UPB_ASSERT(type == kUpb_FieldType_Enum); + type = kUpb_FieldType_Int32; + field->mode |= kUpb_LabelFlags_IsAlternate; + } else if (type == kUpb_FieldType_String && + !(msg_modifiers & kUpb_MessageModifier_ValidateUtf8)) { + type = kUpb_FieldType_Bytes; field->mode |= kUpb_LabelFlags_IsAlternate; } - if (upb_MiniTable_HasSub(field, msg_modifiers)) { - field->UPB_PRIVATE(submsg_index) = sub_count ? (*sub_count)++ : 0; - } else { - field->UPB_PRIVATE(submsg_index) = kUpb_NoSub; - } + field->UPB_PRIVATE(descriptortype) = type; if (upb_MtDecoder_FieldIsPackable(field) && (msg_modifiers & kUpb_MessageModifier_DefaultIsPacked)) { field->mode |= kUpb_LabelFlags_IsPacked; } + + if (type == kUpb_FieldType_Message || type == kUpb_FieldType_Group) { + field->UPB_PRIVATE(submsg_index) = sub_counts->submsg_count++; + } else if (type == kUpb_FieldType_Enum) { + // We will need to update this later once we know the total number of + // submsg fields. + field->UPB_PRIVATE(submsg_index) = sub_counts->subenum_count++; + } else { + field->UPB_PRIVATE(submsg_index) = kUpb_NoSub; + } } static const char kUpb_EncodedToType[] = { @@ -208,7 +204,7 @@ static const char kUpb_EncodedToType[] = { static void upb_MiniTable_SetField(upb_MtDecoder* d, uint8_t ch, upb_MiniTableField* field, uint64_t msg_modifiers, - uint32_t* sub_count) { + upb_SubCounts* sub_counts) { static const char kUpb_EncodedToFieldRep[] = { [kUpb_EncodedType_Double] = kUpb_FieldRep_8Byte, [kUpb_EncodedType_Float] = kUpb_FieldRep_4Byte, @@ -255,7 +251,7 @@ static void upb_MiniTable_SetField(upb_MtDecoder* d, uint8_t ch, upb_MtDecoder_ErrorFormat(d, "Invalid field type: %d", (int)type); UPB_UNREACHABLE(); } - upb_MiniTable_SetTypeAndSub(field, kUpb_EncodedToType[type], sub_count, + upb_MiniTable_SetTypeAndSub(field, kUpb_EncodedToType[type], sub_counts, msg_modifiers, type == kUpb_EncodedType_OpenEnum); } @@ -443,18 +439,35 @@ static const char* upb_MtDecoder_ParseModifier(upb_MtDecoder* d, return ptr; } -static void upb_MtDecoder_AllocateSubs(upb_MtDecoder* d, uint32_t sub_count) { - size_t subs_bytes = sizeof(*d->table->subs) * sub_count; - void* subs = upb_Arena_Malloc(d->arena, subs_bytes); - memset(subs, 0, subs_bytes); +static void upb_MtDecoder_AllocateSubs(upb_MtDecoder* d, + upb_SubCounts sub_counts) { + uint32_t total_count = sub_counts.submsg_count + sub_counts.subenum_count; + size_t subs_bytes = sizeof(*d->table->subs) * total_count; + upb_MiniTableSub* subs = upb_Arena_Malloc(d->arena, subs_bytes); + upb_MtDecoder_CheckOutOfMemory(d, subs); + uint32_t i = 0; + for (; i < sub_counts.submsg_count; i++) { + subs[i].submsg = NULL; // &kUpb_MiniTable_Empty; + } + if (sub_counts.subenum_count) { + upb_MiniTableField* f = d->fields; + upb_MiniTableField* end_f = f + d->table->field_count; + for (; f < end_f; f++) { + if (f->UPB_PRIVATE(descriptortype) == kUpb_FieldType_Enum) { + f->UPB_PRIVATE(submsg_index) += sub_counts.submsg_count; + } + } + for (; i < sub_counts.submsg_count + sub_counts.subenum_count; i++) { + subs[i].subenum = NULL; + } + } d->table->subs = subs; - upb_MtDecoder_CheckOutOfMemory(d, d->table->subs); } static const char* upb_MtDecoder_Parse(upb_MtDecoder* d, const char* ptr, size_t len, void* fields, size_t field_size, uint16_t* field_count, - uint32_t* sub_count) { + upb_SubCounts* sub_counts) { uint64_t msg_modifiers = 0; uint32_t last_field_number = 0; upb_MiniTableField* last_field = NULL; @@ -474,7 +487,7 @@ static const char* upb_MtDecoder_Parse(upb_MtDecoder* d, const char* ptr, fields = (char*)fields + field_size; field->number = ++last_field_number; last_field = field; - upb_MiniTable_SetField(d, ch, field, msg_modifiers, sub_count); + upb_MiniTable_SetField(d, ch, field, msg_modifiers, sub_counts); } else if (kUpb_EncodedValue_MinModifier <= ch && ch <= kUpb_EncodedValue_MaxModifier) { ptr = upb_MtDecoder_ParseModifier(d, ptr, ch, last_field, &msg_modifiers); @@ -519,16 +532,16 @@ static void upb_MtDecoder_ParseMessage(upb_MtDecoder* d, const char* data, d->fields = upb_Arena_Malloc(d->arena, sizeof(*d->fields) * len); upb_MtDecoder_CheckOutOfMemory(d, d->fields); - uint32_t sub_count = 0; + upb_SubCounts sub_counts = {0, 0}; d->table->field_count = 0; d->table->fields = d->fields; upb_MtDecoder_Parse(d, data, len, d->fields, sizeof(*d->fields), - &d->table->field_count, &sub_count); + &d->table->field_count, &sub_counts); upb_Arena_ShrinkLast(d->arena, d->fields, sizeof(*d->fields) * len, sizeof(*d->fields) * d->table->field_count); d->table->fields = d->fields; - upb_MtDecoder_AllocateSubs(d, sub_count); + upb_MtDecoder_AllocateSubs(d, sub_counts); } int upb_MtDecoder_CompareFields(const void* _a, const void* _b) { @@ -942,8 +955,9 @@ static const char* upb_MtDecoder_DoBuildMiniTableExtension( } uint16_t count = 0; - const char* ret = - upb_MtDecoder_Parse(decoder, data, len, ext, sizeof(*ext), &count, NULL); + upb_SubCounts sub_counts = {0, 0}; + const char* ret = upb_MtDecoder_Parse(decoder, data, len, ext, sizeof(*ext), + &count, &sub_counts); if (!ret || count != 1) return NULL; upb_MiniTableField* f = &ext->field; diff --git a/upbc/protoc-gen-upb.cc b/upbc/protoc-gen-upb.cc index 25f9827b8a..3c2072dc47 100644 --- a/upbc/protoc-gen-upb.cc +++ b/upbc/protoc-gen-upb.cc @@ -1350,12 +1350,15 @@ void WriteMessage(upb::MessageDefPtr message, const DefPoolPair& pools, std::string subenums_array_ref = "NULL"; const upb_MiniTable* mt_32 = pools.GetMiniTable32(message); const upb_MiniTable* mt_64 = pools.GetMiniTable64(message); - std::vector subs; + std::map subs; for (int i = 0; i < mt_64->field_count; i++) { const upb_MiniTableField* f = &mt_64->fields[i]; - if (f->UPB_PRIVATE(submsg_index) != kUpb_NoSub) { - subs.push_back(GetSub(message.FindFieldByNumber(f->number))); + uint32_t index = f->UPB_PRIVATE(submsg_index); + if (index != kUpb_NoSub) { + auto pair = + subs.emplace(index, GetSub(message.FindFieldByNumber(f->number))); + ABSL_CHECK(pair.second); } } @@ -1365,8 +1368,10 @@ void WriteMessage(upb::MessageDefPtr message, const DefPoolPair& pools, output("static const upb_MiniTableSub $0[$1] = {\n", submsgs_array_name, subs.size()); - for (const auto& sub : subs) { - output(" $0,\n", sub); + int i = 0; + for (const auto& pair : subs) { + ABSL_CHECK(pair.first == i++); + output(" $0,\n", pair.second); } output("};\n\n");