From abf5dfbfbc7a1a8131b5a317acc25b26d4956d1a Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Mon, 28 Aug 2023 09:58:39 -0700 Subject: [PATCH] Fix a bug that strips options from descriptor.proto in Pure Python. GetOptions on fields (which parse the _serialized_options) will be called for the first time of parse or serialize instead of Build time. Note: GetOptions on messages are still called in Build time because of message_set_wire_format. If message options are needed in descriptor.proto, a parse error will be raised in GetOptions(). We can check the file to not invoke GetOptions() for descriptor.proto as long as message_set_wire_format not needed in descriptor.proto. Other options except message options do not invoke GetOptions() in Build time PiperOrigin-RevId: 560741182 --- python/google/protobuf/descriptor.py | 32 +++--- python/google/protobuf/descriptor_pool.py | 1 + .../protobuf/internal/python_message.py | 102 +++++++++++++----- .../protobuf/internal/reflection_test.py | 15 --- .../protobuf/internal/unknown_fields_test.py | 33 ------ 5 files changed, 86 insertions(+), 97 deletions(-) diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index fb9632b539..ea8bc22ca2 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -115,16 +115,6 @@ _Deprecated.count = 100 _internal_create_key = object() -def _IsDescriptorBootstrapProto(file): - """Checks if the file descriptor corresponds to our bootstrapped descriptor.proto""" - if file is None: - return False - return ( - file.name == 'net/proto2/proto/descriptor.proto' - or file.name == 'google/protobuf/descriptor.proto' - ) - - class DescriptorBase(metaclass=DescriptorMetaclass): """Descriptors base class. @@ -154,9 +144,7 @@ class DescriptorBase(metaclass=DescriptorMetaclass): self.file = file self._options = options self._options_class_name = options_class_name - self._serialized_options = ( - serialized_options if not _IsDescriptorBootstrapProto(file) else None - ) + self._serialized_options = serialized_options # Does this descriptor have non-default options? self.has_options = (self._options is not None) or ( @@ -192,14 +180,15 @@ class DescriptorBase(metaclass=DescriptorMetaclass): raise RuntimeError('Unknown options class name %s!' % (self._options_class_name)) - with _lock: - if self._serialized_options is None: + if self._serialized_options is None: + with _lock: self._options = options_class() - else: - self._options = _ParseOptions(options_class(), - self._serialized_options) + else: + options = _ParseOptions(options_class(), self._serialized_options) + with _lock: + self._options = options - return self._options + return self._options class _NestedDescriptorBase(DescriptorBase): @@ -299,6 +288,7 @@ class Descriptor(_NestedDescriptorBase): oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in :attr:`oneofs`, but indexed by "name" attribute. file (FileDescriptor): Reference to file descriptor. + is_map_entry: If the message type is a map entry. """ @@ -324,6 +314,7 @@ class Descriptor(_NestedDescriptorBase): serialized_start=None, serialized_end=None, syntax=None, + is_map_entry=False, create_key=None): _message.Message._CheckCalledFromGeneratedFile() return _message.default_pool.FindMessageTypeByName(full_name) @@ -336,7 +327,7 @@ class Descriptor(_NestedDescriptorBase): serialized_options=None, is_extendable=True, extension_ranges=None, oneofs=None, file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin - syntax=None, create_key=None): + syntax=None, is_map_entry=False, create_key=None): """Arguments to __init__() are as described in the description of Descriptor fields above. @@ -386,6 +377,7 @@ class Descriptor(_NestedDescriptorBase): for oneof in self.oneofs: oneof.containing_type = self self.syntax = syntax or "proto2" + self._is_map_entry = is_map_entry @property def fields_by_camelcase_name(self): diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index c1471b488e..8ba6121214 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -900,6 +900,7 @@ class DescriptorPool(object): serialized_start=None, serialized_end=None, syntax=syntax, + is_map_entry=desc_proto.options.map_entry, # pylint: disable=protected-access create_key=descriptor._internal_create_key) for nested in desc.nested_types: diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 7e2cbc65d3..d6b36a620a 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -183,11 +183,14 @@ class GeneratedProtocolMessageType(type): % (descriptor.full_name)) return - cls._decoders_by_tag = {} + cls._message_set_decoders_by_tag = {} + cls._fields_by_tag = {} if (descriptor.has_options and descriptor.GetOptions().message_set_wire_format): - cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( - decoder.MessageSetItemDecoder(descriptor), None) + cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(descriptor), + None, + ) # Attach stuff to each FieldDescriptor for quick lookup later on. for field in descriptor.fields: @@ -278,16 +281,36 @@ def _IsMessageSetExtension(field): def _IsMapField(field): return (field.type == _FieldDescriptor.TYPE_MESSAGE and - field.message_type.has_options and - field.message_type.GetOptions().map_entry) + field.message_type._is_map_entry) def _IsMessageMapField(field): value_type = field.message_type.fields_by_name['value'] return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE - def _AttachFieldHelpers(cls, field_descriptor): + is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED + field_descriptor._default_constructor = _DefaultValueConstructorForField( + field_descriptor + ) + + def AddFieldByTag(wiretype, is_packed): + tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed) + + AddFieldByTag( + type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False + ) + + if is_repeated and wire_format.IsTypePackable(field_descriptor.type): + # To support wire compatibility of adding packed = true, add a decoder for + # packed values regardless of the field's options. + AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + + +def _MaybeAddEncoder(cls, field_descriptor): + if hasattr(field_descriptor, '_encoder'): + return is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) is_map_entry = _IsMapField(field_descriptor) is_packed = field_descriptor.is_packed @@ -307,11 +330,17 @@ def _AttachFieldHelpers(cls, field_descriptor): field_descriptor._encoder = field_encoder field_descriptor._sizer = sizer - field_descriptor._default_constructor = _DefaultValueConstructorForField( - field_descriptor) - def AddDecoder(wiretype, is_packed): - tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + +def _MaybeAddDecoder(cls, field_descriptor): + if hasattr(field_descriptor, '_decoders'): + return + + is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED + is_map_entry = _IsMapField(field_descriptor) + field_descriptor._decoders = {} + + def AddDecoder(is_packed): decode_type = field_descriptor.type if (decode_type == _FieldDescriptor.TYPE_ENUM and not field_descriptor.enum_type.is_closed): @@ -343,15 +372,14 @@ def _AttachFieldHelpers(cls, field_descriptor): field_descriptor, field_descriptor._default_constructor, not field_descriptor.has_presence) - cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) + field_descriptor._decoders[is_packed] = field_decoder - AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], - False) + AddDecoder(False) if is_repeated and wire_format.IsTypePackable(field_descriptor.type): # To support wire compatibility of adding packed = true, add a decoder for # packed values regardless of the field's options. - AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + AddDecoder(True) def _AddClassAttributesForNestedExtensions(descriptor, dictionary): @@ -1035,12 +1063,17 @@ def _AddByteSizeMethod(message_descriptor, cls): size = 0 descriptor = self.DESCRIPTOR - if descriptor.GetOptions().map_entry: + if descriptor._is_map_entry: # Fields of map entry should always be serialized. - size = descriptor.fields_by_name['key']._sizer(self.key) - size += descriptor.fields_by_name['value']._sizer(self.value) + key_field = descriptor.fields_by_name['key'] + _MaybeAddEncoder(cls, key_field) + size = key_field._sizer(self.key) + value_field = descriptor.fields_by_name['value'] + _MaybeAddEncoder(cls, value_field) + size += value_field._sizer(self.value) else: for field_descriptor, field_value in self.ListFields(): + _MaybeAddEncoder(cls, field_descriptor) size += field_descriptor._sizer(field_value) for tag_bytes, value_bytes in self._unknown_fields: size += len(tag_bytes) + len(value_bytes) @@ -1083,14 +1116,17 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): deterministic = bool(deterministic) descriptor = self.DESCRIPTOR - if descriptor.GetOptions().map_entry: + if descriptor._is_map_entry: # Fields of map entry should always be serialized. - descriptor.fields_by_name['key']._encoder( - write_bytes, self.key, deterministic) - descriptor.fields_by_name['value']._encoder( - write_bytes, self.value, deterministic) + key_field = descriptor.fields_by_name['key'] + _MaybeAddEncoder(cls, key_field) + key_field._encoder(write_bytes, self.key, deterministic) + value_field = descriptor.fields_by_name['value'] + _MaybeAddEncoder(cls, value_field) + value_field._encoder(write_bytes, self.value, deterministic) else: for field_descriptor, field_value in self.ListFields(): + _MaybeAddEncoder(cls, field_descriptor) field_descriptor._encoder(write_bytes, field_value, deterministic) for tag_bytes, value_bytes in self._unknown_fields: write_bytes(tag_bytes) @@ -1118,7 +1154,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_ReadTag = decoder.ReadTag local_SkipField = decoder.SkipField - decoders_by_tag = cls._decoders_by_tag + fields_by_tag = cls._fields_by_tag + message_set_decoders_by_tag = cls._message_set_decoders_by_tag def InternalParse(self, buffer, pos, end): """Create a message from serialized bytes. @@ -1141,8 +1178,14 @@ def _AddMergeFromStringMethod(message_descriptor, cls): unknown_field_set = self._unknown_field_set while pos != end: (tag_bytes, new_pos) = local_ReadTag(buffer, pos) - field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) - if field_decoder is None: + field_decoder, field_des = message_set_decoders_by_tag.get( + tag_bytes, (None, None) + ) + if field_decoder: + pos = field_decoder(buffer, new_pos, end, self, field_dict) + continue + field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None)) + if field_des is None: if not self._unknown_fields: # pylint: disable=protected-access self._unknown_fields = [] # pylint: disable=protected-access if unknown_field_set is None: @@ -1171,9 +1214,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls): (tag_bytes, buffer[old_pos:new_pos].tobytes())) pos = new_pos else: + _MaybeAddDecoder(cls, field_des) + field_decoder = field_des._decoders[is_packed] pos = field_decoder(buffer, new_pos, end, self, field_dict) - if field_desc: - self._UpdateOneofState(field_desc) + if field_des.containing_oneof: + self._UpdateOneofState(field_des) return pos cls._InternalParse = InternalParse @@ -1209,8 +1254,7 @@ def _AddIsInitializedMethod(message_descriptor, cls): for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: - if (field.message_type.has_options and - field.message_type.GetOptions().map_entry): + if (field.message_type._is_map_entry): continue for element in value: if not element.IsInitialized(): diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index e038927b68..c5a600fa96 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -2053,11 +2053,6 @@ class Proto2ReflectionTest(unittest.TestCase): # dependency on the C++ logging code. self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) - @unittest.skipIf( - api_implementation.Type() == 'python', - 'Options are not supported on descriptor.proto in pure python' - ' (b/296476238).', - ) def testDescriptorProtoHasFileOptions(self): self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) self.assertEqual( @@ -2065,11 +2060,6 @@ class Proto2ReflectionTest(unittest.TestCase): 'com.google.protobuf', ) - @unittest.skipIf( - api_implementation.Type() == 'python', - 'Options are not supported on descriptor.proto in pure python' - ' (b/296476238).', - ) def testDescriptorProtoHasFieldOptions(self): self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) self.assertEqual( @@ -2084,11 +2074,6 @@ class Proto2ReflectionTest(unittest.TestCase): self.assertTrue(packed_desc.has_options) self.assertTrue(packed_desc.GetOptions().packed) - @unittest.skipIf( - api_implementation.Type() == 'python', - 'Options are not supported on descriptor.proto in pure python' - ' (b/296476238).', - ) def testDescriptorProtoHasFeatureOptions(self): self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) self.assertEqual( diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index ec1aa1b457..9a8d7d751a 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -177,25 +177,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - # InternalCheckUnknownField() is an additional Pure Python check which checks - # a detail of unknown fields. It cannot be used by the C++ - # implementation because some protect members are called. - # The test is added for historical reasons. It is not necessary as - # serialized string is checked. - # TODO(jieluo): Remove message._unknown_fields. - def InternalCheckUnknownField(self, name, expected_value): - if api_implementation.Type() != 'python': - return - field_descriptor = self.descriptor.fields_by_name[name] - wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] - field_tag = encoder.TagBytes(field_descriptor.number, wire_type) - result_dict = {} - for tag_bytes, value in self.empty_message._unknown_fields: - if tag_bytes == field_tag: - decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] - decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) - self.assertEqual(expected_value, result_dict[field_descriptor]) - def CheckUnknownField(self, name, unknown_field_set, expected_value): field_descriptor = self.descriptor.fields_by_name[name] expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ @@ -223,50 +204,36 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.CheckUnknownField('optional_nested_enum', unknown_field_set, self.all_fields.optional_nested_enum) - self.InternalCheckUnknownField('optional_nested_enum', - self.all_fields.optional_nested_enum) # Test repeated enum. self.CheckUnknownField('repeated_nested_enum', unknown_field_set, self.all_fields.repeated_nested_enum) - self.InternalCheckUnknownField('repeated_nested_enum', - self.all_fields.repeated_nested_enum) # Test varint. self.CheckUnknownField('optional_int32', unknown_field_set, self.all_fields.optional_int32) - self.InternalCheckUnknownField('optional_int32', - self.all_fields.optional_int32) # Test fixed32. self.CheckUnknownField('optional_fixed32', unknown_field_set, self.all_fields.optional_fixed32) - self.InternalCheckUnknownField('optional_fixed32', - self.all_fields.optional_fixed32) # Test fixed64. self.CheckUnknownField('optional_fixed64', unknown_field_set, self.all_fields.optional_fixed64) - self.InternalCheckUnknownField('optional_fixed64', - self.all_fields.optional_fixed64) # Test length delimited. self.CheckUnknownField('optional_string', unknown_field_set, self.all_fields.optional_string.encode('utf-8')) - self.InternalCheckUnknownField('optional_string', - self.all_fields.optional_string) # Test group. self.CheckUnknownField('optionalgroup', unknown_field_set, (17, 0, 117)) - self.InternalCheckUnknownField('optionalgroup', - self.all_fields.optionalgroup) self.assertEqual(98, len(unknown_field_set))