diff --git a/python/extension_dict.c b/python/extension_dict.c index bb58603815..c2bd4d6572 100644 --- a/python/extension_dict.c +++ b/python/extension_dict.c @@ -88,24 +88,9 @@ static void PyUpb_ExtensionDict_Dealloc(PyUpb_ExtensionDict* self) { PyUpb_Dealloc(self); } -static const upb_fielddef* PyUpb_ExtensionDict_GetExtensionDef(PyObject* key) { - const upb_fielddef* f = PyUpb_FieldDescriptor_GetDef(key); - if (!f) { - PyErr_Clear(); - PyErr_Format(PyExc_KeyError, "Object %R is not a field descriptor\n", key); - return NULL; - } - if (!upb_fielddef_isextension(f)) { - PyErr_Format(PyExc_KeyError, "Field %s is not an extension\n", - upb_fielddef_fullname(f)); - return NULL; - } - return f; -} - static int PyUpb_ExtensionDict_Contains(PyObject* _self, PyObject* key) { PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self; - const upb_fielddef* f = PyUpb_ExtensionDict_GetExtensionDef(key); + const upb_fielddef* f = PyUpb_CMessage_GetExtensionDef(self->msg, key); if (!f) return -1; upb_msg* msg = PyUpb_CMessage_GetIfReified(self->msg); if (!msg) return 0; @@ -125,7 +110,7 @@ static Py_ssize_t PyUpb_ExtensionDict_Length(PyObject* _self) { static PyObject* PyUpb_ExtensionDict_Subscript(PyObject* _self, PyObject* key) { PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self; - const upb_fielddef* f = PyUpb_ExtensionDict_GetExtensionDef(key); + const upb_fielddef* f = PyUpb_CMessage_GetExtensionDef(self->msg, key); if (!f) return NULL; return PyUpb_CMessage_GetFieldValue(self->msg, f); } @@ -133,7 +118,7 @@ static PyObject* PyUpb_ExtensionDict_Subscript(PyObject* _self, PyObject* key) { static int PyUpb_ExtensionDict_AssignSubscript(PyObject* _self, PyObject* key, PyObject* val) { PyUpb_ExtensionDict* self = (PyUpb_ExtensionDict*)_self; - const upb_fielddef* f = PyUpb_ExtensionDict_GetExtensionDef(key); + const upb_fielddef* f = PyUpb_CMessage_GetExtensionDef(self->msg, key); if (!f) return -1; if (val) { return PyUpb_CMessage_SetFieldValue(self->msg, f, val); diff --git a/python/message.c b/python/message.c index 7a3c8f6254..736358c29e 100644 --- a/python/message.c +++ b/python/message.c @@ -1068,13 +1068,8 @@ void PyUpb_CMessage_DoClearField(PyObject* _self, const upb_fielddef* f) { static PyObject* PyUpb_CMessage_ClearExtension(PyObject* _self, PyObject* arg) { PyUpb_CMessage* self = (void*)_self; PyUpb_CMessage_AssureReified(self); - const upb_msgdef* msgdef = _PyUpb_CMessage_GetMsgdef(self); - const upb_fielddef* f = PyUpb_FieldDescriptor_GetDef(arg); + const upb_fielddef* f = PyUpb_CMessage_GetExtensionDef(_self, arg); if (!f) return NULL; - if (upb_fielddef_containingtype(f) != msgdef) { - PyErr_Format(PyExc_ValueError, "Extension doesn't match (%s vs %s)", - upb_msgdef_fullname(msgdef), upb_fielddef_fullname(f)); - } PyUpb_CMessage_DoClearField(_self, f); Py_RETURN_NONE; } @@ -1159,11 +1154,38 @@ err: goto done; } +const upb_fielddef* PyUpb_CMessage_GetExtensionDef(PyObject* _self, PyObject* key) { + const upb_fielddef* f = PyUpb_FieldDescriptor_GetDef(key); + if (!f) { + PyErr_Clear(); + PyErr_Format(PyExc_KeyError, "Object %R is not a field descriptor\n", key); + return NULL; + } + if (!upb_fielddef_isextension(f)) { + PyErr_Format(PyExc_KeyError, "Field %s is not an extension\n", + upb_fielddef_fullname(f)); + return NULL; + } + const upb_msgdef* msgdef = PyUpb_CMessage_GetMsgdef(_self); + if (upb_fielddef_containingtype(f) != msgdef) { + PyErr_Format(PyExc_KeyError, "Extension doesn't match (%s vs %s)", + upb_msgdef_fullname(msgdef), upb_fielddef_fullname(f)); + return NULL; + } + return f; +} + + static PyObject* PyUpb_CMessage_HasExtension(PyObject* _self, PyObject* ext_desc) { upb_msg* msg = PyUpb_CMessage_GetIfReified(_self); - const upb_fielddef* f = PyUpb_FieldDescriptor_GetDef(ext_desc); + const upb_fielddef* f = PyUpb_CMessage_GetExtensionDef(_self, ext_desc); if (!f) return NULL; + if (upb_fielddef_isseq(f)) { + PyErr_SetString(PyExc_KeyError, + "Field is repeated. A singular method is required."); + return NULL; + } if (!msg) Py_RETURN_FALSE; return PyBool_FromLong(upb_msg_has(msg, f)); } diff --git a/python/message.h b/python/message.h index 3c4474743d..f67bf63ffc 100644 --- a/python/message.h +++ b/python/message.h @@ -69,6 +69,11 @@ PyObject* PyUpb_CMessage_SerializeToString(PyObject* self, PyObject* args, int PyUpb_CMessage_InitAttributes(PyObject* _self, PyObject* args, PyObject* kwargs); +// Checks that `key` is a field descriptor for an extension type, and that the +// extendee is this message. Otherwise returns NULL and sets a KeyError. +const upb_fielddef* PyUpb_CMessage_GetExtensionDef(PyObject* _self, + PyObject* key); + // Clears the given field in this message. void PyUpb_CMessage_DoClearField(PyObject* _self, const upb_fielddef* f); diff --git a/python/pb_unit_tests/reflection_test_wrapper.py b/python/pb_unit_tests/reflection_test_wrapper.py index ae9470d90b..e8a5d802e1 100644 --- a/python/pb_unit_tests/reflection_test_wrapper.py +++ b/python/pb_unit_tests/reflection_test_wrapper.py @@ -26,8 +26,6 @@ from google.protobuf.internal import reflection_test import unittest -#reflection_test.Proto2ReflectionTest.testExtensionDelete.__unittest_expecting_failure__ = True -reflection_test.Proto2ReflectionTest.testExtensionFailureModes.__unittest_expecting_failure__ = True reflection_test.Proto2ReflectionTest.testExtensionIter.__unittest_expecting_failure__ = True reflection_test.Proto2ReflectionTest.testIsInitialized.__unittest_expecting_failure__ = True reflection_test.Proto2ReflectionTest.testListFieldsAndExtensions.__unittest_expecting_failure__ = True