Merge pull request #324 from haberman/simplemomi

Eliminated bounds checks inside parsing a field.
pull/13171/head
Joshua Haberman 4 years ago committed by GitHub
commit 1bd62e8218
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 228
      upb/decode.c

@ -139,10 +139,14 @@ static const int8_t delim_ops[37] = {
/* Data pertaining to the parse. */ /* Data pertaining to the parse. */
typedef struct { typedef struct {
const char *limit; /* End of delimited region or end of buffer. */ const char *end; /* Can read up to 16 bytes slop beyond this. */
upb_arena arena; const char *limit_ptr; /* = end + UPB_MIN(limit, 0) */
int limit; /* Submessage limit relative to end. */
int depth; int depth;
uint32_t end_group; /* Set to field number of END_GROUP tag, if any. */ uint32_t end_group; /* Set to field number of END_GROUP tag, if any. */
bool alias;
char patch[32];
upb_arena arena;
jmp_buf err; jmp_buf err;
} upb_decstate; } upb_decstate;
@ -150,7 +154,7 @@ typedef union {
bool bool_val; bool bool_val;
uint32_t uint32_val; uint32_t uint32_val;
uint64_t uint64_val; uint64_t uint64_val;
upb_strview str_val; uint32_t size;
} wireval; } wireval;
static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
@ -200,41 +204,48 @@ static bool decode_reserve(upb_decstate *d, upb_array *arr, size_t elem) {
return need_realloc; return need_realloc;
} }
typedef struct {
const char *ptr;
uint64_t val;
} decode_vret;
UPB_NOINLINE UPB_NOINLINE
static const char *decode_longvarint64(upb_decstate *d, const char *ptr, static decode_vret decode_longvarint64(const char *ptr, uint64_t val) {
const char *limit, uint64_t *val) { decode_vret ret = {NULL, 0};
uint8_t byte; uint64_t byte;
int bitpos = 0; int i;
uint64_t out = 0; for (i = 1; i < 10; i++) {
byte = (uint8_t)ptr[i];
do { val += (byte - 1) << (i * 7);
if (bitpos >= 70 || ptr == limit) decode_err(d); if (!(byte & 0x80)) {
byte = *ptr; ret.ptr = ptr + i + 1;
out |= (uint64_t)(byte & 0x7F) << bitpos; ret.val = val;
ptr++; return ret;
bitpos += 7; }
} while (byte & 0x80); }
return ret;
*val = out;
return ptr;
} }
UPB_FORCEINLINE UPB_FORCEINLINE
static const char *decode_varint64(upb_decstate *d, const char *ptr, static const char *decode_varint64(upb_decstate *d, const char *ptr,
const char *limit, uint64_t *val) { uint64_t *val) {
if (UPB_LIKELY(ptr < limit && (*ptr & 0x80) == 0)) { uint64_t byte = (uint8_t)*ptr;
*val = (uint8_t)*ptr; if (UPB_LIKELY((byte & 0x80) == 0)) {
*val = byte;
return ptr + 1; return ptr + 1;
} else { } else {
return decode_longvarint64(d, ptr, limit, val); decode_vret res = decode_longvarint64(ptr, byte);
if (!res.ptr) decode_err(d);
*val = res.val;
return res.ptr;
} }
} }
UPB_FORCEINLINE UPB_FORCEINLINE
static const char *decode_varint32(upb_decstate *d, const char *ptr, static const char *decode_varint32(upb_decstate *d, const char *ptr,
const char *limit, uint32_t *val) { uint32_t *val) {
uint64_t u64; uint64_t u64;
ptr = decode_varint64(d, ptr, limit, &u64); ptr = decode_varint64(d, ptr, &u64);
if (u64 > UINT32_MAX) decode_err(d); if (u64 > UINT32_MAX) decode_err(d);
*val = (uint32_t)u64; *val = (uint32_t)u64;
return ptr; return ptr;
@ -287,17 +298,82 @@ static upb_msg *decode_newsubmsg(upb_decstate *d, const upb_msglayout *layout,
return _upb_msg_new_inl(subl, &d->arena); return _upb_msg_new_inl(subl, &d->arena);
} }
static void decode_tosubmsg(upb_decstate *d, upb_msg *submsg, static int decode_pushlimit(upb_decstate *d, const char *ptr, int size) {
const upb_msglayout *layout, int limit = size + (int)(ptr - d->end);
const upb_msglayout_field *field, upb_strview val) { int delta = d->limit - limit;
d->limit = limit;
d->limit_ptr = d->end + UPB_MIN(0, limit);
return delta;
}
static void decode_poplimit(upb_decstate *d, int saved_delta) {
d->limit += saved_delta;
d->limit_ptr = d->end + UPB_MIN(0, d->limit);
}
typedef struct {
bool ok;
const char *ptr;
} decode_doneret;
UPB_NOINLINE
static const char *decode_isdonefallback(upb_decstate *d, const char *ptr,
int overrun) {
if (overrun < d->limit) {
/* Need to copy remaining data into patch buffer. */
UPB_ASSERT(overrun < 16);
memset(d->patch + 16, 0, 16);
memcpy(d->patch, d->end, 16);
ptr = &d->patch[0] + overrun;
d->end = &d->patch[16];
d->limit -= 16;
d->limit_ptr = d->end + d->limit;
d->alias = false;
UPB_ASSERT(ptr < d->limit_ptr);
return ptr;
} else {
decode_err(d);
}
}
UPB_FORCEINLINE
static bool decode_isdone(upb_decstate *d, const char **ptr) {
int overrun = *ptr - d->end;
if (UPB_LIKELY(*ptr < d->limit_ptr)) {
return false;
} else if (UPB_LIKELY(overrun == d->limit)) {
return true;
} else {
*ptr = decode_isdonefallback(d, *ptr, overrun);
return false;
}
}
static const char *decode_readstr(upb_decstate *d, const char *ptr, int size,
upb_strview *str) {
if (d->alias) {
str->data = ptr;
} else {
char *data = upb_arena_malloc(&d->arena, size);
if (!data) decode_err(d);
memcpy(data, ptr, size);
str->data = data;
}
str->size = size;
return ptr + size;
}
static const char *decode_tosubmsg(upb_decstate *d, const char *ptr,
upb_msg *submsg, const upb_msglayout *layout,
const upb_msglayout_field *field, int size) {
const upb_msglayout *subl = layout->submsgs[field->submsg_index]; const upb_msglayout *subl = layout->submsgs[field->submsg_index];
const char *saved_limit = d->limit; int saved_delta = decode_pushlimit(d, ptr, size);
if (--d->depth < 0) decode_err(d); if (--d->depth < 0) decode_err(d);
d->limit = val.data + val.size; ptr = decode_msg(d, ptr, submsg, subl);
decode_msg(d, val.data, submsg, subl); decode_poplimit(d, saved_delta);
d->limit = saved_limit;
if (d->end_group != 0) decode_err(d); if (d->end_group != 0) decode_err(d);
d->depth++; d->depth++;
return ptr;
} }
static const char *decode_group(upb_decstate *d, const char *ptr, static const char *decode_group(upb_decstate *d, const char *ptr,
@ -345,15 +421,14 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
memcpy(mem, &val, 1 << op); memcpy(mem, &val, 1 << op);
return ptr; return ptr;
case OP_STRING: case OP_STRING:
decode_verifyutf8(d, val.str_val.data, val.str_val.size); decode_verifyutf8(d, ptr, val.size);
/* Fallthrough. */ /* Fallthrough. */
case OP_BYTES: case OP_BYTES: {
/* Append bytes. */ /* Append bytes. */
mem = upb_strview *str = (upb_strview*)_upb_array_ptr(arr) + arr->len;
UPB_PTR_AT(_upb_array_ptr(arr), arr->len * sizeof(upb_strview), void);
arr->len++; arr->len++;
memcpy(mem, &val, sizeof(upb_strview)); return decode_readstr(d, ptr, val.size, str);
return ptr; }
case OP_SUBMSG: { case OP_SUBMSG: {
/* Append submessage / group. */ /* Append submessage / group. */
upb_msg *submsg = decode_newsubmsg(d, layout, field); upb_msg *submsg = decode_newsubmsg(d, layout, field);
@ -361,26 +436,25 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
submsg; submsg;
arr->len++; arr->len++;
if (UPB_UNLIKELY(field->descriptortype == UPB_DTYPE_GROUP)) { if (UPB_UNLIKELY(field->descriptortype == UPB_DTYPE_GROUP)) {
ptr = decode_togroup(d, ptr, submsg, layout, field); return decode_togroup(d, ptr, submsg, layout, field);
} else { } else {
decode_tosubmsg(d, submsg, layout, field, val.str_val); return decode_tosubmsg(d, ptr, submsg, layout, field, val.size);
} }
return ptr;
} }
case OP_FIXPCK_LG2(2): case OP_FIXPCK_LG2(2):
case OP_FIXPCK_LG2(3): { case OP_FIXPCK_LG2(3): {
/* Fixed packed. */ /* Fixed packed. */
int lg2 = op - OP_FIXPCK_LG2(0); int lg2 = op - OP_FIXPCK_LG2(0);
int mask = (1 << lg2) - 1; int mask = (1 << lg2) - 1;
size_t count = val.str_val.size >> lg2; size_t count = val.size >> lg2;
if ((val.str_val.size & mask) != 0) { if ((val.size & mask) != 0) {
decode_err(d); /* Length isn't a round multiple of elem size. */ decode_err(d); /* Length isn't a round multiple of elem size. */
} }
decode_reserve(d, arr, count); decode_reserve(d, arr, count);
mem = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void); mem = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void);
arr->len += count; arr->len += count;
memcpy(mem, val.str_val.data, val.str_val.size); memcpy(mem, ptr, val.size); /* XXX: ptr boundary. */
return ptr; return ptr + val.size;
} }
case OP_VARPCK_LG2(0): case OP_VARPCK_LG2(0):
case OP_VARPCK_LG2(2): case OP_VARPCK_LG2(2):
@ -388,12 +462,11 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
/* Varint packed. */ /* Varint packed. */
int lg2 = op - OP_VARPCK_LG2(0); int lg2 = op - OP_VARPCK_LG2(0);
int scale = 1 << lg2; int scale = 1 << lg2;
const char *ptr = val.str_val.data; int saved_limit = decode_pushlimit(d, ptr, val.size);
const char *end = ptr + val.str_val.size;
char *out = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void); char *out = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void);
while (ptr < end) { while (!decode_isdone(d, &ptr)) {
wireval elem; wireval elem;
ptr = decode_varint64(d, ptr, end, &elem.uint64_val); ptr = decode_varint64(d, ptr, &elem.uint64_val);
decode_munge(field->descriptortype, &elem); decode_munge(field->descriptortype, &elem);
if (decode_reserve(d, arr, 1)) { if (decode_reserve(d, arr, 1)) {
out = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void); out = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void);
@ -402,7 +475,7 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
memcpy(out, &elem, scale); memcpy(out, &elem, scale);
out += scale; out += scale;
} }
if (ptr != end) decode_err(d); decode_poplimit(d, saved_limit);
return ptr; return ptr;
} }
default: default:
@ -410,9 +483,9 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
} }
} }
static void decode_tomap(upb_decstate *d, upb_msg *msg, static const char *decode_tomap(upb_decstate *d, const char *ptr, upb_msg *msg,
const upb_msglayout *layout, const upb_msglayout *layout,
const upb_msglayout_field *field, wireval val) { const upb_msglayout_field *field, wireval val) {
upb_map **map_p = UPB_PTR_AT(msg, field->offset, upb_map *); upb_map **map_p = UPB_PTR_AT(msg, field->offset, upb_map *);
upb_map *map = *map_p; upb_map *map = *map_p;
upb_map_entry ent; upb_map_entry ent;
@ -440,10 +513,9 @@ static void decode_tomap(upb_decstate *d, upb_msg *msg,
ent.v.val = upb_value_ptr(_upb_msg_new(entry->submsgs[0], &d->arena)); ent.v.val = upb_value_ptr(_upb_msg_new(entry->submsgs[0], &d->arena));
} }
decode_tosubmsg(d, &ent.k, layout, field, val.str_val); ptr = decode_tosubmsg(d, ptr, &ent.k, layout, field, val.size);
/* Insert into map. */
_upb_map_set(map, &ent.k, map->key_size, &ent.v, map->val_size, &d->arena); _upb_map_set(map, &ent.k, map->key_size, &ent.v, map->val_size, &d->arena);
return ptr;
} }
static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg, static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg,
@ -477,16 +549,15 @@ static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg,
if (UPB_UNLIKELY(type == UPB_DTYPE_GROUP)) { if (UPB_UNLIKELY(type == UPB_DTYPE_GROUP)) {
ptr = decode_togroup(d, ptr, submsg, layout, field); ptr = decode_togroup(d, ptr, submsg, layout, field);
} else { } else {
decode_tosubmsg(d, submsg, layout, field, val.str_val); ptr = decode_tosubmsg(d, ptr, submsg, layout, field, val.size);
} }
break; break;
} }
case OP_STRING: case OP_STRING:
decode_verifyutf8(d, val.str_val.data, val.str_val.size); decode_verifyutf8(d, ptr, val.size);
/* Fallthrough. */ /* Fallthrough. */
case OP_BYTES: case OP_BYTES:
memcpy(mem, &val, sizeof(upb_strview)); return decode_readstr(d, ptr, val.size, mem);
break;
case OP_SCALAR_LG2(3): case OP_SCALAR_LG2(3):
memcpy(mem, &val, 8); memcpy(mem, &val, 8);
break; break;
@ -505,7 +576,7 @@ static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg,
static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
const upb_msglayout *layout) { const upb_msglayout *layout) {
while (ptr < d->limit) { while (!decode_isdone(d, &ptr)) {
uint32_t tag; uint32_t tag;
const upb_msglayout_field *field; const upb_msglayout_field *field;
int field_number; int field_number;
@ -514,7 +585,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
wireval val; wireval val;
int op; int op;
ptr = decode_varint32(d, ptr, d->limit, &tag); ptr = decode_varint32(d, ptr, &tag);
field_number = tag >> 3; field_number = tag >> 3;
wire_type = tag & 7; wire_type = tag & 7;
@ -522,12 +593,11 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
switch (wire_type) { switch (wire_type) {
case UPB_WIRE_TYPE_VARINT: case UPB_WIRE_TYPE_VARINT:
ptr = decode_varint64(d, ptr, d->limit, &val.uint64_val); ptr = decode_varint64(d, ptr, &val.uint64_val);
op = varint_ops[field->descriptortype]; op = varint_ops[field->descriptortype];
decode_munge(field->descriptortype, &val); decode_munge(field->descriptortype, &val);
break; break;
case UPB_WIRE_TYPE_32BIT: case UPB_WIRE_TYPE_32BIT:
if (d->limit - ptr < 4) decode_err(d);
memcpy(&val.uint32_val, ptr, 4); memcpy(&val.uint32_val, ptr, 4);
val.uint32_val = _upb_be_swap32(val.uint32_val); val.uint32_val = _upb_be_swap32(val.uint32_val);
ptr += 4; ptr += 4;
@ -535,7 +605,6 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
if (((1 << field->descriptortype) & fixed32_ok) == 0) goto unknown; if (((1 << field->descriptortype) & fixed32_ok) == 0) goto unknown;
break; break;
case UPB_WIRE_TYPE_64BIT: case UPB_WIRE_TYPE_64BIT:
if (d->limit - ptr < 8) decode_err(d);
memcpy(&val.uint64_val, ptr, 8); memcpy(&val.uint64_val, ptr, 8);
val.uint64_val = _upb_be_swap64(val.uint64_val); val.uint64_val = _upb_be_swap64(val.uint64_val);
ptr += 8; ptr += 8;
@ -543,16 +612,12 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
if (((1 << field->descriptortype) & fixed64_ok) == 0) goto unknown; if (((1 << field->descriptortype) & fixed64_ok) == 0) goto unknown;
break; break;
case UPB_WIRE_TYPE_DELIMITED: { case UPB_WIRE_TYPE_DELIMITED: {
uint32_t size;
int ndx = field->descriptortype; int ndx = field->descriptortype;
if (_upb_isrepeated(field)) ndx += 18; if (_upb_isrepeated(field)) ndx += 18;
ptr = decode_varint32(d, ptr, d->limit, &size); ptr = decode_varint32(d, ptr, &val.size);
if (size >= INT32_MAX || (size_t)(d->limit - ptr) < size) { if (val.size >= INT32_MAX || ptr - d->end + val.size > d->limit) {
decode_err(d); /* Length overflow. */ decode_err(d); /* Length overflow. */
} }
val.str_val.data = ptr;
val.str_val.size = size;
ptr += size;
op = delim_ops[ndx]; op = delim_ops[ndx];
break; break;
} }
@ -576,7 +641,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
ptr = decode_toarray(d, ptr, msg, layout, field, val, op); ptr = decode_toarray(d, ptr, msg, layout, field, val, op);
break; break;
case _UPB_LABEL_MAP: case _UPB_LABEL_MAP:
decode_tomap(d, msg, layout, field, val); ptr = decode_tomap(d, ptr, msg, layout, field, val);
break; break;
default: default:
ptr = decode_tomsg(d, ptr, msg, layout, field, val, op); ptr = decode_tomsg(d, ptr, msg, layout, field, val, op);
@ -590,6 +655,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
ptr = decode_group(d, ptr, NULL, NULL, field_number); ptr = decode_group(d, ptr, NULL, NULL, field_number);
} }
if (msg) { if (msg) {
if (wire_type == UPB_WIRE_TYPE_DELIMITED) ptr += val.size;
if (!_upb_msg_addunknown(msg, field_start, ptr - field_start, if (!_upb_msg_addunknown(msg, field_start, ptr - field_start,
&d->arena)) { &d->arena)) {
decode_err(d); decode_err(d);
@ -598,7 +664,6 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
} }
} }
if (ptr != d->limit) decode_err(d);
return ptr; return ptr;
} }
@ -607,9 +672,22 @@ bool upb_decode(const char *buf, size_t size, void *msg, const upb_msglayout *l,
bool ok; bool ok;
upb_decstate state; upb_decstate state;
if (size == 0) return true; if (size == 0) {
return true;
} else if (size < 16) {
memset(&state.patch, 0, 32);
memcpy(&state.patch, buf, size);
buf = state.patch;
state.end = buf + size;
state.limit = 0;
state.alias = false;
} else {
state.end = buf + size - 16;
state.limit = 16;
state.alias = true;
}
state.limit = buf + size; state.limit_ptr = state.end;
state.depth = 64; state.depth = 64;
state.end_group = 0; state.end_group = 0;
state.arena.head = arena->head; state.arena.head = arena->head;

Loading…
Cancel
Save