From 81da6b999a8229942436f6c203a20633c65ebd26 Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Tue, 12 Nov 2024 09:41:02 -0800 Subject: [PATCH] Breaking Change: Python setdefault behavior change for map field. -setdefault will be similar with dict for ScalarMap. But both key and value must be set. -setdefault will be rejected for MessageMap. PiperOrigin-RevId: 695768629 --- python/google/protobuf/internal/containers.py | 13 +++++ .../google/protobuf/internal/message_test.py | 28 ++++++++++ python/google/protobuf/pyext/map_container.cc | 39 ++++++++++++++ python/map.c | 51 ++++++++++++++++++- 4 files changed, 129 insertions(+), 2 deletions(-) diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 23357816f6..7298bc5c7a 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -412,6 +412,13 @@ class ScalarMap(MutableMapping[_K, _V]): def __repr__(self) -> str: return repr(self._values) + def setdefault(self, key: _K, value: Optional[_V] = None) -> _V: + if value == None: + raise ValueError('The value for scalar map setdefault must be set.') + if key not in self._values: + self.__setitem__(key, value) + return self[key] + def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None: self._values.update(other._values) self._message_listener.Modified() @@ -526,6 +533,12 @@ class MessageMap(MutableMapping[_K, _V]): def __repr__(self) -> str: return repr(self._values) + def setdefault(self, key: _K, value: Optional[_V] = None) -> _V: + raise NotImplementedError( + 'Set message map value directly is not supported, call' + ' my_map[key].foo = 5' + ) + def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None: # pylint: disable=protected-access for key in other._values: diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 3a9852b570..43f0a10dc7 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -1900,6 +1900,26 @@ class Proto3Test(unittest.TestCase): self.assertEqual(msg1.map_int32_int32, msg2.map_int32_int32) + def testScalarMapSetdefault(self): + msg = map_unittest_pb2.TestMap() + value = msg.map_int32_int32.setdefault(123, 888) + self.assertEqual(value, 888) + self.assertEqual(msg.map_int32_int32[123], 888) + value = msg.map_int32_int32.setdefault(123, 777) + self.assertEqual(value, 888) + + with self.assertRaises(ValueError): + value = msg.map_int32_int32.setdefault(1001) + self.assertNotIn(1001, msg.map_int32_int32) + with self.assertRaises(TypeError): + value = msg.map_int32_int32.setdefault() + with self.assertRaises(TypeError): + value = msg.map_int32_int32.setdefault(1, 2, 3) + with self.assertRaises(TypeError): + value = msg.map_int32_int32.setdefault("1", 2) + with self.assertRaises(TypeError): + value = msg.map_int32_int32.setdefault(1, "2") + def testMessageMapComparison(self): msg1 = map_unittest_pb2.TestMap() msg2 = map_unittest_pb2.TestMap() @@ -1907,6 +1927,14 @@ class Proto3Test(unittest.TestCase): self.assertEqual(msg1.map_int32_foreign_message, msg2.map_int32_foreign_message) + def testMessageMapSetdefault(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_foreign_message[123].c = 888 + with self.assertRaises(NotImplementedError): + msg.map_int32_foreign_message.setdefault( + 1, msg.map_int32_foreign_message[123] + ) + def testMapGet(self): # Need to test that get() properly returns the default, even though the dict # has defaultdict-like semantics. diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 90c01228cc..6322982a39 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -438,6 +438,35 @@ int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key, } } +static PyObject* ScalarMapSetdefault(PyObject* self, PyObject* args) { + PyObject* key = nullptr; + PyObject* default_value = Py_None; + + if (!PyArg_UnpackTuple(args, "setdefault", 1, 2, &key, &default_value)) { + return nullptr; + } + + if (default_value == Py_None) { + PyErr_Format(PyExc_ValueError, + "The value for scalar map setdefault must be set."); + return nullptr; + } + + ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key)); + if (is_present == nullptr) { + return nullptr; + } + if (PyObject_IsTrue(is_present.get())) { + return MapReflectionFriend::ScalarMapGetItem(self, key); + } + + if (MapReflectionFriend::ScalarMapSetItem(self, key, default_value) < 0) { + return nullptr; + } + Py_INCREF(default_value); + return default_value; +} + static PyObject* ScalarMapGet(PyObject* self, PyObject* args, PyObject* kwargs) { static const char* kwlist[] = {"key", "default", nullptr}; @@ -512,6 +541,8 @@ static PyMethodDef ScalarMapMethods[] = { "Tests whether a key is a member of the map."}, {"clear", (PyCFunction)Clear, METH_NOARGS, "Removes all elements from the map."}, + {"setdefault", (PyCFunction)ScalarMapSetdefault, METH_VARARGS, + "If the key does not exist, insert the key, with the specified value"}, {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS, "Gets the value for the given key if present, or otherwise a default"}, {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS, @@ -685,6 +716,12 @@ PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) { return PyObject_Repr(dict.get()); } +static PyObject* MessageMapSetdefault(PyObject* self, PyObject* args) { + PyErr_Format(PyExc_NotImplementedError, + "Set message map value directly is not supported."); + return nullptr; +} + PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) { static const char* kwlist[] = {"key", "default", nullptr}; PyObject* key; @@ -729,6 +766,8 @@ static PyMethodDef MessageMapMethods[] = { "Tests whether the map contains this element."}, {"clear", (PyCFunction)Clear, METH_NOARGS, "Removes all elements from the map."}, + {"setdefault", (PyCFunction)MessageMapSetdefault, METH_VARARGS, + "setdefault is disallowed in MessageMap."}, {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS, "Gets the value for the given key if present, or otherwise a default"}, {"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O, diff --git a/python/map.c b/python/map.c index 1e2b3806b6..4fe7c19114 100644 --- a/python/map.c +++ b/python/map.c @@ -217,6 +217,48 @@ static PyObject* PyUpb_MapContainer_Clear(PyObject* _self, PyObject* key) { Py_RETURN_NONE; } +static PyObject* PyUpb_ScalarMapContainer_Setdefault(PyObject* _self, + PyObject* args) { + PyObject* key; + PyObject* default_value = Py_None; + + if (!PyArg_UnpackTuple(args, "setdefault", 1, 2, &key, &default_value)) { + return NULL; + } + + if (default_value == Py_None) { + PyErr_Format(PyExc_ValueError, + "The value for scalar map setdefault must be set."); + return NULL; + } + + PyUpb_MapContainer* self = (PyUpb_MapContainer*)_self; + upb_Map* map = PyUpb_MapContainer_EnsureReified(_self); + const upb_FieldDef* f = PyUpb_MapContainer_GetField(self); + const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f); + const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0); + const upb_FieldDef* val_f = upb_MessageDef_Field(entry_m, 1); + upb_MessageValue u_key, u_val; + if (!PyUpb_PyToUpb(key, key_f, &u_key, NULL)) return NULL; + if (upb_Map_Get(map, u_key, &u_val)) { + return PyUpb_UpbToPy(u_val, val_f, self->arena); + } + + upb_Arena* arena = PyUpb_Arena_Get(self->arena); + if (!PyUpb_PyToUpb(default_value, val_f, &u_val, arena)) return NULL; + if (!PyUpb_MapContainer_Set(self, map, u_key, u_val, arena)) return NULL; + + Py_INCREF(default_value); + return default_value; +} + +static PyObject* PyUpb_MessageMapContainer_Setdefault(PyObject* self, + PyObject* args) { + PyErr_Format(PyExc_NotImplementedError, + "Set message map value directly is not supported."); + return NULL; +} + static PyObject* PyUpb_MapContainer_Get(PyObject* _self, PyObject* args, PyObject* kwargs) { PyUpb_MapContainer* self = (PyUpb_MapContainer*)_self; @@ -331,6 +373,9 @@ PyObject* PyUpb_MapContainer_GetOrCreateWrapper(upb_Map* map, static PyMethodDef PyUpb_ScalarMapContainer_Methods[] = { {"clear", PyUpb_MapContainer_Clear, METH_NOARGS, "Removes all elements from the map."}, + {"setdefault", (PyCFunction)PyUpb_ScalarMapContainer_Setdefault, + METH_VARARGS, + "If the key does not exist, insert the key, with the specified value"}, {"get", (PyCFunction)PyUpb_MapContainer_Get, METH_VARARGS | METH_KEYWORDS, "Gets the value for the given key if present, or otherwise a default"}, {"GetEntryClass", PyUpb_MapContainer_GetEntryClass, METH_NOARGS, @@ -373,6 +418,8 @@ static PyType_Spec PyUpb_ScalarMapContainer_Spec = { static PyMethodDef PyUpb_MessageMapContainer_Methods[] = { {"clear", PyUpb_MapContainer_Clear, METH_NOARGS, "Removes all elements from the map."}, + {"setdefault", (PyCFunction)PyUpb_MessageMapContainer_Setdefault, + METH_VARARGS, "setdefault is disallowed in MessageMap."}, {"get", (PyCFunction)PyUpb_MapContainer_Get, METH_VARARGS | METH_KEYWORDS, "Gets the value for the given key if present, or otherwise a default"}, {"get_or_create", PyUpb_MapContainer_Subscript, METH_O, @@ -480,8 +527,8 @@ bool PyUpb_Map_Init(PyObject* m) { PyObject* base = GetMutableMappingBase(); if (!base) return false; - const char* methods[] = {"keys", "items", "values", "__eq__", "__ne__", - "pop", "popitem", "update", "setdefault", NULL}; + const char* methods[] = {"keys", "items", "values", "__eq__", "__ne__", + "pop", "popitem", "update", NULL}; state->message_map_container_type = PyUpb_AddClassWithRegister( m, &PyUpb_MessageMapContainer_Spec, base, methods);