Merge pull request #324 from haberman/simplemomi

Eliminated bounds checks inside parsing a field.
pull/13171/head
Joshua Haberman 4 years ago committed by GitHub
commit 1bd62e8218
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 224
      upb/decode.c

@ -139,10 +139,14 @@ static const int8_t delim_ops[37] = {
/* Data pertaining to the parse. */
typedef struct {
const char *limit; /* End of delimited region or end of buffer. */
upb_arena arena;
const char *end; /* Can read up to 16 bytes slop beyond this. */
const char *limit_ptr; /* = end + UPB_MIN(limit, 0) */
int limit; /* Submessage limit relative to end. */
int depth;
uint32_t end_group; /* Set to field number of END_GROUP tag, if any. */
bool alias;
char patch[32];
upb_arena arena;
jmp_buf err;
} upb_decstate;
@ -150,7 +154,7 @@ typedef union {
bool bool_val;
uint32_t uint32_val;
uint64_t uint64_val;
upb_strview str_val;
uint32_t size;
} wireval;
static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
@ -200,41 +204,48 @@ static bool decode_reserve(upb_decstate *d, upb_array *arr, size_t elem) {
return need_realloc;
}
typedef struct {
const char *ptr;
uint64_t val;
} decode_vret;
UPB_NOINLINE
static const char *decode_longvarint64(upb_decstate *d, const char *ptr,
const char *limit, uint64_t *val) {
uint8_t byte;
int bitpos = 0;
uint64_t out = 0;
do {
if (bitpos >= 70 || ptr == limit) decode_err(d);
byte = *ptr;
out |= (uint64_t)(byte & 0x7F) << bitpos;
ptr++;
bitpos += 7;
} while (byte & 0x80);
*val = out;
return ptr;
static decode_vret decode_longvarint64(const char *ptr, uint64_t val) {
decode_vret ret = {NULL, 0};
uint64_t byte;
int i;
for (i = 1; i < 10; i++) {
byte = (uint8_t)ptr[i];
val += (byte - 1) << (i * 7);
if (!(byte & 0x80)) {
ret.ptr = ptr + i + 1;
ret.val = val;
return ret;
}
}
return ret;
}
UPB_FORCEINLINE
static const char *decode_varint64(upb_decstate *d, const char *ptr,
const char *limit, uint64_t *val) {
if (UPB_LIKELY(ptr < limit && (*ptr & 0x80) == 0)) {
*val = (uint8_t)*ptr;
uint64_t *val) {
uint64_t byte = (uint8_t)*ptr;
if (UPB_LIKELY((byte & 0x80) == 0)) {
*val = byte;
return ptr + 1;
} else {
return decode_longvarint64(d, ptr, limit, val);
decode_vret res = decode_longvarint64(ptr, byte);
if (!res.ptr) decode_err(d);
*val = res.val;
return res.ptr;
}
}
UPB_FORCEINLINE
static const char *decode_varint32(upb_decstate *d, const char *ptr,
const char *limit, uint32_t *val) {
uint32_t *val) {
uint64_t u64;
ptr = decode_varint64(d, ptr, limit, &u64);
ptr = decode_varint64(d, ptr, &u64);
if (u64 > UINT32_MAX) decode_err(d);
*val = (uint32_t)u64;
return ptr;
@ -287,17 +298,82 @@ static upb_msg *decode_newsubmsg(upb_decstate *d, const upb_msglayout *layout,
return _upb_msg_new_inl(subl, &d->arena);
}
static void decode_tosubmsg(upb_decstate *d, upb_msg *submsg,
const upb_msglayout *layout,
const upb_msglayout_field *field, upb_strview val) {
static int decode_pushlimit(upb_decstate *d, const char *ptr, int size) {
int limit = size + (int)(ptr - d->end);
int delta = d->limit - limit;
d->limit = limit;
d->limit_ptr = d->end + UPB_MIN(0, limit);
return delta;
}
static void decode_poplimit(upb_decstate *d, int saved_delta) {
d->limit += saved_delta;
d->limit_ptr = d->end + UPB_MIN(0, d->limit);
}
typedef struct {
bool ok;
const char *ptr;
} decode_doneret;
UPB_NOINLINE
static const char *decode_isdonefallback(upb_decstate *d, const char *ptr,
int overrun) {
if (overrun < d->limit) {
/* Need to copy remaining data into patch buffer. */
UPB_ASSERT(overrun < 16);
memset(d->patch + 16, 0, 16);
memcpy(d->patch, d->end, 16);
ptr = &d->patch[0] + overrun;
d->end = &d->patch[16];
d->limit -= 16;
d->limit_ptr = d->end + d->limit;
d->alias = false;
UPB_ASSERT(ptr < d->limit_ptr);
return ptr;
} else {
decode_err(d);
}
}
UPB_FORCEINLINE
static bool decode_isdone(upb_decstate *d, const char **ptr) {
int overrun = *ptr - d->end;
if (UPB_LIKELY(*ptr < d->limit_ptr)) {
return false;
} else if (UPB_LIKELY(overrun == d->limit)) {
return true;
} else {
*ptr = decode_isdonefallback(d, *ptr, overrun);
return false;
}
}
static const char *decode_readstr(upb_decstate *d, const char *ptr, int size,
upb_strview *str) {
if (d->alias) {
str->data = ptr;
} else {
char *data = upb_arena_malloc(&d->arena, size);
if (!data) decode_err(d);
memcpy(data, ptr, size);
str->data = data;
}
str->size = size;
return ptr + size;
}
static const char *decode_tosubmsg(upb_decstate *d, const char *ptr,
upb_msg *submsg, const upb_msglayout *layout,
const upb_msglayout_field *field, int size) {
const upb_msglayout *subl = layout->submsgs[field->submsg_index];
const char *saved_limit = d->limit;
int saved_delta = decode_pushlimit(d, ptr, size);
if (--d->depth < 0) decode_err(d);
d->limit = val.data + val.size;
decode_msg(d, val.data, submsg, subl);
d->limit = saved_limit;
ptr = decode_msg(d, ptr, submsg, subl);
decode_poplimit(d, saved_delta);
if (d->end_group != 0) decode_err(d);
d->depth++;
return ptr;
}
static const char *decode_group(upb_decstate *d, const char *ptr,
@ -345,15 +421,14 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
memcpy(mem, &val, 1 << op);
return ptr;
case OP_STRING:
decode_verifyutf8(d, val.str_val.data, val.str_val.size);
decode_verifyutf8(d, ptr, val.size);
/* Fallthrough. */
case OP_BYTES:
case OP_BYTES: {
/* Append bytes. */
mem =
UPB_PTR_AT(_upb_array_ptr(arr), arr->len * sizeof(upb_strview), void);
upb_strview *str = (upb_strview*)_upb_array_ptr(arr) + arr->len;
arr->len++;
memcpy(mem, &val, sizeof(upb_strview));
return ptr;
return decode_readstr(d, ptr, val.size, str);
}
case OP_SUBMSG: {
/* Append submessage / group. */
upb_msg *submsg = decode_newsubmsg(d, layout, field);
@ -361,26 +436,25 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
submsg;
arr->len++;
if (UPB_UNLIKELY(field->descriptortype == UPB_DTYPE_GROUP)) {
ptr = decode_togroup(d, ptr, submsg, layout, field);
return decode_togroup(d, ptr, submsg, layout, field);
} else {
decode_tosubmsg(d, submsg, layout, field, val.str_val);
return decode_tosubmsg(d, ptr, submsg, layout, field, val.size);
}
return ptr;
}
case OP_FIXPCK_LG2(2):
case OP_FIXPCK_LG2(3): {
/* Fixed packed. */
int lg2 = op - OP_FIXPCK_LG2(0);
int mask = (1 << lg2) - 1;
size_t count = val.str_val.size >> lg2;
if ((val.str_val.size & mask) != 0) {
size_t count = val.size >> lg2;
if ((val.size & mask) != 0) {
decode_err(d); /* Length isn't a round multiple of elem size. */
}
decode_reserve(d, arr, count);
mem = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void);
arr->len += count;
memcpy(mem, val.str_val.data, val.str_val.size);
return ptr;
memcpy(mem, ptr, val.size); /* XXX: ptr boundary. */
return ptr + val.size;
}
case OP_VARPCK_LG2(0):
case OP_VARPCK_LG2(2):
@ -388,12 +462,11 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
/* Varint packed. */
int lg2 = op - OP_VARPCK_LG2(0);
int scale = 1 << lg2;
const char *ptr = val.str_val.data;
const char *end = ptr + val.str_val.size;
int saved_limit = decode_pushlimit(d, ptr, val.size);
char *out = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void);
while (ptr < end) {
while (!decode_isdone(d, &ptr)) {
wireval elem;
ptr = decode_varint64(d, ptr, end, &elem.uint64_val);
ptr = decode_varint64(d, ptr, &elem.uint64_val);
decode_munge(field->descriptortype, &elem);
if (decode_reserve(d, arr, 1)) {
out = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void);
@ -402,7 +475,7 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
memcpy(out, &elem, scale);
out += scale;
}
if (ptr != end) decode_err(d);
decode_poplimit(d, saved_limit);
return ptr;
}
default:
@ -410,7 +483,7 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr,
}
}
static void decode_tomap(upb_decstate *d, upb_msg *msg,
static const char *decode_tomap(upb_decstate *d, const char *ptr, upb_msg *msg,
const upb_msglayout *layout,
const upb_msglayout_field *field, wireval val) {
upb_map **map_p = UPB_PTR_AT(msg, field->offset, upb_map *);
@ -440,10 +513,9 @@ static void decode_tomap(upb_decstate *d, upb_msg *msg,
ent.v.val = upb_value_ptr(_upb_msg_new(entry->submsgs[0], &d->arena));
}
decode_tosubmsg(d, &ent.k, layout, field, val.str_val);
/* Insert into map. */
ptr = decode_tosubmsg(d, ptr, &ent.k, layout, field, val.size);
_upb_map_set(map, &ent.k, map->key_size, &ent.v, map->val_size, &d->arena);
return ptr;
}
static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg,
@ -477,16 +549,15 @@ static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg,
if (UPB_UNLIKELY(type == UPB_DTYPE_GROUP)) {
ptr = decode_togroup(d, ptr, submsg, layout, field);
} else {
decode_tosubmsg(d, submsg, layout, field, val.str_val);
ptr = decode_tosubmsg(d, ptr, submsg, layout, field, val.size);
}
break;
}
case OP_STRING:
decode_verifyutf8(d, val.str_val.data, val.str_val.size);
decode_verifyutf8(d, ptr, val.size);
/* Fallthrough. */
case OP_BYTES:
memcpy(mem, &val, sizeof(upb_strview));
break;
return decode_readstr(d, ptr, val.size, mem);
case OP_SCALAR_LG2(3):
memcpy(mem, &val, 8);
break;
@ -505,7 +576,7 @@ static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg,
static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
const upb_msglayout *layout) {
while (ptr < d->limit) {
while (!decode_isdone(d, &ptr)) {
uint32_t tag;
const upb_msglayout_field *field;
int field_number;
@ -514,7 +585,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
wireval val;
int op;
ptr = decode_varint32(d, ptr, d->limit, &tag);
ptr = decode_varint32(d, ptr, &tag);
field_number = tag >> 3;
wire_type = tag & 7;
@ -522,12 +593,11 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
switch (wire_type) {
case UPB_WIRE_TYPE_VARINT:
ptr = decode_varint64(d, ptr, d->limit, &val.uint64_val);
ptr = decode_varint64(d, ptr, &val.uint64_val);
op = varint_ops[field->descriptortype];
decode_munge(field->descriptortype, &val);
break;
case UPB_WIRE_TYPE_32BIT:
if (d->limit - ptr < 4) decode_err(d);
memcpy(&val.uint32_val, ptr, 4);
val.uint32_val = _upb_be_swap32(val.uint32_val);
ptr += 4;
@ -535,7 +605,6 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
if (((1 << field->descriptortype) & fixed32_ok) == 0) goto unknown;
break;
case UPB_WIRE_TYPE_64BIT:
if (d->limit - ptr < 8) decode_err(d);
memcpy(&val.uint64_val, ptr, 8);
val.uint64_val = _upb_be_swap64(val.uint64_val);
ptr += 8;
@ -543,16 +612,12 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
if (((1 << field->descriptortype) & fixed64_ok) == 0) goto unknown;
break;
case UPB_WIRE_TYPE_DELIMITED: {
uint32_t size;
int ndx = field->descriptortype;
if (_upb_isrepeated(field)) ndx += 18;
ptr = decode_varint32(d, ptr, d->limit, &size);
if (size >= INT32_MAX || (size_t)(d->limit - ptr) < size) {
ptr = decode_varint32(d, ptr, &val.size);
if (val.size >= INT32_MAX || ptr - d->end + val.size > d->limit) {
decode_err(d); /* Length overflow. */
}
val.str_val.data = ptr;
val.str_val.size = size;
ptr += size;
op = delim_ops[ndx];
break;
}
@ -576,7 +641,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
ptr = decode_toarray(d, ptr, msg, layout, field, val, op);
break;
case _UPB_LABEL_MAP:
decode_tomap(d, msg, layout, field, val);
ptr = decode_tomap(d, ptr, msg, layout, field, val);
break;
default:
ptr = decode_tomsg(d, ptr, msg, layout, field, val, op);
@ -590,6 +655,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
ptr = decode_group(d, ptr, NULL, NULL, field_number);
}
if (msg) {
if (wire_type == UPB_WIRE_TYPE_DELIMITED) ptr += val.size;
if (!_upb_msg_addunknown(msg, field_start, ptr - field_start,
&d->arena)) {
decode_err(d);
@ -598,7 +664,6 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg,
}
}
if (ptr != d->limit) decode_err(d);
return ptr;
}
@ -607,9 +672,22 @@ bool upb_decode(const char *buf, size_t size, void *msg, const upb_msglayout *l,
bool ok;
upb_decstate state;
if (size == 0) return true;
if (size == 0) {
return true;
} else if (size < 16) {
memset(&state.patch, 0, 32);
memcpy(&state.patch, buf, size);
buf = state.patch;
state.end = buf + size;
state.limit = 0;
state.alias = false;
} else {
state.end = buf + size - 16;
state.limit = 16;
state.alias = true;
}
state.limit = buf + size;
state.limit_ptr = state.end;
state.depth = 64;
state.end_group = 0;
state.arena.head = arena->head;

Loading…
Cancel
Save