Allow dynamic tree shaking for groups, and disallow groups for map entry sub-messages.

The initial motivation for this CL was to fix a bug found by fuzzing.  But the fuzz bug pointed out a few edge cases that this CL corrects:

1. The core bug is that we were allowing a map entry sub-message to be linked to a group field.  This is not allowed in protobuf schemas, but we did not check for this edge case in `upb_MiniTable_SetSubMessage()`, so we were de facto allowing it. This triggered some bad behavior in the parser whereby we pushed a limit without checking its validity first.

2. To defend against this, I added asserts in `upb_MiniTable_SetSubMessage()` to verify the type of the field we are linking, to ensure that a group field is not linked to a map entry sub-message.  But this should probably be changed to return an error instead of relying on asserts for this.

3. I changed the fuzz util code that builds the MiniTable so that it will never violate this new invariant.  The fuzz util code now can run into situations where a group field has no valid non-map-entry sub-message to select.  In those cases it will simply not register any sub-message for that field.

4. Previously group did not support leaving sub-messages unregistered.  Previously I added this feature for sub-messages but not for groups.  There is no reason why dynamic tree shaking should not work for group fields, so I extended the feature to support groups also.

PiperOrigin-RevId: 504913630
pull/13171/head
Joshua Haberman 2 years ago committed by Copybara-Service
parent 5cb1b41a80
commit a5e337f88c
  1. 16
      upb/message/test.cc
  2. 5
      upb/mini_table/decode.c
  3. 49
      upb/test/fuzz_util.cc
  4. 20
      upb/wire/decode.c
  5. 1
      upb/wire/eps_copy_input_stream.h

@ -603,4 +603,20 @@ TEST(MessageTest, MapField) {
// -1960166338, 16809991);
// }
//
// TEST(FuzzTest, GroupMap) {
// // Groups should not be allowed as maps, but we previously failed to prevent
// // this.
// DecodeEncodeArbitrarySchemaAndPayload(
// {.mini_descriptors = {"$$FF$", "%-C"},
// .enum_mini_descriptors = {},
// .extensions = "",
// .links = {1}},
// std::string(
// "\023\020\030\233\000\204\330\372#\000`"
// "a\000\000\001\000\000\000ccccccc\030s\273sssssssss\030\030\030\030"
// "\030\030\030\030\215\215\215\215\215\215\215\215\030\030\232\253\253"
// "\232*\334\227\273\231\207\373\t\0051\305\265\335\224\226"),
// 0, 0);
// }
//
// end:google_only

@ -979,7 +979,12 @@ void upb_MiniTable_SetSubMessage(upb_MiniTable* table,
UPB_ASSERT((uintptr_t)table->fields <= (uintptr_t)field &&
(uintptr_t)field <
(uintptr_t)(table->fields + table->field_count));
// TODO: check these type invariants at runtime and return error to the
// caller if they are violated, instead of using an assert.
UPB_ASSERT(field->descriptortype == kUpb_FieldType_Message ||
field->descriptortype == kUpb_FieldType_Group);
if (sub->ext & kUpb_ExtMode_IsMapEntry) {
UPB_ASSERT(field->descriptortype == kUpb_FieldType_Message);
field->mode = (field->mode & ~kUpb_FieldMode_Mask) | kUpb_FieldMode_Map;
}
upb_MiniTableSub* table_sub = (void*)&table->subs[field->submsg_index];

@ -64,6 +64,20 @@ class Builder {
return input_->links[link_++];
}
const upb_MiniTable* NextNonMapEntryMiniTable() {
if (mini_tables_.empty()) return nullptr;
size_t start = NextLink() % mini_tables_.size();
size_t next = start;
do {
const upb_MiniTable* mini_table = mini_tables_[next];
if ((mini_table->ext & kUpb_ExtMode_IsMapEntry) == 0) {
return mini_table;
}
next = NextLink() % mini_tables_.size();
} while (next != start);
return nullptr;
}
const upb_MiniTable* NextMiniTable() {
return mini_tables_.empty()
? nullptr
@ -154,18 +168,29 @@ void Builder::LinkMessages() {
upb_MiniTableField* field =
const_cast<upb_MiniTableField*>(&table->fields[i]);
if (link_ == input_->links.size()) link_ = 0;
if (field->descriptortype == kUpb_FieldType_Message ||
field->descriptortype == kUpb_FieldType_Group) {
upb_MiniTable_SetSubMessage(table, field, NextMiniTable());
}
if (field->descriptortype == kUpb_FieldType_Enum) {
auto* et = NextEnumTable();
if (et) {
upb_MiniTable_SetSubEnum(table, field, et);
} else {
// We don't have any sub-enums. Override the field type so that it is
// not needed.
field->descriptortype = kUpb_FieldType_Int32;
switch (field->descriptortype) {
case kUpb_FieldType_Message: {
const upb_MiniTable* sub = NextMiniTable();
// We should always have at least one message.
assert(sub);
upb_MiniTable_SetSubMessage(table, field, sub);
break;
}
case kUpb_FieldType_Group: {
const upb_MiniTable* sub = NextNonMapEntryMiniTable();
// sub will be nullptr if no non-map entry messages are available.
if (sub) upb_MiniTable_SetSubMessage(table, field, sub);
break;
}
case kUpb_FieldType_Enum: {
auto* et = NextEnumTable();
if (et) {
upb_MiniTable_SetSubEnum(table, field, et);
} else {
// We don't have any sub-enums. Override the field type so that it
// is not needed.
field->descriptortype = kUpb_FieldType_Int32;
}
}
}
}

@ -937,6 +937,16 @@ int _upb_Decoder_GetVarintOp(const upb_MiniTableField* field) {
return kVarintOps[field->descriptortype];
}
UPB_FORCEINLINE
static void _upb_Decoder_CheckUnlinked(const upb_MiniTable* mt,
const upb_MiniTableField* field,
int* op) {
// If sub-message is not linked, treat as unknown.
if (field->mode & kUpb_LabelFlags_IsExtension) return;
const upb_MiniTableSub* sub = &mt->subs[field->submsg_index];
if (!sub->submsg) *op = kUpb_DecodeOp_UnknownField;
}
int _upb_Decoder_GetDelimitedOp(const upb_MiniTable* mt,
const upb_MiniTableField* field) {
enum { kRepeatedBase = 19 };
@ -991,13 +1001,8 @@ int _upb_Decoder_GetDelimitedOp(const upb_MiniTable* mt,
if (upb_FieldMode_Get(field) == kUpb_FieldMode_Array) ndx += kRepeatedBase;
int op = kDelimitedOps[ndx];
// If sub-message is not linked, treat as unknown.
if (op == kUpb_DecodeOp_SubMessage &&
!(field->mode & kUpb_LabelFlags_IsExtension)) {
const upb_MiniTableSub* sub = &mt->subs[field->submsg_index];
if (!sub->submsg) {
op = kUpb_DecodeOp_UnknownField;
}
if (op == kUpb_DecodeOp_SubMessage) {
_upb_Decoder_CheckUnlinked(mt, field, &op);
}
return op;
@ -1043,6 +1048,7 @@ static const char* _upb_Decoder_DecodeWireValue(upb_Decoder* d, const char* ptr,
val->uint32_val = field->number;
if (field->descriptortype == kUpb_FieldType_Group) {
*op = kUpb_DecodeOp_SubMessage;
_upb_Decoder_CheckUnlinked(mt, field, op);
} else if (field->descriptortype == kUpb_FakeFieldType_MessageSetItem) {
*op = kUpb_DecodeOp_MessageSetItem;
} else {

@ -340,6 +340,7 @@ UPB_INLINE int upb_EpsCopyInputStream_PushLimit(upb_EpsCopyInputStream* e,
int limit = size + (int)(ptr - e->end);
int delta = e->limit - limit;
_upb_EpsCopyInputStream_CheckLimit(e);
UPB_ASSERT(limit <= e->limit);
e->limit = limit;
e->limit_ptr = e->end + UPB_MIN(0, limit);
_upb_EpsCopyInputStream_CheckLimit(e);

Loading…
Cancel
Save