From 3781f45f390935001a7f1be6c20171f6fba0f56c Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Wed, 20 Nov 2024 10:06:12 -0800 Subject: [PATCH] Fix a python bug that UPB and Python C++ extension assume MessageSet extensions are ordered first PiperOrigin-RevId: 698430014 --- .../internal/message_set_extensions.proto | 7 ++++++ .../protobuf/internal/text_format_test.py | 17 +++++++++++++ .../google/protobuf/pyext/extension_dict.cc | 25 +++++++++++++------ upb/reflection/def_pool.c | 13 +++++++--- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/python/google/protobuf/internal/message_set_extensions.proto b/python/google/protobuf/internal/message_set_extensions.proto index 17330522e9..33db0585bd 100644 --- a/python/google/protobuf/internal/message_set_extensions.proto +++ b/python/google/protobuf/internal/message_set_extensions.proto @@ -19,6 +19,9 @@ message TestMessageSet { } message TestMessageSetExtension1 { + extend TestMessageSet { + optional TestExtension first_extension = 2534113; + } extend TestMessageSet { optional TestMessageSetExtension1 message_set_extension = 98418603; } @@ -36,6 +39,10 @@ message TestMessageSetExtension3 { optional string text = 35; } +message TestExtension { + optional string str = 1; +} + extend TestMessageSet { optional TestMessageSetExtension3 message_set_extension3 = 98418655; } diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index ec7ccb8406..e44f59703e 100644 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -1411,6 +1411,23 @@ class Proto2Tests(TextFormatBase): ' text: \"bar\"\n' '}\n') + def testMessageSetExtensionNotFirst(self): + desc = message_set_extensions_pb2.TestMessageSetExtension1.DESCRIPTOR + self.assertEqual('first_extension', desc.extensions[0].name) + self.assertEqual('message_set_extension', desc.extensions[1].name) + message = message_set_extensions_pb2.TestMessageSet() + ext = ( + message_set_extensions_pb2.TestMessageSetExtension1.message_set_extension + ) + message.Extensions[ext].i = 123 + expected_str = ( + '[google.protobuf.internal.TestMessageSetExtension1] {\n i: 123\n}\n' + ) + self.CompareToGoldenText(text_format.MessageToString(message), expected_str) + parsed = message_set_extensions_pb2.TestMessageSet() + text_format.Parse(expected_str, parsed) + self.CompareToGoldenText(text_format.MessageToString(parsed), expected_str) + def testPrintMessageSetByFieldNumber(self): out = text_format.TextWriter(False) message = unittest_mset_pb2.TestMessageSetContainer() diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 3f0d722370..9871b5f5e7 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -207,6 +207,21 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { return 0; } +static const FieldDescriptor* FindMessageSetExtension( + const Descriptor* message_descriptor) { + for (int i = 0; i < message_descriptor->extension_count(); i++) { + const FieldDescriptor* extension = message_descriptor->extension(i); + if (extension->is_extension() && + extension->containing_type()->options().message_set_wire_format() && + extension->type() == FieldDescriptor::TYPE_MESSAGE && + extension->label() == FieldDescriptor::LABEL_OPTIONAL && + extension->message_type() == message_descriptor) { + return extension; + } + } + return nullptr; +} + PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { char* name; Py_ssize_t name_size; @@ -221,14 +236,8 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { // Is is the name of a message set extension? const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName(absl::string_view(name, name_size)); - if (message_descriptor && message_descriptor->extension_count() > 0) { - const FieldDescriptor* extension = message_descriptor->extension(0); - if (extension->is_extension() && - extension->containing_type()->options().message_set_wire_format() && - extension->type() == FieldDescriptor::TYPE_MESSAGE && - extension->label() == FieldDescriptor::LABEL_OPTIONAL) { - message_extension = extension; - } + if (message_descriptor) { + message_extension = FindMessageSetExtension(message_descriptor); } } if (message_extension == nullptr) { diff --git a/upb/reflection/def_pool.c b/upb/reflection/def_pool.c index 0250d1b0d2..eac97d0053 100644 --- a/upb/reflection/def_pool.c +++ b/upb/reflection/def_pool.c @@ -12,6 +12,7 @@ #include "upb/hash/str_table.h" #include "upb/mem/alloc.h" #include "upb/mem/arena.h" +#include "upb/reflection/def.h" #include "upb/reflection/def_type.h" #include "upb/reflection/file_def.h" #include "upb/reflection/internal/def_builder.h" @@ -236,9 +237,15 @@ const upb_FieldDef* upb_DefPool_FindExtensionByNameWithSize( return _upb_DefType_Unpack(v, UPB_DEFTYPE_FIELD); case UPB_DEFTYPE_MSG: { const upb_MessageDef* m = _upb_DefType_Unpack(v, UPB_DEFTYPE_MSG); - return _upb_MessageDef_InMessageSet(m) - ? upb_MessageDef_NestedExtension(m, 0) - : NULL; + if (_upb_MessageDef_InMessageSet(m)) { + for (int i = 0; i < upb_MessageDef_NestedExtensionCount(m); i++) { + const upb_FieldDef* ext = upb_MessageDef_NestedExtension(m, i); + if (upb_FieldDef_MessageSubDef(ext) == m) { + return ext; + } + } + } + return NULL; } default: break;