From 382f92a87fb9fb87912e0af2fc8fda702a3ea0b6 Mon Sep 17 00:00:00 2001 From: Joshua Haberman Date: Tue, 10 Dec 2019 16:36:11 -0800 Subject: [PATCH] Maps encode and decode successfully! --- tests/bindings/lua/test_upb.lua | 334 ++++++-------------------------- upb/decode.c | 68 ++++++- upb/def.c | 7 +- upb/encode.c | 9 +- upb/msg.c | 17 -- upb/msg.h | 17 +- upb/reflection.c | 1 - upb/reflection.h | 6 +- upb/upb.h | 2 - upbc/message_layout.cc | 11 +- 10 files changed, 155 insertions(+), 317 deletions(-) diff --git a/tests/bindings/lua/test_upb.lua b/tests/bindings/lua/test_upb.lua index 2e404d0149..07f7ed0f15 100644 --- a/tests/bindings/lua/test_upb.lua +++ b/tests/bindings/lua/test_upb.lua @@ -42,285 +42,6 @@ function test_def_readers() assert_equal(2, e:value("BAZ")) end ---[[ - - -function test_enumdef() -function test_iteration() - -- Test that we cannot crash the process even if we modify the set of fields - -- during iteration. - local md = upb.MessageDef{full_name = "TestMessage"} - - for i=1,10 do - md:add(upb.FieldDef{ - name = "field" .. tostring(i), - number = 1000 - i, - type = upb.TYPE_INT32 - }) - end - - local add = #md - for f in md:fields() do - if add > 0 then - add = add - 1 - for i=10000,11000 do - local field_name = "field" .. tostring(i) - -- We want to add fields to the table to trigger a table resize, - -- but we must skip it if the field name or number already exists - -- otherwise it will raise an error. - if md:field(field_name) == nil and - md:field(i) == nil then - md:add(upb.FieldDef{ - name = field_name, - number = i, - type = upb.TYPE_INT32 - }) - end - end - end - end - - -- Test that iterators don't crash the process even if the MessageDef goes - -- out of scope. - -- - -- Note: have previously verified that this can indeed crash the process if - -- we do not explicitly add a reference from the iterator to the underlying - -- MessageDef. - local iter = md:fields() - md = nil - collectgarbage() - while iter() do - end - - local ed = upb.EnumDef{ - values = { - {"FOO", 1}, - {"BAR", 77}, - } - } - iter = ed:values() - ed = nil - collectgarbage() - while iter() do - end -end - -function test_msgdef_setters() - local md = upb.MessageDef() - md:set_full_name("Message1") - assert_equal("Message1", md:full_name()) - local f = upb.FieldDef{name = "field1", number = 3, type = upb.TYPE_DOUBLE} - md:add(f) - assert_equal(1, #md) - assert_equal(f, md:field("field1")) -end - -function test_msgdef_errors() - assert_error(function() upb.MessageDef{bad_initializer_key = 5} end) - local md = upb.MessageDef() - assert_error(function() - -- Duplicate field number. - upb.MessageDef{ - fields = { - upb.FieldDef{name = "field1", number = 1, type = upb.TYPE_INT32}, - upb.FieldDef{name = "field2", number = 1, type = upb.TYPE_INT32} - } - } - end) - assert_error(function() - -- Duplicate field name. - upb.MessageDef{ - fields = { - upb.FieldDef{name = "field1", number = 1, type = upb.TYPE_INT32}, - upb.FieldDef{name = "field1", number = 2, type = upb.TYPE_INT32} - } - } - end) - - assert_error(function() - -- Duplicate field name. - upb.MessageDef{ - fields = { - upb.OneofDef{name = "field1", fields = { - upb.FieldDef{name = "field2", number = 1, type = upb.TYPE_INT32}, - }}, - upb.FieldDef{name = "field2", number = 2, type = upb.TYPE_INT32} - } - } - end) - - -- attempt to set a name with embedded NULLs. - assert_error_match("names cannot have embedded NULLs", function() - md:set_full_name("abc\0def") - end) - - upb.freeze(md) - -- Attempt to mutate frozen MessageDef. - assert_error_match("frozen", function() - md:add(upb.FieldDef{name = "field1", number = 1, type = upb.TYPE_INT32}) - end) - assert_error_match("frozen", function() - md:set_full_name("abc") - end) - - -- Attempt to freeze a msgdef without freezing its subdef. - assert_error_match("is not frozen or being frozen", function() - m1 = upb.MessageDef() - upb.freeze( - upb.MessageDef{ - fields = { - upb.FieldDef{name = "f1", number = 1, type = upb.TYPE_MESSAGE, - subdef = m1} - } - } - ) - end) -end - -function test_symtab() - local empty = upb.SymbolTable() - assert_equal(0, #iter_to_array(empty:defs(upb.DEF_ANY))) - assert_equal(0, #iter_to_array(empty:defs(upb.DEF_MSG))) - assert_equal(0, #iter_to_array(empty:defs(upb.DEF_ENUM))) - - local symtab = upb.SymbolTable{ - upb.MessageDef{full_name = "TestMessage"}, - upb.MessageDef{full_name = "ContainingMessage", fields = { - upb.FieldDef{name = "field1", number = 1, type = upb.TYPE_INT32}, - upb.FieldDef{name = "field2", number = 2, type = upb.TYPE_MESSAGE, - subdef_name = ".TestMessage"} - } - } - } - - local msgdef1 = symtab:lookup("TestMessage") - local msgdef2 = symtab:lookup("ContainingMessage") - assert_not_nil(msgdef1) - assert_not_nil(msgdef2) - assert_equal(msgdef1, msgdef2:field("field2"):subdef()) - assert_true(msgdef1:is_frozen()) - assert_true(msgdef2:is_frozen()) - - symtab:add{ - upb.MessageDef{full_name = "ContainingMessage2", fields = { - upb.FieldDef{name = "field5", number = 5, type = upb.TYPE_MESSAGE, - subdef = msgdef2} - } - } - } - - local msgdef3 = symtab:lookup("ContainingMessage2") - assert_not_nil(msgdef3) - assert_equal(msgdef3:field("field5"):subdef(), msgdef2) -end - -function test_msg_primitives() - local function test_for_numeric_type(upb_type, val, too_big, too_small, bad3) - local symtab = upb.SymbolTable{ - upb.MessageDef{full_name = "TestMessage", fields = { - upb.FieldDef{name = "f", number = 1, type = upb_type}, - } - } - } - - factory = upb.MessageFactory(symtab) - TestMessage = factory:get_message_class("TestMessage") - msg = TestMessage() - - -- Defaults to zero - assert_equal(0, msg.f) - - msg.f = 0 - assert_equal(0, msg.f) - - msg.f = val - assert_equal(val, msg.f) - - local errmsg = "not an integer or out of range" - if too_small then - assert_error_match(errmsg, function() msg.f = too_small end) - end - if too_big then - assert_error_match(errmsg, function() msg.f = too_big end) - end - if bad3 then - assert_error_match(errmsg, function() msg.f = bad3 end) - end - - -- Can't assign other Lua types. - errmsg = "bad argument #3" - assert_error_match(errmsg, function() msg.f = "abc" end) - assert_error_match(errmsg, function() msg.f = true end) - assert_error_match(errmsg, function() msg.f = false end) - assert_error_match(errmsg, function() msg.f = nil end) - assert_error_match(errmsg, function() msg.f = {} end) - assert_error_match(errmsg, function() msg.f = print end) - assert_error_match(errmsg, function() msg.f = array end) - end - - local symtab = upb.SymbolTable{ - upb.MessageDef{full_name = "TestMessage", fields = { - upb.FieldDef{ - name = "i32", number = 1, type = upb.TYPE_INT32, default = 1}, - upb.FieldDef{ - name = "u32", number = 2, type = upb.TYPE_UINT32, default = 2}, - upb.FieldDef{ - name = "i64", number = 3, type = upb.TYPE_INT64, default = 3}, - upb.FieldDef{ - name = "u64", number = 4, type = upb.TYPE_UINT64, default = 4}, - upb.FieldDef{ - name = "dbl", number = 5, type = upb.TYPE_DOUBLE, default = 5}, - upb.FieldDef{ - name = "flt", number = 6, type = upb.TYPE_FLOAT, default = 6}, - upb.FieldDef{ - name = "bool", number = 7, type = upb.TYPE_BOOL, default = true}, - } - } - } - - factory = upb.MessageFactory(symtab) - TestMessage = factory:get_message_class("TestMessage") - msg = TestMessage() - - -- Unset member returns default value. - -- TODO(haberman): re-enable these when we have descriptor-based reflection. - -- assert_equal(1, msg.i32) - -- assert_equal(2, msg.u32) - -- assert_equal(3, msg.i64) - -- assert_equal(4, msg.u64) - -- assert_equal(5, msg.dbl) - -- assert_equal(6, msg.flt) - -- assert_equal(true, msg.bool) - - -- Attempts to access non-existent fields fail. - assert_error_match("no such field", function() msg.no_such = 1 end) - - msg.i32 = 10 - msg.u32 = 20 - msg.i64 = 30 - msg.u64 = 40 - msg.dbl = 50 - msg.flt = 60 - msg.bool = true - - assert_equal(10, msg.i32) - assert_equal(20, msg.u32) - assert_equal(30, msg.i64) - assert_equal(40, msg.u64) - assert_equal(50, msg.dbl) - assert_equal(60, msg.flt) - assert_equal(true, msg.bool) - - test_for_numeric_type(upb.TYPE_UINT32, 2^32 - 1, 2^32, -1, 5.1) - test_for_numeric_type(upb.TYPE_UINT64, 2^62, 2^64, -1, bad64) - test_for_numeric_type(upb.TYPE_INT32, 2^31 - 1, 2^31, -2^31 - 1, 5.1) - test_for_numeric_type(upb.TYPE_INT64, 2^61, 2^63, -2^64, bad64) - test_for_numeric_type(upb.TYPE_FLOAT, 2^20) - test_for_numeric_type(upb.TYPE_DOUBLE, 10^101) -end - -==]] - function test_msg_map() msg = test_messages_proto3.TestAllTypesProto3() msg.map_int32_int32[5] = 10 @@ -335,6 +56,21 @@ function test_msg_map() assert_equal(12, msg2.map_int32_int32[6]) end +function test_msg_string_map() + msg = test_messages_proto3.TestAllTypesProto3() + msg.map_string_string["foo"] = "bar" + msg.map_string_string["baz"] = "quux" + assert_nil(msg.map_string_string["abc"]) + assert_equal("bar", msg.map_string_string["foo"]) + assert_equal("quux", msg.map_string_string["baz"]) + + local serialized = upb.encode(msg) + assert_true(#serialized > 0) + local msg2 = upb.decode(test_messages_proto3.TestAllTypesProto3, serialized) + assert_equal("bar", msg2.map_string_string["foo"]) + assert_equal("quux", msg2.map_string_string["baz"]) +end + function test_msg_array() msg = test_messages_proto3.TestAllTypesProto3() @@ -472,6 +208,46 @@ local numeric_types = { }, } +function test_msg_primitives() + local msg = test_messages_proto3.TestAllTypesProto3{ + optional_int32 = 10, + optional_uint32 = 20, + optional_int64 = 30, + optional_uint64 = 40, + optional_double = 50, + optional_float = 60, + optional_sint32 = 70, + optional_sint64 = 80, + optional_fixed32 = 90, + optional_fixed64 = 100, + optional_sfixed32 = 110, + optional_sfixed64 = 120, + optional_bool = true, + optional_string = "abc", + optional_nested_message = test_messages_proto3['TestAllTypesProto3.NestedMessage']{a = 123}, + } + + -- Attempts to access non-existent fields fail. + assert_error_match("no such field", function() msg.no_such = 1 end) + + assert_equal(10, msg.optional_int32) + assert_equal(20, msg.optional_uint32) + assert_equal(30, msg.optional_int64) + assert_equal(40, msg.optional_uint64) + assert_equal(50, msg.optional_double) + assert_equal(60, msg.optional_float) + assert_equal(70, msg.optional_sint32) + assert_equal(80, msg.optional_sint64) + assert_equal(90, msg.optional_fixed32) + assert_equal(100, msg.optional_fixed64) + assert_equal(110, msg.optional_sfixed32) + assert_equal(120, msg.optional_sfixed64) + assert_equal(true, msg.optional_bool) + assert_equal("abc", msg.optional_string) + assert_equal(123, msg.optional_nested_message.a) +end + + function test_string_array() local function test_for_string_type(upb_type) local array = upb.Array(upb_type) diff --git a/upb/decode.c b/upb/decode.c index 10342c85ff..0ddc29c5d2 100644 --- a/upb/decode.c +++ b/upb/decode.c @@ -6,7 +6,7 @@ #include "upb/port_def.inc" /* Maps descriptor type -> upb field type. */ -const uint8_t upb_desctype_to_fieldtype[] = { +const uint8_t desctype_to_fieldtype[] = { UPB_WIRE_TYPE_END_GROUP, /* ENDGROUP */ UPB_TYPE_DOUBLE, /* DOUBLE */ UPB_TYPE_FLOAT, /* FLOAT */ @@ -198,7 +198,7 @@ static upb_array *upb_getorcreatearr(upb_decframe *frame, upb_array *arr = upb_getarr(frame, field); if (!arr) { - upb_fieldtype_t type = upb_desctype_to_fieldtype[field->descriptortype]; + upb_fieldtype_t type = desctype_to_fieldtype[field->descriptortype]; arr = upb_array_new(frame->state->arena, type); CHK(arr); *(upb_array**)&frame->msg[field->offset] = arr; @@ -463,6 +463,63 @@ static bool upb_decode_toarray(upb_decstate *d, upb_decframe *frame, UPB_UNREACHABLE(); } +static bool upb_decode_mapfield(upb_decstate *d, upb_decframe *frame, + const upb_msglayout_field *field, int len) { + /* Max map entry size is string key/val. */ + size_t size = sizeof(upb_msg_internal) + (sizeof(upb_strview) * 2); + char submsg[size]; + char *submsg_ptr = &submsg[sizeof(upb_msg_internal)]; + upb_map *map = *(upb_map**)&frame->msg[field->offset]; + upb_alloc *alloc = upb_arena_alloc(d->arena); + const upb_msglayout *entry = frame->layout->submsgs[field->submsg_index]; + upb_value val; + const char *key; + size_t key_size; + + if (!map) { + /* Lazily create map. */ + const upb_msglayout_field *key_field = &entry->fields[0]; + const upb_msglayout_field *val_field = &entry->fields[1]; + upb_fieldtype_t key_type = desctype_to_fieldtype[key_field->descriptortype]; + upb_fieldtype_t val_type = desctype_to_fieldtype[val_field->descriptortype]; + UPB_ASSERT(key_field->number == 1); + UPB_ASSERT(val_field->number == 2); + UPB_ASSERT(key_field->offset == 0); + UPB_ASSERT(val_field->offset == sizeof(upb_strview)); + map = upb_map_new(frame->state->arena, key_type, val_type); + *(upb_map**)&frame->msg[field->offset] = map; + } + + /* Parse map entry. */ + memset(&submsg, 0, size); + CHK(upb_decode_msgfield(d, submsg_ptr, entry, len)); + + /* Insert into map. */ + if (map->key_size_lg2 == UPB_MAPTYPE_STRING) { + const upb_strview* key_view = (const upb_strview*)submsg_ptr; + key = key_view->data; + key_size = key_view->size; + } else { + key = submsg_ptr; + key_size = 1 << map->key_size_lg2; + } + + if (map->val_size_lg2 == UPB_MAPTYPE_STRING) { + upb_strview* val_view = upb_arena_malloc(d->arena, sizeof(*val_view)); + CHK(val_view); + memcpy(val_view, submsg_ptr + sizeof(upb_strview), sizeof(upb_strview)); + memset(&val, 0, sizeof(val)); + memcpy(&val, &val_view, sizeof(void*)); + } else { + memcpy(&val, submsg_ptr + sizeof(upb_strview), 8); + } + + if (!upb_strtable_lookup2(&map->table, key, key_size, NULL)) { + upb_strtable_insert3(&map->table, key, key_size, val, alloc); + } + return true; +} + static bool upb_decode_delimitedfield(upb_decstate *d, upb_decframe *frame, const upb_msglayout_field *field) { int len; @@ -472,12 +529,7 @@ static bool upb_decode_delimitedfield(upb_decstate *d, upb_decframe *frame, if (field->label == UPB_LABEL_REPEATED) { return upb_decode_toarray(d, frame, field, len); } else if (field->label == UPB_LABEL_MAP) { - /* Max map entry size is string key/val. */ - char submsg[sizeof(upb_strview) * 2]; - const upb_msglayout *layout = frame->layout->submsgs[field->submsg_index]; - CHK(upb_decode_msgfield(d, &submsg, layout, len)); - /* TODO: insert into map. */ - return true; + return upb_decode_mapfield(d, frame, field, len); } else { switch (field->descriptortype) { case UPB_DESCRIPTOR_TYPE_STRING: diff --git a/upb/def.c b/upb/def.c index 559179bf97..a4a2200e3a 100644 --- a/upb/def.c +++ b/upb/def.c @@ -868,7 +868,12 @@ static size_t upb_msgval_sizeof(upb_fieldtype_t type) { } static uint8_t upb_msg_fielddefsize(const upb_fielddef *f) { - if (upb_fielddef_isseq(f)) { + if (upb_msgdef_mapentry(upb_fielddef_containingtype(f))) { + // Map entries aren't actually stored, they are only used during parsing. + // For parsing, it helps a lot if all map entry messages have the same + // layout. + return sizeof(upb_strview); + } else if (upb_fielddef_isseq(f)) { return sizeof(void*); } else { return upb_msgval_sizeof(upb_fielddef_type(f)); diff --git a/upb/encode.c b/upb/encode.c index f61427b47d..9314c1d500 100644 --- a/upb/encode.c +++ b/upb/encode.c @@ -330,10 +330,13 @@ static bool upb_encode_map(upb_encstate *e, const char *field_mem, size_t size; upb_strview key = upb_strtable_iter_key(&i); const upb_value val = upb_strtable_iter_value(&i); + const void* keyp = + map->key_size_lg2 == UPB_MAPTYPE_STRING ? (void*)&key : key.data; + const void* valp = + map->val_size_lg2 == UPB_MAPTYPE_STRING ? upb_value_getptr(val) : &val; - /* XXX; string key/value */ - CHK(upb_encode_scalarfield(e, &val, entry, val_field, false)); - CHK(upb_encode_scalarfield(e, key.data, entry, key_field, false)); + CHK(upb_encode_scalarfield(e, valp, entry, val_field, false)); + CHK(upb_encode_scalarfield(e, keyp, entry, key_field, false)); size = (e->limit - e->ptr) - pre_len; CHK(upb_put_varint(e, size)); CHK(upb_put_tag(e, f->number, UPB_WIRE_TYPE_DELIMITED)); diff --git a/upb/msg.c b/upb/msg.c index cd33e0cca6..b27fa41a92 100644 --- a/upb/msg.c +++ b/upb/msg.c @@ -9,23 +9,6 @@ /** upb_msg *******************************************************************/ -/* Internal members of a upb_msg. We can change this without breaking binary - * compatibility. We put these before the user's data. The user's upb_msg* - * points after the upb_msg_internal. */ - -/* Used when a message is not extendable. */ -typedef struct { - char *unknown; - size_t unknown_len; - size_t unknown_size; -} upb_msg_internal; - -/* Used when a message is extendable. */ -typedef struct { - upb_inttable *extdict; - upb_msg_internal base; -} upb_msg_internal_withext; - static char _upb_fieldtype_to_sizelg2[12] = { 0, 0, /* UPB_TYPE_BOOL */ diff --git a/upb/msg.h b/upb/msg.h index 71bbac67c3..6b6dafa19e 100644 --- a/upb/msg.h +++ b/upb/msg.h @@ -55,7 +55,22 @@ typedef struct upb_msglayout { /** upb_msg *******************************************************************/ -/* Representation is in msg.c for now. */ +/* Internal members of a upb_msg. We can change this without breaking binary + * compatibility. We put these before the user's data. The user's upb_msg* + * points after the upb_msg_internal. */ + +/* Used when a message is not extendable. */ +typedef struct { + char *unknown; + size_t unknown_len; + size_t unknown_size; +} upb_msg_internal; + +/* Used when a message is extendable. */ +typedef struct { + upb_inttable *extdict; + upb_msg_internal base; +} upb_msg_internal_withext; /* Maps upb_fieldtype_t -> memory size. */ extern char _upb_fieldtype_to_size[12]; diff --git a/upb/reflection.c b/upb/reflection.c index 756fd7a07a..f1e20e59d4 100644 --- a/upb/reflection.c +++ b/upb/reflection.c @@ -223,7 +223,6 @@ bool upb_map_get(const upb_map *map, upb_msgval key, upb_msgval *val) { ret = upb_strtable_lookup2(&map->table, strkey.data, strkey.size, &tabval); if (ret) { *val = upb_map_fromvalue(map->val_size_lg2, tabval); - memcpy(val, &tabval, sizeof(tabval)); } return ret; diff --git a/upb/reflection.h b/upb/reflection.h index d6825dbc82..58fa7ee7a2 100644 --- a/upb/reflection.h +++ b/upb/reflection.h @@ -1,6 +1,6 @@ -#ifndef UPB_LEGACY_MSG_REFLECTION_H_ -#define UPB_LEGACY_MSG_REFLECTION_H_ +#ifndef UPB_REFLECTION_H_ +#define UPB_REFLECTION_H_ #include "upb/def.h" #include "upb/msg.h" @@ -134,4 +134,4 @@ bool upb_mapiter_isequal(const upb_mapiter *i1, const upb_mapiter *i2); #include "upb/port_undef.inc" -#endif /* UPB_LEGACY_MSG_REFLECTION_H_ */ +#endif /* UPB_REFLECTION_H_ */ diff --git a/upb/upb.h b/upb/upb.h index e4c8adea67..aec97e0611 100644 --- a/upb/upb.h +++ b/upb/upb.h @@ -358,8 +358,6 @@ typedef enum { UPB_DESCRIPTOR_TYPE_SINT64 = 18 } upb_descriptortype_t; -extern const uint8_t upb_desctype_to_fieldtype[]; - #include "upb/port_undef.inc" #endif /* UPB_H_ */ diff --git a/upbc/message_layout.cc b/upbc/message_layout.cc index f0a68725c2..bf3eb2b3bd 100644 --- a/upbc/message_layout.cc +++ b/upbc/message_layout.cc @@ -1,5 +1,6 @@ #include "upbc/message_layout.h" +#include "google/protobuf/descriptor.pb.h" namespace upbc { @@ -25,12 +26,18 @@ MessageLayout::Size MessageLayout::Place( bool MessageLayout::HasHasbit(const protobuf::FieldDescriptor* field) { return field->file()->syntax() == protobuf::FileDescriptor::SYNTAX_PROTO2 && field->label() != protobuf::FieldDescriptor::LABEL_REPEATED && - !field->containing_oneof(); + !field->containing_oneof() && + !field->containing_type()->options().map_entry(); } MessageLayout::SizeAndAlign MessageLayout::SizeOf( const protobuf::FieldDescriptor* field) { - if (field->is_repeated()) { + if (field->containing_type()->options().map_entry()) { + // Map entries aren't actually stored, they are only used during parsing. + // For parsing, it helps a lot if all map entry messages have the same + // layout. + return {{8, 16}, {4, 8}}; // upb_stringview + } else if (field->is_repeated()) { return {{4, 8}, {4, 8}}; // Pointer to array object. } else { return SizeOfUnwrapped(field);