Fix a python bug that UPB and Python C++ extension assume MessageSet extensions are ordered first

PiperOrigin-RevId: 698430014
pull/19323/head
Jie Luo 4 months ago committed by Copybara-Service
parent f5a293768f
commit 3781f45f39
  1. 7
      python/google/protobuf/internal/message_set_extensions.proto
  2. 17
      python/google/protobuf/internal/text_format_test.py
  3. 25
      python/google/protobuf/pyext/extension_dict.cc
  4. 13
      upb/reflection/def_pool.c

@ -19,6 +19,9 @@ message TestMessageSet {
}
message TestMessageSetExtension1 {
extend TestMessageSet {
optional TestExtension first_extension = 2534113;
}
extend TestMessageSet {
optional TestMessageSetExtension1 message_set_extension = 98418603;
}
@ -36,6 +39,10 @@ message TestMessageSetExtension3 {
optional string text = 35;
}
message TestExtension {
optional string str = 1;
}
extend TestMessageSet {
optional TestMessageSetExtension3 message_set_extension3 = 98418655;
}

@ -1411,6 +1411,23 @@ class Proto2Tests(TextFormatBase):
' text: \"bar\"\n'
'}\n')
def testMessageSetExtensionNotFirst(self):
desc = message_set_extensions_pb2.TestMessageSetExtension1.DESCRIPTOR
self.assertEqual('first_extension', desc.extensions[0].name)
self.assertEqual('message_set_extension', desc.extensions[1].name)
message = message_set_extensions_pb2.TestMessageSet()
ext = (
message_set_extensions_pb2.TestMessageSetExtension1.message_set_extension
)
message.Extensions[ext].i = 123
expected_str = (
'[google.protobuf.internal.TestMessageSetExtension1] {\n i: 123\n}\n'
)
self.CompareToGoldenText(text_format.MessageToString(message), expected_str)
parsed = message_set_extensions_pb2.TestMessageSet()
text_format.Parse(expected_str, parsed)
self.CompareToGoldenText(text_format.MessageToString(parsed), expected_str)
def testPrintMessageSetByFieldNumber(self):
out = text_format.TextWriter(False)
message = unittest_mset_pb2.TestMessageSetContainer()

@ -207,6 +207,21 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
return 0;
}
static const FieldDescriptor* FindMessageSetExtension(
const Descriptor* message_descriptor) {
for (int i = 0; i < message_descriptor->extension_count(); i++) {
const FieldDescriptor* extension = message_descriptor->extension(i);
if (extension->is_extension() &&
extension->containing_type()->options().message_set_wire_format() &&
extension->type() == FieldDescriptor::TYPE_MESSAGE &&
extension->label() == FieldDescriptor::LABEL_OPTIONAL &&
extension->message_type() == message_descriptor) {
return extension;
}
}
return nullptr;
}
PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
char* name;
Py_ssize_t name_size;
@ -221,14 +236,8 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
// Is is the name of a message set extension?
const Descriptor* message_descriptor =
pool->pool->FindMessageTypeByName(absl::string_view(name, name_size));
if (message_descriptor && message_descriptor->extension_count() > 0) {
const FieldDescriptor* extension = message_descriptor->extension(0);
if (extension->is_extension() &&
extension->containing_type()->options().message_set_wire_format() &&
extension->type() == FieldDescriptor::TYPE_MESSAGE &&
extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
message_extension = extension;
}
if (message_descriptor) {
message_extension = FindMessageSetExtension(message_descriptor);
}
}
if (message_extension == nullptr) {

@ -12,6 +12,7 @@
#include "upb/hash/str_table.h"
#include "upb/mem/alloc.h"
#include "upb/mem/arena.h"
#include "upb/reflection/def.h"
#include "upb/reflection/def_type.h"
#include "upb/reflection/file_def.h"
#include "upb/reflection/internal/def_builder.h"
@ -236,9 +237,15 @@ const upb_FieldDef* upb_DefPool_FindExtensionByNameWithSize(
return _upb_DefType_Unpack(v, UPB_DEFTYPE_FIELD);
case UPB_DEFTYPE_MSG: {
const upb_MessageDef* m = _upb_DefType_Unpack(v, UPB_DEFTYPE_MSG);
return _upb_MessageDef_InMessageSet(m)
? upb_MessageDef_NestedExtension(m, 0)
: NULL;
if (_upb_MessageDef_InMessageSet(m)) {
for (int i = 0; i < upb_MessageDef_NestedExtensionCount(m); i++) {
const upb_FieldDef* ext = upb_MessageDef_NestedExtension(m, i);
if (upb_FieldDef_MessageSubDef(ext) == m) {
return ext;
}
}
}
return NULL;
}
default:
break;

Loading…
Cancel
Save