diff --git a/upb/decode.c b/upb/decode.c index 32dda2f3c4..347f964a91 100644 --- a/upb/decode.c +++ b/upb/decode.c @@ -139,10 +139,14 @@ static const int8_t delim_ops[37] = { /* Data pertaining to the parse. */ typedef struct { - const char *limit; /* End of delimited region or end of buffer. */ - upb_arena arena; + const char *end; /* Can read up to 16 bytes slop beyond this. */ + const char *limit_ptr; /* = end + UPB_MIN(limit, 0) */ + int limit; /* Submessage limit relative to end. */ int depth; uint32_t end_group; /* Set to field number of END_GROUP tag, if any. */ + bool alias; + char patch[32]; + upb_arena arena; jmp_buf err; } upb_decstate; @@ -150,7 +154,7 @@ typedef union { bool bool_val; uint32_t uint32_val; uint64_t uint64_val; - upb_strview str_val; + uint32_t size; } wireval; static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, @@ -200,41 +204,48 @@ static bool decode_reserve(upb_decstate *d, upb_array *arr, size_t elem) { return need_realloc; } +typedef struct { + const char *ptr; + uint64_t val; +} decode_vret; + UPB_NOINLINE -static const char *decode_longvarint64(upb_decstate *d, const char *ptr, - const char *limit, uint64_t *val) { - uint8_t byte; - int bitpos = 0; - uint64_t out = 0; - - do { - if (bitpos >= 70 || ptr == limit) decode_err(d); - byte = *ptr; - out |= (uint64_t)(byte & 0x7F) << bitpos; - ptr++; - bitpos += 7; - } while (byte & 0x80); - - *val = out; - return ptr; +static decode_vret decode_longvarint64(const char *ptr, uint64_t val) { + decode_vret ret = {NULL, 0}; + uint64_t byte; + int i; + for (i = 1; i < 10; i++) { + byte = (uint8_t)ptr[i]; + val += (byte - 1) << (i * 7); + if (!(byte & 0x80)) { + ret.ptr = ptr + i + 1; + ret.val = val; + return ret; + } + } + return ret; } UPB_FORCEINLINE static const char *decode_varint64(upb_decstate *d, const char *ptr, - const char *limit, uint64_t *val) { - if (UPB_LIKELY(ptr < limit && (*ptr & 0x80) == 0)) { - *val = (uint8_t)*ptr; + uint64_t *val) { + uint64_t byte = (uint8_t)*ptr; + if (UPB_LIKELY((byte & 0x80) == 0)) { + *val = byte; return ptr + 1; } else { - return decode_longvarint64(d, ptr, limit, val); + decode_vret res = decode_longvarint64(ptr, byte); + if (!res.ptr) decode_err(d); + *val = res.val; + return res.ptr; } } UPB_FORCEINLINE static const char *decode_varint32(upb_decstate *d, const char *ptr, - const char *limit, uint32_t *val) { + uint32_t *val) { uint64_t u64; - ptr = decode_varint64(d, ptr, limit, &u64); + ptr = decode_varint64(d, ptr, &u64); if (u64 > UINT32_MAX) decode_err(d); *val = (uint32_t)u64; return ptr; @@ -287,17 +298,82 @@ static upb_msg *decode_newsubmsg(upb_decstate *d, const upb_msglayout *layout, return _upb_msg_new_inl(subl, &d->arena); } -static void decode_tosubmsg(upb_decstate *d, upb_msg *submsg, - const upb_msglayout *layout, - const upb_msglayout_field *field, upb_strview val) { +static int decode_pushlimit(upb_decstate *d, const char *ptr, int size) { + int limit = size + (int)(ptr - d->end); + int delta = d->limit - limit; + d->limit = limit; + d->limit_ptr = d->end + UPB_MIN(0, limit); + return delta; +} + +static void decode_poplimit(upb_decstate *d, int saved_delta) { + d->limit += saved_delta; + d->limit_ptr = d->end + UPB_MIN(0, d->limit); +} + +typedef struct { + bool ok; + const char *ptr; +} decode_doneret; + +UPB_NOINLINE +static const char *decode_isdonefallback(upb_decstate *d, const char *ptr, + int overrun) { + if (overrun < d->limit) { + /* Need to copy remaining data into patch buffer. */ + UPB_ASSERT(overrun < 16); + memset(d->patch + 16, 0, 16); + memcpy(d->patch, d->end, 16); + ptr = &d->patch[0] + overrun; + d->end = &d->patch[16]; + d->limit -= 16; + d->limit_ptr = d->end + d->limit; + d->alias = false; + UPB_ASSERT(ptr < d->limit_ptr); + return ptr; + } else { + decode_err(d); + } +} + +UPB_FORCEINLINE +static bool decode_isdone(upb_decstate *d, const char **ptr) { + int overrun = *ptr - d->end; + if (UPB_LIKELY(*ptr < d->limit_ptr)) { + return false; + } else if (UPB_LIKELY(overrun == d->limit)) { + return true; + } else { + *ptr = decode_isdonefallback(d, *ptr, overrun); + return false; + } +} + +static const char *decode_readstr(upb_decstate *d, const char *ptr, int size, + upb_strview *str) { + if (d->alias) { + str->data = ptr; + } else { + char *data = upb_arena_malloc(&d->arena, size); + if (!data) decode_err(d); + memcpy(data, ptr, size); + str->data = data; + } + str->size = size; + return ptr + size; +} + +static const char *decode_tosubmsg(upb_decstate *d, const char *ptr, + upb_msg *submsg, const upb_msglayout *layout, + const upb_msglayout_field *field, int size) { const upb_msglayout *subl = layout->submsgs[field->submsg_index]; - const char *saved_limit = d->limit; + int saved_delta = decode_pushlimit(d, ptr, size); if (--d->depth < 0) decode_err(d); - d->limit = val.data + val.size; - decode_msg(d, val.data, submsg, subl); - d->limit = saved_limit; + ptr = decode_msg(d, ptr, submsg, subl); + decode_poplimit(d, saved_delta); if (d->end_group != 0) decode_err(d); d->depth++; + return ptr; } static const char *decode_group(upb_decstate *d, const char *ptr, @@ -345,15 +421,14 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr, memcpy(mem, &val, 1 << op); return ptr; case OP_STRING: - decode_verifyutf8(d, val.str_val.data, val.str_val.size); + decode_verifyutf8(d, ptr, val.size); /* Fallthrough. */ - case OP_BYTES: + case OP_BYTES: { /* Append bytes. */ - mem = - UPB_PTR_AT(_upb_array_ptr(arr), arr->len * sizeof(upb_strview), void); + upb_strview *str = (upb_strview*)_upb_array_ptr(arr) + arr->len; arr->len++; - memcpy(mem, &val, sizeof(upb_strview)); - return ptr; + return decode_readstr(d, ptr, val.size, str); + } case OP_SUBMSG: { /* Append submessage / group. */ upb_msg *submsg = decode_newsubmsg(d, layout, field); @@ -361,26 +436,25 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr, submsg; arr->len++; if (UPB_UNLIKELY(field->descriptortype == UPB_DTYPE_GROUP)) { - ptr = decode_togroup(d, ptr, submsg, layout, field); + return decode_togroup(d, ptr, submsg, layout, field); } else { - decode_tosubmsg(d, submsg, layout, field, val.str_val); + return decode_tosubmsg(d, ptr, submsg, layout, field, val.size); } - return ptr; } case OP_FIXPCK_LG2(2): case OP_FIXPCK_LG2(3): { /* Fixed packed. */ int lg2 = op - OP_FIXPCK_LG2(0); int mask = (1 << lg2) - 1; - size_t count = val.str_val.size >> lg2; - if ((val.str_val.size & mask) != 0) { + size_t count = val.size >> lg2; + if ((val.size & mask) != 0) { decode_err(d); /* Length isn't a round multiple of elem size. */ } decode_reserve(d, arr, count); mem = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void); arr->len += count; - memcpy(mem, val.str_val.data, val.str_val.size); - return ptr; + memcpy(mem, ptr, val.size); /* XXX: ptr boundary. */ + return ptr + val.size; } case OP_VARPCK_LG2(0): case OP_VARPCK_LG2(2): @@ -388,12 +462,11 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr, /* Varint packed. */ int lg2 = op - OP_VARPCK_LG2(0); int scale = 1 << lg2; - const char *ptr = val.str_val.data; - const char *end = ptr + val.str_val.size; + int saved_limit = decode_pushlimit(d, ptr, val.size); char *out = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void); - while (ptr < end) { + while (!decode_isdone(d, &ptr)) { wireval elem; - ptr = decode_varint64(d, ptr, end, &elem.uint64_val); + ptr = decode_varint64(d, ptr, &elem.uint64_val); decode_munge(field->descriptortype, &elem); if (decode_reserve(d, arr, 1)) { out = UPB_PTR_AT(_upb_array_ptr(arr), arr->len << lg2, void); @@ -402,7 +475,7 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr, memcpy(out, &elem, scale); out += scale; } - if (ptr != end) decode_err(d); + decode_poplimit(d, saved_limit); return ptr; } default: @@ -410,9 +483,9 @@ static const char *decode_toarray(upb_decstate *d, const char *ptr, } } -static void decode_tomap(upb_decstate *d, upb_msg *msg, - const upb_msglayout *layout, - const upb_msglayout_field *field, wireval val) { +static const char *decode_tomap(upb_decstate *d, const char *ptr, upb_msg *msg, + const upb_msglayout *layout, + const upb_msglayout_field *field, wireval val) { upb_map **map_p = UPB_PTR_AT(msg, field->offset, upb_map *); upb_map *map = *map_p; upb_map_entry ent; @@ -440,10 +513,9 @@ static void decode_tomap(upb_decstate *d, upb_msg *msg, ent.v.val = upb_value_ptr(_upb_msg_new(entry->submsgs[0], &d->arena)); } - decode_tosubmsg(d, &ent.k, layout, field, val.str_val); - - /* Insert into map. */ + ptr = decode_tosubmsg(d, ptr, &ent.k, layout, field, val.size); _upb_map_set(map, &ent.k, map->key_size, &ent.v, map->val_size, &d->arena); + return ptr; } static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg, @@ -477,16 +549,15 @@ static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg, if (UPB_UNLIKELY(type == UPB_DTYPE_GROUP)) { ptr = decode_togroup(d, ptr, submsg, layout, field); } else { - decode_tosubmsg(d, submsg, layout, field, val.str_val); + ptr = decode_tosubmsg(d, ptr, submsg, layout, field, val.size); } break; } case OP_STRING: - decode_verifyutf8(d, val.str_val.data, val.str_val.size); + decode_verifyutf8(d, ptr, val.size); /* Fallthrough. */ case OP_BYTES: - memcpy(mem, &val, sizeof(upb_strview)); - break; + return decode_readstr(d, ptr, val.size, mem); case OP_SCALAR_LG2(3): memcpy(mem, &val, 8); break; @@ -505,7 +576,7 @@ static const char *decode_tomsg(upb_decstate *d, const char *ptr, upb_msg *msg, static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, const upb_msglayout *layout) { - while (ptr < d->limit) { + while (!decode_isdone(d, &ptr)) { uint32_t tag; const upb_msglayout_field *field; int field_number; @@ -514,7 +585,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, wireval val; int op; - ptr = decode_varint32(d, ptr, d->limit, &tag); + ptr = decode_varint32(d, ptr, &tag); field_number = tag >> 3; wire_type = tag & 7; @@ -522,12 +593,11 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, switch (wire_type) { case UPB_WIRE_TYPE_VARINT: - ptr = decode_varint64(d, ptr, d->limit, &val.uint64_val); + ptr = decode_varint64(d, ptr, &val.uint64_val); op = varint_ops[field->descriptortype]; decode_munge(field->descriptortype, &val); break; case UPB_WIRE_TYPE_32BIT: - if (d->limit - ptr < 4) decode_err(d); memcpy(&val.uint32_val, ptr, 4); val.uint32_val = _upb_be_swap32(val.uint32_val); ptr += 4; @@ -535,7 +605,6 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, if (((1 << field->descriptortype) & fixed32_ok) == 0) goto unknown; break; case UPB_WIRE_TYPE_64BIT: - if (d->limit - ptr < 8) decode_err(d); memcpy(&val.uint64_val, ptr, 8); val.uint64_val = _upb_be_swap64(val.uint64_val); ptr += 8; @@ -543,16 +612,12 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, if (((1 << field->descriptortype) & fixed64_ok) == 0) goto unknown; break; case UPB_WIRE_TYPE_DELIMITED: { - uint32_t size; int ndx = field->descriptortype; if (_upb_isrepeated(field)) ndx += 18; - ptr = decode_varint32(d, ptr, d->limit, &size); - if (size >= INT32_MAX || (size_t)(d->limit - ptr) < size) { + ptr = decode_varint32(d, ptr, &val.size); + if (val.size >= INT32_MAX || ptr - d->end + val.size > d->limit) { decode_err(d); /* Length overflow. */ } - val.str_val.data = ptr; - val.str_val.size = size; - ptr += size; op = delim_ops[ndx]; break; } @@ -576,7 +641,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, ptr = decode_toarray(d, ptr, msg, layout, field, val, op); break; case _UPB_LABEL_MAP: - decode_tomap(d, msg, layout, field, val); + ptr = decode_tomap(d, ptr, msg, layout, field, val); break; default: ptr = decode_tomsg(d, ptr, msg, layout, field, val, op); @@ -590,6 +655,7 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, ptr = decode_group(d, ptr, NULL, NULL, field_number); } if (msg) { + if (wire_type == UPB_WIRE_TYPE_DELIMITED) ptr += val.size; if (!_upb_msg_addunknown(msg, field_start, ptr - field_start, &d->arena)) { decode_err(d); @@ -598,7 +664,6 @@ static const char *decode_msg(upb_decstate *d, const char *ptr, upb_msg *msg, } } - if (ptr != d->limit) decode_err(d); return ptr; } @@ -607,9 +672,22 @@ bool upb_decode(const char *buf, size_t size, void *msg, const upb_msglayout *l, bool ok; upb_decstate state; - if (size == 0) return true; + if (size == 0) { + return true; + } else if (size < 16) { + memset(&state.patch, 0, 32); + memcpy(&state.patch, buf, size); + buf = state.patch; + state.end = buf + size; + state.limit = 0; + state.alias = false; + } else { + state.end = buf + size - 16; + state.limit = 16; + state.alias = true; + } - state.limit = buf + size; + state.limit_ptr = state.end; state.depth = 64; state.end_group = 0; state.arena.head = arena->head;