diff --git a/python/BUILD b/python/BUILD index 5db0894d31..3bb4a1e2b8 100644 --- a/python/BUILD +++ b/python/BUILD @@ -203,6 +203,7 @@ py_extension( "-Wno-pedantic", ], deps = [ + "//:collections", "//:descriptor_upb_proto_reflection", "//:reflection", "//:table", diff --git a/python/map.c b/python/map.c index 35d3d325b2..482e62fb64 100644 --- a/python/map.c +++ b/python/map.c @@ -30,6 +30,7 @@ #include "python/convert.h" #include "python/message.h" #include "python/protobuf.h" +#include "upb/collections.h" // ----------------------------------------------------------------------------- // MapContainer @@ -145,6 +146,21 @@ upb_Map* PyUpb_MapContainer_EnsureReified(PyObject* _self) { return map; } +bool PyUpb_MapContainer_Set(PyUpb_MapContainer* self, upb_Map* map, + upb_MessageValue key, upb_MessageValue val, + upb_Arena* arena) { + switch (upb_Map_Insert(map, key, val, arena)) { + case kUpb_MapInsertStatus_Inserted: + return true; + case kUpb_MapInsertStatus_Replaced: + // We did not insert a new key, undo the previous invalidate. + self->version--; + return true; + case kUpb_MapInsertStatus_OutOfMemory: + return false; + } +} + int PyUpb_MapContainer_AssignSubscript(PyObject* _self, PyObject* key, PyObject* val) { PyUpb_MapContainer* self = (PyUpb_MapContainer*)_self; @@ -159,7 +175,7 @@ int PyUpb_MapContainer_AssignSubscript(PyObject* _self, PyObject* key, if (val) { if (!PyUpb_PyToUpb(val, val_f, &u_val, arena)) return -1; - upb_Map_Set(map, u_key, u_val, arena); + if (!PyUpb_MapContainer_Set(self, map, u_key, u_val, arena)) return -1; } else { if (!upb_Map_Delete(map, u_key)) { PyErr_Format(PyExc_KeyError, "Key not present in map"); @@ -187,7 +203,7 @@ PyObject* PyUpb_MapContainer_Subscript(PyObject* _self, PyObject* key) { } else { memset(&u_val, 0, sizeof(u_val)); } - upb_Map_Set(map, u_key, u_val, arena); + if (!PyUpb_MapContainer_Set(self, map, u_key, u_val, arena)) return false; } return PyUpb_UpbToPy(u_val, val_f, self->arena); } diff --git a/python/pb_unit_tests/message_test_wrapper.py b/python/pb_unit_tests/message_test_wrapper.py index 040ffded0b..bb1b6c7de9 100644 --- a/python/pb_unit_tests/message_test_wrapper.py +++ b/python/pb_unit_tests/message_test_wrapper.py @@ -47,7 +47,6 @@ Proto3Test.testCopyFromBadType.__unittest_expecting_failure__ = True Proto3Test.testMergeFromBadType.__unittest_expecting_failure__ = True Proto2Test.test_documentation.__unittest_expecting_failure__ = True -Proto3Test.testModifyMapEntryWhileIterating.__unittest_expecting_failure__ = True if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/upb/collections.c b/upb/collections.c index a0de0407a5..6bf7fbbbc0 100644 --- a/upb/collections.c +++ b/upb/collections.c @@ -147,9 +147,10 @@ bool upb_Map_Get(const upb_Map* map, upb_MessageValue key, void upb_Map_Clear(upb_Map* map) { _upb_Map_Clear(map); } -bool upb_Map_Set(upb_Map* map, upb_MessageValue key, upb_MessageValue val, - upb_Arena* arena) { - return _upb_Map_Set(map, &key, map->key_size, &val, map->val_size, arena); +upb_MapInsertStatus upb_Map_Insert(upb_Map* map, upb_MessageValue key, + upb_MessageValue val, upb_Arena* arena) { + return (upb_MapInsertStatus)_upb_Map_Insert(map, &key, map->key_size, &val, + map->val_size, arena); } bool upb_Map_Delete(upb_Map* map, upb_MessageValue key) { diff --git a/upb/collections.h b/upb/collections.h index fa28d83cf7..047e14ef68 100644 --- a/upb/collections.h +++ b/upb/collections.h @@ -110,10 +110,28 @@ bool upb_Map_Get(const upb_Map* map, upb_MessageValue key, /* Removes all entries in the map. */ void upb_Map_Clear(upb_Map* map); -/* Sets the given key to the given value. Returns true if this was a new key in - * the map, or false if an existing key was replaced. */ -bool upb_Map_Set(upb_Map* map, upb_MessageValue key, upb_MessageValue val, - upb_Arena* arena); +typedef enum { + // LINT.IfChange + kUpb_MapInsertStatus_Inserted = 0, + kUpb_MapInsertStatus_Replaced = 1, + kUpb_MapInsertStatus_OutOfMemory = 2, + // LINT.ThenChange(//depot/google3/third_party/upb/upb/msg_internal.h) +} upb_MapInsertStatus; + +/* Sets the given key to the given value, returning whether the key was inserted + * or replaced. If the key was inserted, then any existing iterators will be + * invalidated. */ +upb_MapInsertStatus upb_Map_Insert(upb_Map* map, upb_MessageValue key, + upb_MessageValue val, upb_Arena* arena); + +/* Sets the given key to the given value. Returns false if memory allocation + * failed. If the key is newly inserted, then any existing iterators will be + * invalidated. */ +UPB_INLINE bool upb_Map_Set(upb_Map* map, upb_MessageValue key, + upb_MessageValue val, upb_Arena* arena) { + return upb_Map_Insert(map, key, val, arena) != + kUpb_MapInsertStatus_OutOfMemory; +} /* Deletes this key from the table. Returns true if the key was present. */ bool upb_Map_Delete(upb_Map* map, upb_MessageValue key); diff --git a/upb/decode.c b/upb/decode.c index 0985abc965..e6dac273c3 100644 --- a/upb/decode.c +++ b/upb/decode.c @@ -653,7 +653,10 @@ static const char* decode_tomap(upb_Decoder* d, const char* ptr, decode_err(d, kUpb_DecodeStatus_OutOfMemory); } } else { - _upb_Map_Set(map, &ent.k, map->key_size, &ent.v, map->val_size, &d->arena); + if (_upb_Map_Insert(map, &ent.k, map->key_size, &ent.v, map->val_size, + &d->arena) == _kUpb_MapInsertStatus_OutOfMemory) { + decode_err(d, kUpb_DecodeStatus_OutOfMemory); + } } return ptr; } diff --git a/upb/msg_internal.h b/upb/msg_internal.h index b5d8b5c7dd..584ed775fb 100644 --- a/upb/msg_internal.h +++ b/upb/msg_internal.h @@ -667,15 +667,31 @@ UPB_INLINE void* _upb_map_next(const upb_Map* map, size_t* iter) { return (void*)str_tabent(&it); } -UPB_INLINE bool _upb_Map_Set(upb_Map* map, const void* key, size_t key_size, - void* val, size_t val_size, upb_Arena* a) { +typedef enum { + // LINT.IfChange + _kUpb_MapInsertStatus_Inserted = 0, + _kUpb_MapInsertStatus_Replaced = 1, + _kUpb_MapInsertStatus_OutOfMemory = 2, + // LINT.ThenChange(//depot/google3/third_party/upb/upb/collections.h) +} _upb_MapInsertStatus; + +UPB_INLINE _upb_MapInsertStatus _upb_Map_Insert(upb_Map* map, const void* key, + size_t key_size, void* val, + size_t val_size, upb_Arena* a) { upb_StringView strkey = _upb_map_tokey(key, key_size); upb_value tabval = {0}; - if (!_upb_map_tovalue(val, val_size, &tabval, a)) return false; + if (!_upb_map_tovalue(val, val_size, &tabval, a)) { + return _kUpb_MapInsertStatus_OutOfMemory; + } /* TODO(haberman): add overwrite operation to minimize number of lookups. */ - upb_strtable_remove2(&map->table, strkey.data, strkey.size, NULL); - return upb_strtable_insert(&map->table, strkey.data, strkey.size, tabval, a); + bool removed = + upb_strtable_remove2(&map->table, strkey.data, strkey.size, NULL); + if (!upb_strtable_insert(&map->table, strkey.data, strkey.size, tabval, a)) { + return _kUpb_MapInsertStatus_OutOfMemory; + } + return removed ? _kUpb_MapInsertStatus_Replaced + : _kUpb_MapInsertStatus_Inserted; } UPB_INLINE bool _upb_Map_Delete(upb_Map* map, const void* key, @@ -717,7 +733,8 @@ UPB_INLINE bool _upb_msg_map_set(upb_Message* msg, size_t ofs, const void* key, if (!*map) { *map = _upb_Map_New(arena, key_size, val_size); } - return _upb_Map_Set(*map, key, key_size, val, val_size, arena); + return _upb_Map_Insert(*map, key, key_size, val, val_size, arena) != + _kUpb_MapInsertStatus_OutOfMemory; } UPB_INLINE bool _upb_msg_map_delete(upb_Message* msg, size_t ofs,