diff --git a/BUILD b/BUILD index 7c1707d1ab..ceeae4fab9 100644 --- a/BUILD +++ b/BUILD @@ -60,7 +60,6 @@ cc_library( srcs = [ "upb/decode.c", "upb/encode.c", - "upb/generated_util.h", "upb/msg.c", "upb/msg.h", "upb/port.c", @@ -91,7 +90,6 @@ cc_library( cc_library( name = "generated_code_support__only_for_generated_code_do_not_use__i_give_permission_to_break_me", hdrs = [ - "upb/generated_util.h", "upb/msg.h", ], copts = select({ diff --git a/CMakeLists.txt b/CMakeLists.txt index 1235bdb098..a88176389f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,7 +63,6 @@ enable_testing() add_library(upb upb/decode.c upb/encode.c - upb/generated_util.h upb/msg.c upb/msg.h upb/port.c diff --git a/generated_for_cmake/google/protobuf/descriptor.upb.h b/generated_for_cmake/google/protobuf/descriptor.upb.h index 33e020011d..5baad08810 100644 --- a/generated_for_cmake/google/protobuf/descriptor.upb.h +++ b/generated_for_cmake/google/protobuf/descriptor.upb.h @@ -9,7 +9,6 @@ #ifndef GOOGLE_PROTOBUF_DESCRIPTOR_PROTO_UPB_H_ #define GOOGLE_PROTOBUF_DESCRIPTOR_PROTO_UPB_H_ -#include "upb/generated_util.h" #include "upb/msg.h" #include "upb/decode.h" #include "upb/encode.h" diff --git a/tests/bindings/lua/test_upb.lua b/tests/bindings/lua/test_upb.lua index 07f7ed0f15..25ec0c1721 100644 --- a/tests/bindings/lua/test_upb.lua +++ b/tests/bindings/lua/test_upb.lua @@ -49,6 +49,18 @@ function test_msg_map() assert_equal(10, msg.map_int32_int32[5]) assert_equal(12, msg.map_int32_int32[6]) + -- Test overwrite. + msg.map_int32_int32[5] = 20 + assert_equal(20, msg.map_int32_int32[5]) + assert_equal(12, msg.map_int32_int32[6]) + msg.map_int32_int32[5] = 10 + + -- Test delete. + msg.map_int32_int32[5] = nil + assert_nil(msg.map_int32_int32[5]) + assert_equal(12, msg.map_int32_int32[6]) + msg.map_int32_int32[5] = 10 + local serialized = upb.encode(msg) assert_true(#serialized > 0) local msg2 = upb.decode(test_messages_proto3.TestAllTypesProto3, serialized) @@ -64,6 +76,18 @@ function test_msg_string_map() assert_equal("bar", msg.map_string_string["foo"]) assert_equal("quux", msg.map_string_string["baz"]) + -- Test overwrite. + msg.map_string_string["foo"] = "123" + assert_equal("123", msg.map_string_string["foo"]) + assert_equal("quux", msg.map_string_string["baz"]) + msg.map_string_string["foo"] = "bar" + + -- Test delete + msg.map_string_string["foo"] = nil + assert_nil(msg.map_string_string["foo"]) + assert_equal("quux", msg.map_string_string["baz"]) + msg.map_string_string["foo"] = "bar" + local serialized = upb.encode(msg) assert_true(#serialized > 0) local msg2 = upb.decode(test_messages_proto3.TestAllTypesProto3, serialized) diff --git a/tests/bindings/lua/test_upb.pb.lua b/tests/bindings/lua/test_upb.pb.lua deleted file mode 100644 index ea6de09989..0000000000 --- a/tests/bindings/lua/test_upb.pb.lua +++ /dev/null @@ -1,80 +0,0 @@ - --- Require "pb" first to ensure that the transitive require of "upb" is --- handled properly by the "pb" module. -local pb = require "upb.pb" -local upb = require "upb" -local lunit = require "lunit" - -if _VERSION >= 'Lua 5.2' then - _ENV = lunit.module("testupb_pb", "seeall") -else - module("testupb_pb", lunit.testcase, package.seeall) -end - -local symtab = upb.SymbolTable{ - upb.MessageDef{full_name = "TestMessage", fields = { - upb.FieldDef{name = "i32", number = 1, type = upb.TYPE_INT32}, - upb.FieldDef{name = "u32", number = 2, type = upb.TYPE_UINT32}, - upb.FieldDef{name = "i64", number = 3, type = upb.TYPE_INT64}, - upb.FieldDef{name = "u64", number = 4, type = upb.TYPE_UINT64}, - upb.FieldDef{name = "dbl", number = 5, type = upb.TYPE_DOUBLE}, - upb.FieldDef{name = "flt", number = 6, type = upb.TYPE_FLOAT}, - upb.FieldDef{name = "bool", number = 7, type = upb.TYPE_BOOL}, - } - } -} - -local factory = upb.MessageFactory(symtab); -local TestMessage = factory:get_message_class("TestMessage") - -function test_parse_primitive() - local binary_pb = - "\008\128\128\128\128\002\016\128\128\128\128\004\024\128\128" - .. "\128\128\128\128\128\002\032\128\128\128\128\128\128\128\001\041\000" - .. "\000\000\000\000\000\248\063\053\000\000\096\064\056\001" - local msg = TestMessage() - pb.decode(msg, binary_pb) - assert_equal(536870912, msg.i32) - assert_equal(1073741824, msg.u32) - assert_equal(1125899906842624, msg.i64) - assert_equal(562949953421312, msg.u64) - assert_equal(1.5, msg.dbl) - assert_equal(3.5, msg.flt) - assert_equal(true, msg.bool) - - local encoded = pb.encode(msg) - local msg2 = TestMessage() - pb.decode(msg2, encoded) - assert_equal(536870912, msg.i32) - assert_equal(1073741824, msg.u32) - assert_equal(1125899906842624, msg.i64) - assert_equal(562949953421312, msg.u64) - assert_equal(1.5, msg.dbl) - assert_equal(3.5, msg.flt) - assert_equal(true, msg.bool) -end - -function test_parse_string() - local symtab = upb.SymbolTable{ - upb.MessageDef{full_name = "TestMessage", fields = { - upb.FieldDef{name = "str", number = 1, type = upb.TYPE_STRING}, - } - } - } - - local factory = upb.MessageFactory(symtab); - local TestMessage = factory:get_message_class("TestMessage") - - local binary_pb = "\010\005Hello" - msg = TestMessage() - pb.decode(msg, binary_pb) - -- TODO(haberman): re-enable when this stuff works better. - -- assert_equal("Hello", msg.str) -end - - -local stats = lunit.main() - -if stats.failed > 0 or stats.errors > 0 then - error("One or more errors in test suite") -end diff --git a/upb/bindings/lua/msg.c b/upb/bindings/lua/msg.c index b611cee8e6..aee9238e68 100644 --- a/upb/bindings/lua/msg.c +++ b/upb/bindings/lua/msg.c @@ -592,7 +592,7 @@ static int lupb_map_newindex(lua_State *L) { upb_msgval key = lupb_tomsgval(L, lmap->key_type, 2, 1, LUPB_REF); if (lua_isnil(L, 3)) { - upb_map_delete(map, key, lupb_arenaget(L, 1)); + upb_map_delete(map, key); } else { upb_msgval val = lupb_tomsgval(L, lmap->value_type, 3, 1, LUPB_COPY); upb_map_set(map, key, val, lupb_arenaget(L, 1)); @@ -603,18 +603,19 @@ static int lupb_map_newindex(lua_State *L) { static int lupb_mapiter_next(lua_State *L) { int map = lua_upvalueindex(2); - upb_mapiter *i = lua_touserdata(L, lua_upvalueindex(1)); + size_t *iter = lua_touserdata(L, lua_upvalueindex(1)); lupb_map *lmap = lupb_map_check(L, map); - if (upb_mapiter_done(i)) { + if (upb_mapiter_next(lmap->map, iter)) { + upb_msgval key = upb_mapiter_key(lmap->map, *iter); + upb_msgval val = upb_mapiter_value(lmap->map, *iter); + lupb_pushmsgval(L, map, lmap->key_type, key); + lupb_pushmsgval(L, map, lmap->value_type, val); + return 2; + } else { return 0; } - lupb_pushmsgval(L, map, lmap->key_type, upb_mapiter_key(i)); - lupb_pushmsgval(L, map, lmap->value_type, upb_mapiter_value(i)); - upb_mapiter_next(i); - - return 2; } /** @@ -624,13 +625,13 @@ static int lupb_mapiter_next(lua_State *L) { * pairs(map) */ static int lupb_map_pairs(lua_State *L) { - lupb_map *lmap = lupb_map_check(L, 1); - upb_mapiter *i = lua_newuserdata(L, upb_mapiter_sizeof()); + lupb_map_check(L, 1); + size_t *iter = lua_newuserdata(L, sizeof(*iter)); - upb_mapiter_begin(i, lmap->map); + *iter = UPB_MAP_BEGIN; lua_pushvalue(L, 1); - /* Upvalues are [upb_mapiter, lupb_map]. */ + /* Upvalues are [iter, lupb_map]. */ lua_pushcclosure(L, &lupb_mapiter_next, 2); return 1; diff --git a/upb/decode.c b/upb/decode.c index dedf60c3bc..042045da2d 100644 --- a/upb/decode.c +++ b/upb/decode.c @@ -28,6 +28,29 @@ const uint8_t desctype_to_fieldtype[] = { UPB_TYPE_INT64, /* SINT64 */ }; +/* Maps descriptor type -> upb map size. */ +const uint8_t desctype_to_mapsize[] = { + UPB_WIRE_TYPE_END_GROUP, /* ENDGROUP */ + 8, /* DOUBLE */ + 4, /* FLOAT */ + 8, /* INT64 */ + 8, /* UINT64 */ + 4, /* INT32 */ + 8, /* FIXED64 */ + 4, /* FIXED32 */ + 1, /* BOOL */ + UPB_MAPTYPE_STRING, /* STRING */ + sizeof(void*), /* GROUP */ + sizeof(void*), /* MESSAGE */ + UPB_MAPTYPE_STRING, /* BYTES */ + 4, /* UINT32 */ + 4, /* ENUM */ + 4, /* SFIXED32 */ + 8, /* SFIXED64 */ + 4, /* SINT32 */ + 8, /* SINT64 */ +}; + /* Data pertaining to the parse. */ typedef struct { const char *ptr; /* Current parsing position. */ @@ -464,10 +487,7 @@ static bool upb_decode_toarray(upb_decstate *d, upb_decframe *frame, static bool upb_decode_mapfield(upb_decstate *d, upb_decframe *frame, const upb_msglayout_field *field, int len) { upb_map *map = *(upb_map**)&frame->msg[field->offset]; - upb_alloc *alloc = upb_arena_alloc(d->arena); const upb_msglayout *entry = frame->layout->submsgs[field->submsg_index]; - upb_strview key; - upb_strtable *t; /* The compiler ensures that all map entry messages have this layout. */ struct map_entry { @@ -486,13 +506,13 @@ static bool upb_decode_mapfield(upb_decstate *d, upb_decframe *frame, /* Lazily create map. */ const upb_msglayout_field *key_field = &entry->fields[0]; const upb_msglayout_field *val_field = &entry->fields[1]; - upb_fieldtype_t key_type = desctype_to_fieldtype[key_field->descriptortype]; - upb_fieldtype_t val_type = desctype_to_fieldtype[val_field->descriptortype]; + char key_size = desctype_to_mapsize[key_field->descriptortype]; + char val_size = desctype_to_mapsize[val_field->descriptortype]; UPB_ASSERT(key_field->number == 1); UPB_ASSERT(val_field->number == 2); UPB_ASSERT(key_field->offset == 0); UPB_ASSERT(val_field->offset == sizeof(upb_strview)); - map = upb_map_new(frame->state->arena, key_type, val_type); + map = _upb_map_new(frame->state->arena, key_size, val_size); *(upb_map**)&frame->msg[field->offset] = map; } @@ -501,25 +521,7 @@ static bool upb_decode_mapfield(upb_decstate *d, upb_decframe *frame, CHK(upb_decode_msgfield(d, &ent.k, entry, len)); /* Insert into map. */ - t = &map->table; - - if (map->key_size_lg2 == UPB_MAPTYPE_STRING) { - key = ent.k.str; - } else { - key.data = (const char*)&ent.k; - key.size = 1 << map->key_size_lg2; - } - - if (map->val_size_lg2 == UPB_MAPTYPE_STRING) { - upb_strview* val_view = upb_arena_malloc(d->arena, sizeof(*val_view)); - CHK(val_view); - *val_view = ent.v.str; - ent.v.val = upb_value_ptr(val_view); - } - - /* Have to remove first, since upb's table won't overwrite. */ - upb_strtable_remove3(t, key.data, key.size, NULL, alloc); - upb_strtable_insert3(t, key.data, key.size, ent.v.val, alloc); + _upb_map_set(map, &ent.k, map->key_size, &ent.v, map->val_size, d->arena); return true; } diff --git a/upb/encode.c b/upb/encode.c index 9314c1d500..d9adbff596 100644 --- a/upb/encode.c +++ b/upb/encode.c @@ -330,10 +330,10 @@ static bool upb_encode_map(upb_encstate *e, const char *field_mem, size_t size; upb_strview key = upb_strtable_iter_key(&i); const upb_value val = upb_strtable_iter_value(&i); - const void* keyp = - map->key_size_lg2 == UPB_MAPTYPE_STRING ? (void*)&key : key.data; - const void* valp = - map->val_size_lg2 == UPB_MAPTYPE_STRING ? upb_value_getptr(val) : &val; + const void *keyp = + map->key_size == UPB_MAPTYPE_STRING ? (void *)&key : key.data; + const void *valp = + map->val_size == UPB_MAPTYPE_STRING ? upb_value_getptr(val) : &val; CHK(upb_encode_scalarfield(e, valp, entry, val_field, false)); CHK(upb_encode_scalarfield(e, keyp, entry, key_field, false)); diff --git a/upb/generated_util.h b/upb/generated_util.h deleted file mode 100644 index ac01a27b5c..0000000000 --- a/upb/generated_util.h +++ /dev/null @@ -1,89 +0,0 @@ -/* -** Functions for use by generated code. These are not public and users must -** not call them directly. -*/ - -#ifndef UPB_GENERATED_UTIL_H_ -#define UPB_GENERATED_UTIL_H_ - -#include -#include "upb/msg.h" - -#include "upb/port_def.inc" - -#define PTR_AT(msg, ofs, type) (type*)((const char*)msg + ofs) - -UPB_INLINE const void *_upb_array_accessor(const void *msg, size_t ofs, - size_t *size) { - const upb_array *arr = *PTR_AT(msg, ofs, const upb_array*); - if (arr) { - if (size) *size = arr->len; - return _upb_array_constptr(arr); - } else { - if (size) *size = 0; - return NULL; - } -} - -UPB_INLINE void *_upb_array_mutable_accessor(void *msg, size_t ofs, - size_t *size) { - upb_array *arr = *PTR_AT(msg, ofs, upb_array*); - if (arr) { - if (size) *size = arr->len; - return _upb_array_ptr(arr); - } else { - if (size) *size = 0; - return NULL; - } -} - -UPB_INLINE void *_upb_array_resize_accessor(void *msg, size_t ofs, size_t size, - upb_fieldtype_t type, - upb_arena *arena) { - upb_array **arr_ptr = PTR_AT(msg, ofs, upb_array*); - upb_array *arr = *arr_ptr; - if (!arr || arr->size < size) { - return _upb_array_resize_fallback(arr_ptr, size, type, arena); - } - arr->len = size; - return _upb_array_ptr(arr); -} - -UPB_INLINE bool _upb_array_append_accessor(void *msg, size_t ofs, - size_t elem_size, - upb_fieldtype_t type, - const void *value, - upb_arena *arena) { - upb_array **arr_ptr = PTR_AT(msg, ofs, upb_array*); - upb_array *arr = *arr_ptr; - void* ptr; - if (!arr || arr->len == arr->size) { - return _upb_array_append_fallback(arr_ptr, value, type, arena); - } - ptr = _upb_array_ptr(arr); - memcpy(PTR_AT(ptr, arr->len * elem_size, char), value, elem_size); - arr->len++; - return true; -} - -UPB_INLINE bool _upb_has_field(const void *msg, size_t idx) { - return (*PTR_AT(msg, idx / 8, const char) & (1 << (idx % 8))) != 0; -} - -UPB_INLINE bool _upb_sethas(const void *msg, size_t idx) { - return (*PTR_AT(msg, idx / 8, char)) |= (char)(1 << (idx % 8)); -} - -UPB_INLINE bool _upb_clearhas(const void *msg, size_t idx) { - return (*PTR_AT(msg, idx / 8, char)) &= (char)(~(1 << (idx % 8))); -} - -UPB_INLINE bool _upb_has_oneof_field(const void *msg, size_t case_ofs, int32_t num) { - return *PTR_AT(msg, case_ofs, int32_t) == num; -} - -#undef PTR_AT - -#include "upb/port_undef.inc" - -#endif /* UPB_GENERATED_UTIL_H_ */ diff --git a/upb/msg.c b/upb/msg.c index b27fa41a92..bd3d347d58 100644 --- a/upb/msg.c +++ b/upb/msg.c @@ -24,22 +24,6 @@ static char _upb_fieldtype_to_sizelg2[12] = { UPB_SIZE(3, 4), /* UPB_TYPE_BYTES */ }; -/* Strings/bytes are special-cased in maps. */ -static char _upb_fieldtype_to_mapsizelg2[12] = { - 0, - 0, /* UPB_TYPE_BOOL */ - 2, /* UPB_TYPE_FLOAT */ - 2, /* UPB_TYPE_INT32 */ - 2, /* UPB_TYPE_UINT32 */ - 2, /* UPB_TYPE_ENUM */ - UPB_SIZE(2, 3), /* UPB_TYPE_MESSAGE */ - 3, /* UPB_TYPE_DOUBLE */ - 3, /* UPB_TYPE_INT64 */ - 3, /* UPB_TYPE_UINT64 */ - UPB_MAPTYPE_STRING, /* UPB_TYPE_STRING */ - UPB_MAPTYPE_STRING, /* UPB_TYPE_BYTES */ -}; - static uintptr_t tag_arrptr(void* ptr, int elem_size_lg2) { UPB_ASSERT(elem_size_lg2 <= 4); return (uintptr_t)ptr | elem_size_lg2; @@ -172,8 +156,7 @@ void *_upb_array_resize_fallback(upb_array **arr_ptr, size_t size, /** upb_map *******************************************************************/ -upb_map *upb_map_new(upb_arena *a, upb_fieldtype_t key_type, - upb_fieldtype_t value_type) { +upb_map *_upb_map_new(upb_arena *a, size_t key_size, size_t value_size) { upb_map *map = upb_arena_malloc(a, sizeof(upb_map)); if (!map) { @@ -181,11 +164,10 @@ upb_map *upb_map_new(upb_arena *a, upb_fieldtype_t key_type, } upb_strtable_init2(&map->table, UPB_CTYPE_INT32, upb_arena_alloc(a)); - map->key_size_lg2 = _upb_fieldtype_to_mapsizelg2[key_type]; - map->val_size_lg2 = _upb_fieldtype_to_mapsizelg2[value_type]; + map->key_size = key_size; + map->val_size = value_size; return map; } - #undef VOIDPTR_AT diff --git a/upb/msg.h b/upb/msg.h index 6b6dafa19e..da226f917a 100644 --- a/upb/msg.h +++ b/upb/msg.h @@ -1,6 +1,6 @@ /* ** Our memory representation for parsing tables and messages themselves. -** Functions in this file are used by generated code and possible reflection. +** Functions in this file are used by generated code and possibly reflection. ** ** The definitions in this file are internal to upb. **/ @@ -20,6 +20,8 @@ extern "C" { #endif +#define PTR_AT(msg, ofs, type) (type*)((const char*)msg + ofs) + typedef void upb_msg; /** upb_msglayout *************************************************************/ @@ -86,6 +88,22 @@ void upb_msg_addunknown(upb_msg *msg, const char *data, size_t len, /* Returns a reference to the message's unknown data. */ const char *upb_msg_getunknown(const upb_msg *msg, size_t *len); +UPB_INLINE bool _upb_has_field(const void *msg, size_t idx) { + return (*PTR_AT(msg, idx / 8, const char) & (1 << (idx % 8))) != 0; +} + +UPB_INLINE bool _upb_sethas(const void *msg, size_t idx) { + return (*PTR_AT(msg, idx / 8, char)) |= (char)(1 << (idx % 8)); +} + +UPB_INLINE bool _upb_clearhas(const void *msg, size_t idx) { + return (*PTR_AT(msg, idx / 8, char)) &= (char)(~(1 << (idx % 8))); +} + +UPB_INLINE bool _upb_has_oneof_field(const void *msg, size_t case_ofs, int32_t num) { + return *PTR_AT(msg, case_ofs, int32_t) == num; +} + /** upb_array *****************************************************************/ /* Our internal representation for repeated fields. */ @@ -115,21 +133,243 @@ void *_upb_array_resize_fallback(upb_array **arr_ptr, size_t size, bool _upb_array_append_fallback(upb_array **arr_ptr, const void *value, upb_fieldtype_t type, upb_arena *arena); +UPB_INLINE const void *_upb_array_accessor(const void *msg, size_t ofs, + size_t *size) { + const upb_array *arr = *PTR_AT(msg, ofs, const upb_array*); + if (arr) { + if (size) *size = arr->len; + return _upb_array_constptr(arr); + } else { + if (size) *size = 0; + return NULL; + } +} + +UPB_INLINE void *_upb_array_mutable_accessor(void *msg, size_t ofs, + size_t *size) { + upb_array *arr = *PTR_AT(msg, ofs, upb_array*); + if (arr) { + if (size) *size = arr->len; + return _upb_array_ptr(arr); + } else { + if (size) *size = 0; + return NULL; + } +} + +UPB_INLINE void *_upb_array_resize_accessor(void *msg, size_t ofs, size_t size, + upb_fieldtype_t type, + upb_arena *arena) { + upb_array **arr_ptr = PTR_AT(msg, ofs, upb_array*); + upb_array *arr = *arr_ptr; + if (!arr || arr->size < size) { + return _upb_array_resize_fallback(arr_ptr, size, type, arena); + } + arr->len = size; + return _upb_array_ptr(arr); +} + + +UPB_INLINE bool _upb_array_append_accessor(void *msg, size_t ofs, + size_t elem_size, + upb_fieldtype_t type, + const void *value, + upb_arena *arena) { + upb_array **arr_ptr = PTR_AT(msg, ofs, upb_array*); + upb_array *arr = *arr_ptr; + void* ptr; + if (!arr || arr->len == arr->size) { + return _upb_array_append_fallback(arr_ptr, value, type, arena); + } + ptr = _upb_array_ptr(arr); + memcpy(PTR_AT(ptr, arr->len * elem_size, char), value, elem_size); + arr->len++; + return true; +} + /** upb_map *******************************************************************/ /* Right now we use strmaps for everything. We'll likely want to use * integer-specific maps for integer-keyed maps.*/ typedef struct { - /* We should pack these better and move them into table to avoid padding. */ - char key_size_lg2; - char val_size_lg2; + /* Size of key and val, based on the map type. Strings are represented as '0' + * because they must be handled specially. */ + char key_size; + char val_size; upb_strtable table; } upb_map; /* Creates a new map on the given arena with this key/value type. */ -upb_map *upb_map_new(upb_arena *a, upb_fieldtype_t key_type, - upb_fieldtype_t value_type); +upb_map *_upb_map_new(upb_arena *a, size_t key_size, size_t value_size); + +/* Converting between internal table representation and user values. + * + * _upb_map_tokey() and _upb_map_fromkey() are inverses. + * _upb_map_tovalue() and _upb_map_fromvalue() are inverses. + * + * These functions account for the fact that strings are treated differently + * from other types when stored in a map. + */ + +UPB_INLINE upb_strview _upb_map_tokey(const void *key, size_t size) { + if (size == UPB_MAPTYPE_STRING) { + return *(upb_strview*)key; + } else { + return upb_strview_make((const char*)key, size); + } +} + +UPB_INLINE void _upb_map_fromkey(upb_strview key, void* out, size_t size) { + if (size == UPB_MAPTYPE_STRING) { + memcpy(out, &key, sizeof(key)); + } else { + memcpy(out, key.data, size); + } +} + +UPB_INLINE upb_value _upb_map_tovalue(const void *val, size_t size, + upb_arena *a) { + upb_value ret = {0}; + if (size == UPB_MAPTYPE_STRING) { + upb_strview *strp = (upb_strview*)upb_arena_malloc(a, sizeof(*strp)); + *strp = *(upb_strview*)val; + memcpy(&ret, &strp, sizeof(strp)); + } else { + memcpy(&ret, val, size); + } + return ret; +} + +UPB_INLINE void _upb_map_fromvalue(upb_value val, void* out, size_t size) { + if (size == UPB_MAPTYPE_STRING) { + const upb_strview *strp = (const upb_strview*)upb_value_getptr(val); + memcpy(out, strp, sizeof(upb_strview)); + } else { + memcpy(out, &val, size); + } +} + +/* Map operations, shared by reflection and generated code. */ + +UPB_INLINE size_t _upb_map_size(const upb_map *map) { + return map->table.t.count; +} + +UPB_INLINE bool _upb_map_get(const upb_map *map, const void *key, + size_t key_size, void *val, size_t val_size) { + upb_value tabval; + upb_strview k = _upb_map_tokey(key, key_size); + bool ret = upb_strtable_lookup2(&map->table, k.data, k.size, &tabval); + if (ret) { + _upb_map_fromvalue(tabval, val, val_size); + } + return ret; +} + +UPB_INLINE void* _upb_map_next(const upb_map *map, size_t *iter) { + upb_strtable_iter it = {&map->table, *iter}; + upb_strtable_next(&it); + if (upb_strtable_done(&it)) return NULL; + *iter = it.index; + return (void*)str_tabent(&it); +} + +UPB_INLINE bool _upb_map_set(upb_map *map, const void *key, size_t key_size, + void *val, size_t val_size, upb_arena *arena) { + upb_strview strkey = _upb_map_tokey(key, key_size); + upb_value tabval = _upb_map_tovalue(val, val_size, arena); + upb_alloc *a = upb_arena_alloc(arena); + + /* TODO(haberman): add overwrite operation to minimize number of lookups. */ + upb_strtable_remove3(&map->table, strkey.data, strkey.size, NULL, a); + return upb_strtable_insert3(&map->table, strkey.data, strkey.size, tabval, a); +} + +UPB_INLINE bool _upb_map_delete(upb_map *map, const void *key, size_t key_size) { + upb_strview k = _upb_map_tokey(key, key_size); + return upb_strtable_remove3(&map->table, k.data, k.size, NULL, NULL); +} + +UPB_INLINE void _upb_map_clear(upb_map *map) { + upb_strtable_clear(&map->table); +} + +/* Message map operations, these get the map from the message first. */ + +UPB_INLINE size_t _upb_msg_map_size(const upb_msg *msg, size_t ofs) { + upb_map *map = UPB_FIELD_AT(msg, upb_map *, ofs); + return map ? _upb_map_size(map) : 0; +} + +UPB_INLINE bool _upb_msg_map_get(const upb_msg *msg, size_t ofs, + const void *key, size_t key_size, void *val, + size_t val_size) { + upb_map *map = UPB_FIELD_AT(msg, upb_map *, ofs); + if (!map) return false; + return _upb_map_get(map, key, key_size, val, val_size); +} + +UPB_INLINE void *_upb_msg_map_next(const upb_msg *msg, size_t ofs, + size_t *iter) { + upb_map *map = UPB_FIELD_AT(msg, upb_map *, ofs); + if (!map) return NULL; + return _upb_map_next(map, iter); +} + +UPB_INLINE bool _upb_msg_map_set(upb_msg *msg, size_t ofs, const void *key, + size_t key_size, void *val, size_t val_size, + upb_arena *arena) { + upb_map **map = PTR_AT(msg, ofs, upb_map *); + if (!*map) { + *map = _upb_map_new(arena, key_size, val_size); + } + return _upb_map_set(*map, key, key_size, val, val_size, arena); +} + +UPB_INLINE bool _upb_msg_map_delete(upb_msg *msg, size_t ofs, const void *key, + size_t key_size) { + upb_map *map = UPB_FIELD_AT(msg, upb_map *, ofs); + if (!map) return false; + return _upb_map_delete(map, key, key_size); +} + +UPB_INLINE void _upb_msg_map_clear(upb_msg *msg, size_t ofs) { + upb_map *map = UPB_FIELD_AT(msg, upb_map *, ofs); + if (!map) return; + _upb_map_clear(map); +} + +/* Accessing map key/value from a pointer, used by generated code only. */ + +UPB_INLINE void _upb_msg_map_key(const void* msg, void* key, size_t size) { + const upb_tabent *ent = (const upb_tabent*)msg; + uint32_t u32len; + upb_strview k = {upb_tabstr(ent->key, &u32len)}; + k.size = u32len; + _upb_map_fromkey(k, key, size); +} + +UPB_INLINE void _upb_msg_map_value(const void* msg, void* val, size_t size) { + const upb_tabent *ent = (const upb_tabent*)msg; + upb_value v; + _upb_value_setval(&v, ent->val.val); + _upb_map_fromvalue(v, val, size); +} + +UPB_INLINE void _upb_msg_map_set_value(void* msg, const void* val, size_t size) { + upb_tabent *ent = (upb_tabent*)msg; + /* This is like _upb_map_tovalue() except the entry already exists so we can + * reuse the allocated upb_strview for string fields. */ + if (size == UPB_MAPTYPE_STRING) { + upb_strview *strp = (upb_strview*)ent->val.val; + memcpy(strp, val, sizeof(*strp)); + } else { + memcpy(&ent->val.val, val, size); + } +} + +#undef PTR_AT #ifdef __cplusplus } /* extern "C" */ diff --git a/upb/port_def.inc b/upb/port_def.inc index 138c7a0c4f..a8e5070695 100644 --- a/upb/port_def.inc +++ b/upb/port_def.inc @@ -45,7 +45,7 @@ UPB_FIELD_AT(msg, int, case_offset) = case_val; \ UPB_FIELD_AT(msg, fieldtype, offset) = value; -#define UPB_MAPTYPE_STRING 4 +#define UPB_MAPTYPE_STRING 0 /* UPB_INLINE: inline if possible, emit standalone code if required. */ #ifdef __cplusplus diff --git a/upb/port_undef.inc b/upb/port_undef.inc index 103180b7fb..6a4daa5076 100644 --- a/upb/port_undef.inc +++ b/upb/port_undef.inc @@ -1,5 +1,6 @@ /* See port_def.inc. This should #undef all macros #defined there. */ +#undef UPB_MAPTYPE_STRING #undef UPB_SIZE #undef UPB_FIELD_AT #undef UPB_READ_ONEOF diff --git a/upb/reflection.c b/upb/reflection.c index e0a6299bb2..11db23b999 100644 --- a/upb/reflection.c +++ b/upb/reflection.c @@ -7,7 +7,7 @@ #include "upb/port_def.inc" -char field_size[] = { +static char field_size[] = { 0,/* 0 */ 8, /* UPB_DESCRIPTOR_TYPE_DOUBLE */ 4, /* UPB_DESCRIPTOR_TYPE_FLOAT */ @@ -29,6 +29,22 @@ char field_size[] = { 8, /* UPB_DESCRIPTOR_TYPE_SINT64 */ }; +/* Strings/bytes are special-cased in maps. */ +static char _upb_fieldtype_to_mapsize[12] = { + 0, + 1, /* UPB_TYPE_BOOL */ + 4, /* UPB_TYPE_FLOAT */ + 4, /* UPB_TYPE_INT32 */ + 4, /* UPB_TYPE_UINT32 */ + 4, /* UPB_TYPE_ENUM */ + sizeof(void*), /* UPB_TYPE_MESSAGE */ + 8, /* UPB_TYPE_DOUBLE */ + 8, /* UPB_TYPE_INT64 */ + 8, /* UPB_TYPE_UINT64 */ + 0, /* UPB_TYPE_STRING */ + 0, /* UPB_TYPE_BYTES */ +}; + /** upb_msg *******************************************************************/ /* If we always read/write as a consistent type to each address, this shouldn't @@ -66,7 +82,8 @@ upb_msgval upb_msg_get(const upb_msg *msg, const upb_fielddef *f) { const char *mem = PTR_AT(msg, field->offset, char); upb_msgval val; if (field->presence == 0 || upb_msg_has(msg, f)) { - int size = upb_fielddef_isseq(f) ? sizeof(void*) : field_size[field->descriptortype]; + int size = upb_fielddef_isseq(f) ? sizeof(void *) + : field_size[field->descriptortype]; memcpy(&val, mem, size); } else { /* TODO(haberman): change upb_fielddef to not require this switch(). */ @@ -132,7 +149,8 @@ void upb_msg_set(upb_msg *msg, const upb_fielddef *f, upb_msgval val, upb_arena *a) { const upb_msglayout_field *field = upb_fielddef_layout(f); char *mem = PTR_AT(msg, field->offset, char); - int size = upb_fielddef_isseq(f) ? sizeof(void*) : field_size[field->descriptortype]; + int size = upb_fielddef_isseq(f) ? sizeof(void *) + : field_size[field->descriptortype]; memcpy(mem, &val, size); if (in_oneof(field)) { *oneofcase(msg, field) = field->number; @@ -180,123 +198,48 @@ bool upb_array_resize(upb_array *arr, size_t size, upb_arena *arena) { /** upb_map *******************************************************************/ -size_t upb_map_size(const upb_map *map) { - return upb_strtable_count(&map->table); +upb_map *upb_map_new(upb_arena *a, upb_fieldtype_t key_type, + upb_fieldtype_t value_type) { + return _upb_map_new(a, _upb_fieldtype_to_mapsize[key_type], + _upb_fieldtype_to_mapsize[value_type]); } -static upb_strview upb_map_tokey(int size_lg2, upb_msgval *key) { - if (size_lg2 == UPB_MAPTYPE_STRING) { - return key->str_val; - } else { - return upb_strview_make((const char*)key, 1 << size_lg2); - } -} - -static upb_msgval upb_map_fromvalue(int size_lg2, upb_value val) { - upb_msgval ret; - if (size_lg2 == UPB_MAPTYPE_STRING) { - upb_strview *strp = upb_value_getptr(val); - ret.str_val = *strp; - } else { - memcpy(&ret, &val, 8); - } - return ret; -} - -static upb_value upb_map_tovalue(int size_lg2, upb_msgval val, upb_arena *a) { - upb_value ret; - if (size_lg2 == UPB_MAPTYPE_STRING) { - upb_strview *strp = upb_arena_malloc(a, sizeof(*strp)); - *strp = val.str_val; - ret = upb_value_ptr(strp); - } else { - memcpy(&ret, &val, 8); - } - return ret; +size_t upb_map_size(const upb_map *map) { + return _upb_map_size(map); } bool upb_map_get(const upb_map *map, upb_msgval key, upb_msgval *val) { - upb_strview strkey = upb_map_tokey(map->key_size_lg2, &key); - upb_value tabval; - bool ret; - - ret = upb_strtable_lookup2(&map->table, strkey.data, strkey.size, &tabval); - if (ret) { - *val = upb_map_fromvalue(map->val_size_lg2, tabval); - } - - return ret; + return _upb_map_get(map, &key, map->key_size, val, map->val_size); } bool upb_map_set(upb_map *map, upb_msgval key, upb_msgval val, upb_arena *arena) { - upb_strview strkey = upb_map_tokey(map->key_size_lg2, &key); - upb_value tabval = upb_map_tovalue(map->val_size_lg2, val, arena); - upb_alloc *a = upb_arena_alloc(arena); - - /* TODO(haberman): add overwrite operation to minimize number of lookups. */ - if (upb_strtable_lookup2(&map->table, strkey.data, strkey.size, NULL)) { - upb_strtable_remove3(&map->table, strkey.data, strkey.size, NULL, a); - } - - return upb_strtable_insert3(&map->table, strkey.data, strkey.size, tabval, a); + return _upb_map_set(map, &key, map->key_size, &val, map->val_size, arena); } -bool upb_map_delete(upb_map *map, upb_msgval key, upb_arena *arena) { - upb_strview strkey = upb_map_tokey(map->key_size_lg2, &key); - upb_alloc *a = upb_arena_alloc(arena); - return upb_strtable_remove3(&map->table, strkey.data, strkey.size, NULL, a); +bool upb_map_delete(upb_map *map, upb_msgval key) { + return _upb_map_delete(map, &key, map->key_size); } -/** upb_mapiter ***************************************************************/ - -struct upb_mapiter { - upb_strtable_iter iter; - char key_size_lg2; - char val_size_lg2; -}; - -size_t upb_mapiter_sizeof(void) { - return sizeof(upb_mapiter); +bool upb_mapiter_next(const upb_map *map, size_t *iter) { + return _upb_map_next(map, iter); } -void upb_mapiter_begin(upb_mapiter *i, upb_map *map) { - upb_strtable_begin(&i->iter, &map->table); - i->key_size_lg2 = map->key_size_lg2; - i->val_size_lg2 = map->val_size_lg2; -} - -void upb_mapiter_free(upb_mapiter *i, upb_alloc *a) { - upb_free(a, i); -} - -void upb_mapiter_next(upb_mapiter *i) { - upb_strtable_next(&i->iter); -} - -bool upb_mapiter_done(const upb_mapiter *i) { - return upb_strtable_done(&i->iter); -} - -upb_msgval upb_mapiter_key(const upb_mapiter *i) { - upb_strview key = upb_strtable_iter_key(&i->iter); +/* Returns the key and value for this entry of the map. */ +upb_msgval upb_mapiter_key(const upb_map *map, size_t iter) { + upb_strtable_iter i = {&map->table, iter}; + upb_strview key = upb_strtable_iter_key(&i); upb_msgval ret; - if (i->key_size_lg2 == UPB_MAPTYPE_STRING) { - ret.str_val = key; - } else { - memcpy(&ret, key.data, 1 << i->key_size_lg2); - } + _upb_map_fromkey(key, &ret, map->key_size); return ret; } -upb_msgval upb_mapiter_value(const upb_mapiter *i) { - return upb_map_fromvalue(i->val_size_lg2, upb_strtable_iter_value(&i->iter)); -} - -void upb_mapiter_setdone(upb_mapiter *i) { - upb_strtable_iter_setdone(&i->iter); +upb_msgval upb_mapiter_value(const upb_map *map, size_t iter) { + upb_strtable_iter i = {&map->table, iter}; + upb_value val = upb_strtable_iter_value(&i); + upb_msgval ret; + _upb_map_fromvalue(val, &ret, map->val_size); + return ret; } -bool upb_mapiter_isequal(const upb_mapiter *i1, const upb_mapiter *i2) { - return upb_strtable_iter_isequal(&i1->iter, &i2->iter); -} +/* void upb_mapiter_setvalue(upb_map *map, size_t iter, upb_msgval value); */ diff --git a/upb/reflection.h b/upb/reflection.h index 58fa7ee7a2..f284925904 100644 --- a/upb/reflection.h +++ b/upb/reflection.h @@ -72,6 +72,10 @@ bool upb_array_resize(upb_array *array, size_t size, upb_arena *arena); /** upb_map *******************************************************************/ +/* Creates a new map on the given arena with the given key/value size. */ +upb_map *upb_map_new(upb_arena *a, upb_fieldtype_t key_type, + upb_fieldtype_t value_type); + /* Returns the number of entries in the map. */ size_t upb_map_size(const upb_map *map); @@ -89,48 +93,30 @@ bool upb_map_set(upb_map *map, upb_msgval key, upb_msgval val, upb_arena *arena); /* Deletes this key from the table. Returns true if the key was present. */ -/* TODO(haberman): can |arena| be removed once upb_table is arena-only and no - * longer tries to free keys? */ -bool upb_map_delete(upb_map *map, upb_msgval key, upb_arena *arena); - -/** upb_mapiter ***************************************************************/ - -/* For iterating over a map. Map iterators are invalidated by mutations to the - * map, but an invalidated iterator will never return junk or crash the process - * (this is an important property when exposing iterators to interpreted - * languages like Ruby, PHP, etc). An invalidated iterator may return entries - * that were already returned though, and if you keep invalidating the iterator - * during iteration, the program may enter an infinite loop. */ -struct upb_mapiter; -typedef struct upb_mapiter upb_mapiter; - -size_t upb_mapiter_sizeof(void); - -/* Starts iteration. If the map is mutable then we can modify entries while - * iterating. */ -void upb_mapiter_constbegin(upb_mapiter *i, const upb_map *map); -void upb_mapiter_begin(upb_mapiter *i, upb_map *map); - -/* Sets the iterator to "done" state. This will return "true" from - * upb_mapiter_done() and will compare equal to other "done" iterators. */ -void upb_mapiter_setdone(upb_mapiter *i); - -/* Advances to the next entry. The iterator must not be done. */ -void upb_mapiter_next(upb_mapiter *i); +bool upb_map_delete(upb_map *map, upb_msgval key); + +/* Map iteration: + * + * size_t iter = UPB_MAP_BEGIN; + * while (upb_mapiter_next(map, &iter)) { + * upb_msgval key = upb_mapiter_key(map, iter); + * upb_msgval val = upb_mapiter_value(map, iter); + * + * // If mutating is desired. + * upb_mapiter_setvalue(map, iter, value2); + * } + */ + +/* Advances to the next entry. Returns false if no more entries are present. */ +bool upb_mapiter_next(const upb_map *map, size_t *iter); /* Returns the key and value for this entry of the map. */ -upb_msgval upb_mapiter_key(const upb_mapiter *i); -upb_msgval upb_mapiter_value(const upb_mapiter *i); +upb_msgval upb_mapiter_key(const upb_map *map, size_t iter); +upb_msgval upb_mapiter_value(const upb_map *map, size_t iter); /* Sets the value for this entry. The iterator must not be done, and the * iterator must not have been initialized const. */ -void upb_mapiter_setvalue(const upb_mapiter *i, upb_msgval value); - -/* Returns true if the iterator is done. */ -bool upb_mapiter_done(const upb_mapiter *i); - -/* Compares two iterators for equality. */ -bool upb_mapiter_isequal(const upb_mapiter *i1, const upb_mapiter *i2); +void upb_mapiter_setvalue(upb_map *map, size_t iter, upb_msgval value); #include "upb/port_undef.inc" diff --git a/upb/table.c b/upb/table.c index 5b48bb7e50..1d01a223d0 100644 --- a/upb/table.c +++ b/upb/table.c @@ -284,6 +284,12 @@ bool upb_strtable_init2(upb_strtable *t, upb_ctype_t ctype, upb_alloc *a) { return init(&t->t, 2, a); } +void upb_strtable_clear(upb_strtable *t) { + size_t bytes = upb_table_size(&t->t) * sizeof(upb_tabent); + t->t.count = 0; + memset((char*)t->t.entries, 0, bytes); +} + void upb_strtable_uninit2(upb_strtable *t, upb_alloc *a) { size_t i; for (i = 0; i < upb_table_size(&t->t); i++) @@ -342,7 +348,10 @@ bool upb_strtable_remove3(upb_strtable *t, const char *key, size_t len, uint32_t hash = upb_murmur_hash2(key, len, 0); upb_tabkey tabkey; if (rm(&t->t, strkey2(key, len), val, &tabkey, hash, &streql)) { - upb_free(alloc, (void*)tabkey); + if (alloc) { + /* Arena-based allocs don't need to free and won't pass this. */ + upb_free(alloc, (void*)tabkey); + } return true; } else { return false; @@ -351,10 +360,6 @@ bool upb_strtable_remove3(upb_strtable *t, const char *key, size_t len, /* Iteration */ -static const upb_tabent *str_tabent(const upb_strtable_iter *i) { - return &i->t->t.entries[i->index]; -} - void upb_strtable_begin(upb_strtable_iter *i, const upb_strtable *t) { i->t = t; i->index = begin(&t->t); diff --git a/upb/table.int.h b/upb/table.int.h index d571a3498a..75575eb7d5 100644 --- a/upb/table.int.h +++ b/upb/table.int.h @@ -262,6 +262,7 @@ upb_inttable *upb_inttable_pack(const upb_inttable *t, void *p, size_t *ofs, size_t size); upb_strtable *upb_strtable_pack(const upb_strtable *t, void *p, size_t *ofs, size_t size); +void upb_strtable_clear(upb_strtable *t); /* Inserts the given key into the hashtable with the given value. The key must * not already exist in the hash table. For string tables, the key must be @@ -451,6 +452,10 @@ typedef struct { bool array_part; } upb_inttable_iter; +UPB_INLINE const upb_tabent *str_tabent(const upb_strtable_iter *i) { + return &i->t->t.entries[i->index]; +} + void upb_inttable_begin(upb_inttable_iter *i, const upb_inttable *t); void upb_inttable_next(upb_inttable_iter *i); bool upb_inttable_done(const upb_inttable_iter *i); diff --git a/upb/upb.h b/upb/upb.h index b64e4b1697..011103d949 100644 --- a/upb/upb.h +++ b/upb/upb.h @@ -358,6 +358,8 @@ typedef enum { UPB_DESCRIPTOR_TYPE_SINT64 = 18 } upb_descriptortype_t; +#define UPB_MAP_BEGIN -1 + #include "upb/port_undef.inc" #endif /* UPB_H_ */ diff --git a/upbc/generator.cc b/upbc/generator.cc index 09ed1ee7f4..50fb42d564 100644 --- a/upbc/generator.cc +++ b/upbc/generator.cc @@ -331,21 +331,24 @@ void GenerateMessageInHeader(const protobuf::Descriptor* message, Output& output output("/* $0 */\n\n", message->full_name()); std::string msgname = ToCIdent(message->full_name()); - output( - "UPB_INLINE $0 *$0_new(upb_arena *arena) {\n" - " return ($0 *)_upb_msg_new(&$1, arena);\n" - "}\n" - "UPB_INLINE $0 *$0_parse(const char *buf, size_t size,\n" - " upb_arena *arena) {\n" - " $0 *ret = $0_new(arena);\n" - " return (ret && upb_decode(buf, size, ret, &$1, arena)) ? ret : NULL;\n" - "}\n" - "UPB_INLINE char *$0_serialize(const $0 *msg, upb_arena *arena, size_t " - "*len) {\n" - " return upb_encode(msg, &$1, arena, len);\n" - "}\n" - "\n", - MessageName(message), MessageInit(message)); + + if (!message->options().map_entry()) { + output( + "UPB_INLINE $0 *$0_new(upb_arena *arena) {\n" + " return ($0 *)_upb_msg_new(&$1, arena);\n" + "}\n" + "UPB_INLINE $0 *$0_parse(const char *buf, size_t size,\n" + " upb_arena *arena) {\n" + " $0 *ret = $0_new(arena);\n" + " return (ret && upb_decode(buf, size, ret, &$1, arena)) ? ret : NULL;\n" + "}\n" + "UPB_INLINE char *$0_serialize(const $0 *msg, upb_arena *arena, size_t " + "*len) {\n" + " return upb_encode(msg, &$1, arena, len);\n" + "}\n" + "\n", + MessageName(message), MessageInit(message)); + } for (int i = 0; i < message->oneof_decl_count(); i++) { const protobuf::OneofDescriptor* oneof = message->oneof_decl(i); @@ -386,7 +389,42 @@ void GenerateMessageInHeader(const protobuf::Descriptor* message, Output& output } // Generate getter. - if (field->is_repeated()) { + if (field->is_map()) { + const protobuf::Descriptor* entry = field->message_type(); + const protobuf::FieldDescriptor* key = entry->FindFieldByNumber(1); + const protobuf::FieldDescriptor* val = entry->FindFieldByNumber(2); + output( + "UPB_INLINE size_t $0_$1_size(const $0 *msg) {" + "return _upb_msg_map_size(msg, $2); }\n", + msgname, field->name(), GetSizeInit(layout.GetFieldOffset(field))); + output( + "UPB_INLINE bool $0_$1_get(const $0 *msg, $2 key, $3 *val) { " + "return _upb_msg_map_get(msg, $4, &key, $5, val, $6); }\n", + msgname, field->name(), CType(key), CType(val), + GetSizeInit(layout.GetFieldOffset(field)), + key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(key)", + val->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(*val)"); + output( + "UPB_INLINE $0 $1_$2_next(const $1 *msg, size_t* iter) { " + "return ($0)_upb_msg_map_next(msg, $3, iter); }\n", + CTypeConst(field), msgname, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); + } else if (message->options().map_entry()) { + output( + "UPB_INLINE $0 $1_$2(const $1 *msg) {\n" + " $3 ret;\n" + " _upb_msg_map_$2(msg, &ret, $4);\n" + " return ret;\n" + "}\n", + CTypeConst(field), msgname, field->name(), CType(field), + field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(ret)"); + } else if (field->is_repeated()) { output( "UPB_INLINE $0 const* $1_$2(const $1 *msg, size_t *len) { " "return ($0 const*)_upb_array_accessor(msg, $3, len); }\n", @@ -414,11 +452,39 @@ void GenerateMessageInHeader(const protobuf::Descriptor* message, Output& output // Generate mutable methods. for (auto field : FieldNumberOrder(message)) { - if (message->options().map_entry() && field->name() == "key") { - // Emit nothing, map keys cannot be changed directly. Users must use - // the mutators of the map itself. - } else if (field->is_map()) { + if (field->is_map()) { // TODO(haberman): add map-based mutators. + const protobuf::Descriptor* entry = field->message_type(); + const protobuf::FieldDescriptor* key = entry->FindFieldByNumber(1); + const protobuf::FieldDescriptor* val = entry->FindFieldByNumber(2); + output( + "UPB_INLINE void $0_$1_clear($0 *msg) { _upb_msg_map_clear(msg, $2); }\n", + msgname, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); + output( + "UPB_INLINE bool $0_$1_set($0 *msg, $2 key, $3 val, upb_arena *a) { " + "return _upb_msg_map_set(msg, $4, &key, $5, &val, $6, a); }\n", + msgname, field->name(), CType(key), CType(val), + GetSizeInit(layout.GetFieldOffset(field)), + key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(key)", + val->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(val)"); + output( + "UPB_INLINE bool $0_$1_delete($0 *msg, $2 key) { " + "return _upb_msg_map_delete(msg, $3, &key, $4); }\n", + msgname, field->name(), CType(key), + GetSizeInit(layout.GetFieldOffset(field)), + key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(key)"); + output( + "UPB_INLINE $0 $1_$2_nextmutable($1 *msg, size_t* iter) { " + "return ($0)_upb_msg_map_next(msg, $3, iter); }\n", + CType(field), msgname, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); } else if (field->is_repeated()) { output( "UPB_INLINE $0* $1_mutable_$2($1 *msg, size_t *len) {\n" @@ -461,9 +527,24 @@ void GenerateMessageInHeader(const protobuf::Descriptor* message, Output& output } } else { // Non-repeated field. + if (message->options().map_entry() && field->name() == "key") { + // Key cannot be mutated. + continue; + } + + // The common function signature for all setters. Varying implementations + // follow. output("UPB_INLINE void $0_set_$1($0 *msg, $2 value) {\n", msgname, field->name(), CType(field)); - if (field->containing_oneof()) { + + if (message->options().map_entry()) { + output( + " _upb_msg_map_set_value(msg, &value, $0);\n" + "}\n", + field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(" + CType(field) + ")"); + } else if (field->containing_oneof()) { output( " UPB_WRITE_ONEOF(msg, $0, $1, value, $2, $3);\n" "}\n", @@ -479,7 +560,9 @@ void GenerateMessageInHeader(const protobuf::Descriptor* message, Output& output "}\n", CType(field), GetSizeInit(layout.GetFieldOffset(field))); } - if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + + if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + !message->options().map_entry()) { output( "UPB_INLINE struct $0* $1_mutable_$2($1 *msg, upb_arena *arena) {\n" " struct $0* sub = (struct $0*)$1_$2(msg);\n" @@ -504,7 +587,6 @@ void WriteHeader(const protobuf::FileDescriptor* file, Output& output) { output( "#ifndef $0_UPB_H_\n" "#define $0_UPB_H_\n\n" - "#include \"upb/generated_util.h\"\n" "#include \"upb/msg.h\"\n" "#include \"upb/decode.h\"\n" "#include \"upb/encode.h\"\n\n",