Fix a bug that strips options from descriptor.proto in Pure Python.

GetOptions on fields (which parse the _serialized_options) will be called for the first time of parse or serialize instead of Build time.

Note: GetOptions on messages are still called in Build time because of message_set_wire_format. If message options are needed in descriptor.proto, a parse error will be raised in GetOptions(). We can check the file to not invoke GetOptions() for descriptor.proto as long as message_set_wire_format not needed in descriptor.proto.

Other options except message options do not invoke GetOptions() in Build time

PiperOrigin-RevId: 560741182
pull/13699/head
Jie Luo 1 year ago committed by Copybara-Service
parent d2a2dc9bbc
commit abf5dfbfbc
  1. 32
      python/google/protobuf/descriptor.py
  2. 1
      python/google/protobuf/descriptor_pool.py
  3. 102
      python/google/protobuf/internal/python_message.py
  4. 15
      python/google/protobuf/internal/reflection_test.py
  5. 33
      python/google/protobuf/internal/unknown_fields_test.py

@ -115,16 +115,6 @@ _Deprecated.count = 100
_internal_create_key = object() _internal_create_key = object()
def _IsDescriptorBootstrapProto(file):
"""Checks if the file descriptor corresponds to our bootstrapped descriptor.proto"""
if file is None:
return False
return (
file.name == 'net/proto2/proto/descriptor.proto'
or file.name == 'google/protobuf/descriptor.proto'
)
class DescriptorBase(metaclass=DescriptorMetaclass): class DescriptorBase(metaclass=DescriptorMetaclass):
"""Descriptors base class. """Descriptors base class.
@ -154,9 +144,7 @@ class DescriptorBase(metaclass=DescriptorMetaclass):
self.file = file self.file = file
self._options = options self._options = options
self._options_class_name = options_class_name self._options_class_name = options_class_name
self._serialized_options = ( self._serialized_options = serialized_options
serialized_options if not _IsDescriptorBootstrapProto(file) else None
)
# Does this descriptor have non-default options? # Does this descriptor have non-default options?
self.has_options = (self._options is not None) or ( self.has_options = (self._options is not None) or (
@ -192,14 +180,15 @@ class DescriptorBase(metaclass=DescriptorMetaclass):
raise RuntimeError('Unknown options class name %s!' % raise RuntimeError('Unknown options class name %s!' %
(self._options_class_name)) (self._options_class_name))
with _lock: if self._serialized_options is None:
if self._serialized_options is None: with _lock:
self._options = options_class() self._options = options_class()
else: else:
self._options = _ParseOptions(options_class(), options = _ParseOptions(options_class(), self._serialized_options)
self._serialized_options) with _lock:
self._options = options
return self._options return self._options
class _NestedDescriptorBase(DescriptorBase): class _NestedDescriptorBase(DescriptorBase):
@ -299,6 +288,7 @@ class Descriptor(_NestedDescriptorBase):
oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in
:attr:`oneofs`, but indexed by "name" attribute. :attr:`oneofs`, but indexed by "name" attribute.
file (FileDescriptor): Reference to file descriptor. file (FileDescriptor): Reference to file descriptor.
is_map_entry: If the message type is a map entry.
""" """
@ -324,6 +314,7 @@ class Descriptor(_NestedDescriptorBase):
serialized_start=None, serialized_start=None,
serialized_end=None, serialized_end=None,
syntax=None, syntax=None,
is_map_entry=False,
create_key=None): create_key=None):
_message.Message._CheckCalledFromGeneratedFile() _message.Message._CheckCalledFromGeneratedFile()
return _message.default_pool.FindMessageTypeByName(full_name) return _message.default_pool.FindMessageTypeByName(full_name)
@ -336,7 +327,7 @@ class Descriptor(_NestedDescriptorBase):
serialized_options=None, serialized_options=None,
is_extendable=True, extension_ranges=None, oneofs=None, is_extendable=True, extension_ranges=None, oneofs=None,
file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
syntax=None, create_key=None): syntax=None, is_map_entry=False, create_key=None):
"""Arguments to __init__() are as described in the description """Arguments to __init__() are as described in the description
of Descriptor fields above. of Descriptor fields above.
@ -386,6 +377,7 @@ class Descriptor(_NestedDescriptorBase):
for oneof in self.oneofs: for oneof in self.oneofs:
oneof.containing_type = self oneof.containing_type = self
self.syntax = syntax or "proto2" self.syntax = syntax or "proto2"
self._is_map_entry = is_map_entry
@property @property
def fields_by_camelcase_name(self): def fields_by_camelcase_name(self):

@ -900,6 +900,7 @@ class DescriptorPool(object):
serialized_start=None, serialized_start=None,
serialized_end=None, serialized_end=None,
syntax=syntax, syntax=syntax,
is_map_entry=desc_proto.options.map_entry,
# pylint: disable=protected-access # pylint: disable=protected-access
create_key=descriptor._internal_create_key) create_key=descriptor._internal_create_key)
for nested in desc.nested_types: for nested in desc.nested_types:

@ -183,11 +183,14 @@ class GeneratedProtocolMessageType(type):
% (descriptor.full_name)) % (descriptor.full_name))
return return
cls._decoders_by_tag = {} cls._message_set_decoders_by_tag = {}
cls._fields_by_tag = {}
if (descriptor.has_options and if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format): descriptor.GetOptions().message_set_wire_format):
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
decoder.MessageSetItemDecoder(descriptor), None) decoder.MessageSetItemDecoder(descriptor),
None,
)
# Attach stuff to each FieldDescriptor for quick lookup later on. # Attach stuff to each FieldDescriptor for quick lookup later on.
for field in descriptor.fields: for field in descriptor.fields:
@ -278,16 +281,36 @@ def _IsMessageSetExtension(field):
def _IsMapField(field): def _IsMapField(field):
return (field.type == _FieldDescriptor.TYPE_MESSAGE and return (field.type == _FieldDescriptor.TYPE_MESSAGE and
field.message_type.has_options and field.message_type._is_map_entry)
field.message_type.GetOptions().map_entry)
def _IsMessageMapField(field): def _IsMessageMapField(field):
value_type = field.message_type.fields_by_name['value'] value_type = field.message_type.fields_by_name['value']
return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
def _AttachFieldHelpers(cls, field_descriptor): def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
field_descriptor._default_constructor = _DefaultValueConstructorForField(
field_descriptor
)
def AddFieldByTag(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
AddFieldByTag(
type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
)
if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
# To support wire compatibility of adding packed = true, add a decoder for
# packed values regardless of the field's options.
AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
def _MaybeAddEncoder(cls, field_descriptor):
if hasattr(field_descriptor, '_encoder'):
return
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
is_map_entry = _IsMapField(field_descriptor) is_map_entry = _IsMapField(field_descriptor)
is_packed = field_descriptor.is_packed is_packed = field_descriptor.is_packed
@ -307,11 +330,17 @@ def _AttachFieldHelpers(cls, field_descriptor):
field_descriptor._encoder = field_encoder field_descriptor._encoder = field_encoder
field_descriptor._sizer = sizer field_descriptor._sizer = sizer
field_descriptor._default_constructor = _DefaultValueConstructorForField(
field_descriptor)
def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) def _MaybeAddDecoder(cls, field_descriptor):
if hasattr(field_descriptor, '_decoders'):
return
is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
is_map_entry = _IsMapField(field_descriptor)
field_descriptor._decoders = {}
def AddDecoder(is_packed):
decode_type = field_descriptor.type decode_type = field_descriptor.type
if (decode_type == _FieldDescriptor.TYPE_ENUM and if (decode_type == _FieldDescriptor.TYPE_ENUM and
not field_descriptor.enum_type.is_closed): not field_descriptor.enum_type.is_closed):
@ -343,15 +372,14 @@ def _AttachFieldHelpers(cls, field_descriptor):
field_descriptor, field_descriptor._default_constructor, field_descriptor, field_descriptor._default_constructor,
not field_descriptor.has_presence) not field_descriptor.has_presence)
cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) field_descriptor._decoders[is_packed] = field_decoder
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], AddDecoder(False)
False)
if is_repeated and wire_format.IsTypePackable(field_descriptor.type): if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
# To support wire compatibility of adding packed = true, add a decoder for # To support wire compatibility of adding packed = true, add a decoder for
# packed values regardless of the field's options. # packed values regardless of the field's options.
AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) AddDecoder(True)
def _AddClassAttributesForNestedExtensions(descriptor, dictionary): def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
@ -1035,12 +1063,17 @@ def _AddByteSizeMethod(message_descriptor, cls):
size = 0 size = 0
descriptor = self.DESCRIPTOR descriptor = self.DESCRIPTOR
if descriptor.GetOptions().map_entry: if descriptor._is_map_entry:
# Fields of map entry should always be serialized. # Fields of map entry should always be serialized.
size = descriptor.fields_by_name['key']._sizer(self.key) key_field = descriptor.fields_by_name['key']
size += descriptor.fields_by_name['value']._sizer(self.value) _MaybeAddEncoder(cls, key_field)
size = key_field._sizer(self.key)
value_field = descriptor.fields_by_name['value']
_MaybeAddEncoder(cls, value_field)
size += value_field._sizer(self.value)
else: else:
for field_descriptor, field_value in self.ListFields(): for field_descriptor, field_value in self.ListFields():
_MaybeAddEncoder(cls, field_descriptor)
size += field_descriptor._sizer(field_value) size += field_descriptor._sizer(field_value)
for tag_bytes, value_bytes in self._unknown_fields: for tag_bytes, value_bytes in self._unknown_fields:
size += len(tag_bytes) + len(value_bytes) size += len(tag_bytes) + len(value_bytes)
@ -1083,14 +1116,17 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
deterministic = bool(deterministic) deterministic = bool(deterministic)
descriptor = self.DESCRIPTOR descriptor = self.DESCRIPTOR
if descriptor.GetOptions().map_entry: if descriptor._is_map_entry:
# Fields of map entry should always be serialized. # Fields of map entry should always be serialized.
descriptor.fields_by_name['key']._encoder( key_field = descriptor.fields_by_name['key']
write_bytes, self.key, deterministic) _MaybeAddEncoder(cls, key_field)
descriptor.fields_by_name['value']._encoder( key_field._encoder(write_bytes, self.key, deterministic)
write_bytes, self.value, deterministic) value_field = descriptor.fields_by_name['value']
_MaybeAddEncoder(cls, value_field)
value_field._encoder(write_bytes, self.value, deterministic)
else: else:
for field_descriptor, field_value in self.ListFields(): for field_descriptor, field_value in self.ListFields():
_MaybeAddEncoder(cls, field_descriptor)
field_descriptor._encoder(write_bytes, field_value, deterministic) field_descriptor._encoder(write_bytes, field_value, deterministic)
for tag_bytes, value_bytes in self._unknown_fields: for tag_bytes, value_bytes in self._unknown_fields:
write_bytes(tag_bytes) write_bytes(tag_bytes)
@ -1118,7 +1154,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag fields_by_tag = cls._fields_by_tag
message_set_decoders_by_tag = cls._message_set_decoders_by_tag
def InternalParse(self, buffer, pos, end): def InternalParse(self, buffer, pos, end):
"""Create a message from serialized bytes. """Create a message from serialized bytes.
@ -1141,8 +1178,14 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
unknown_field_set = self._unknown_field_set unknown_field_set = self._unknown_field_set
while pos != end: while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos) (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) field_decoder, field_des = message_set_decoders_by_tag.get(
if field_decoder is None: tag_bytes, (None, None)
)
if field_decoder:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
continue
field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
if field_des is None:
if not self._unknown_fields: # pylint: disable=protected-access if not self._unknown_fields: # pylint: disable=protected-access
self._unknown_fields = [] # pylint: disable=protected-access self._unknown_fields = [] # pylint: disable=protected-access
if unknown_field_set is None: if unknown_field_set is None:
@ -1171,9 +1214,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
(tag_bytes, buffer[old_pos:new_pos].tobytes())) (tag_bytes, buffer[old_pos:new_pos].tobytes()))
pos = new_pos pos = new_pos
else: else:
_MaybeAddDecoder(cls, field_des)
field_decoder = field_des._decoders[is_packed]
pos = field_decoder(buffer, new_pos, end, self, field_dict) pos = field_decoder(buffer, new_pos, end, self, field_dict)
if field_desc: if field_des.containing_oneof:
self._UpdateOneofState(field_desc) self._UpdateOneofState(field_des)
return pos return pos
cls._InternalParse = InternalParse cls._InternalParse = InternalParse
@ -1209,8 +1254,7 @@ def _AddIsInitializedMethod(message_descriptor, cls):
for field, value in list(self._fields.items()): # dict can change size! for field, value in list(self._fields.items()): # dict can change size!
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED: if field.label == _FieldDescriptor.LABEL_REPEATED:
if (field.message_type.has_options and if (field.message_type._is_map_entry):
field.message_type.GetOptions().map_entry):
continue continue
for element in value: for element in value:
if not element.IsInitialized(): if not element.IsInitialized():

@ -2053,11 +2053,6 @@ class Proto2ReflectionTest(unittest.TestCase):
# dependency on the C++ logging code. # dependency on the C++ logging code.
self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
@unittest.skipIf(
api_implementation.Type() == 'python',
'Options are not supported on descriptor.proto in pure python'
' (b/296476238).',
)
def testDescriptorProtoHasFileOptions(self): def testDescriptorProtoHasFileOptions(self):
self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
self.assertEqual( self.assertEqual(
@ -2065,11 +2060,6 @@ class Proto2ReflectionTest(unittest.TestCase):
'com.google.protobuf', 'com.google.protobuf',
) )
@unittest.skipIf(
api_implementation.Type() == 'python',
'Options are not supported on descriptor.proto in pure python'
' (b/296476238).',
)
def testDescriptorProtoHasFieldOptions(self): def testDescriptorProtoHasFieldOptions(self):
self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
self.assertEqual( self.assertEqual(
@ -2084,11 +2074,6 @@ class Proto2ReflectionTest(unittest.TestCase):
self.assertTrue(packed_desc.has_options) self.assertTrue(packed_desc.has_options)
self.assertTrue(packed_desc.GetOptions().packed) self.assertTrue(packed_desc.GetOptions().packed)
@unittest.skipIf(
api_implementation.Type() == 'python',
'Options are not supported on descriptor.proto in pure python'
' (b/296476238).',
)
def testDescriptorProtoHasFeatureOptions(self): def testDescriptorProtoHasFeatureOptions(self):
self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
self.assertEqual( self.assertEqual(

@ -177,25 +177,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data) self.empty_message.ParseFromString(self.all_fields_data)
# InternalCheckUnknownField() is an additional Pure Python check which checks
# a detail of unknown fields. It cannot be used by the C++
# implementation because some protect members are called.
# The test is added for historical reasons. It is not necessary as
# serialized string is checked.
# TODO(jieluo): Remove message._unknown_fields.
def InternalCheckUnknownField(self, name, expected_value):
if api_implementation.Type() != 'python':
return
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
result_dict = {}
for tag_bytes, value in self.empty_message._unknown_fields:
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
self.assertEqual(expected_value, result_dict[field_descriptor])
def CheckUnknownField(self, name, unknown_field_set, expected_value): def CheckUnknownField(self, name, unknown_field_set, expected_value):
field_descriptor = self.descriptor.fields_by_name[name] field_descriptor = self.descriptor.fields_by_name[name]
expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
@ -223,50 +204,36 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.CheckUnknownField('optional_nested_enum', self.CheckUnknownField('optional_nested_enum',
unknown_field_set, unknown_field_set,
self.all_fields.optional_nested_enum) self.all_fields.optional_nested_enum)
self.InternalCheckUnknownField('optional_nested_enum',
self.all_fields.optional_nested_enum)
# Test repeated enum. # Test repeated enum.
self.CheckUnknownField('repeated_nested_enum', self.CheckUnknownField('repeated_nested_enum',
unknown_field_set, unknown_field_set,
self.all_fields.repeated_nested_enum) self.all_fields.repeated_nested_enum)
self.InternalCheckUnknownField('repeated_nested_enum',
self.all_fields.repeated_nested_enum)
# Test varint. # Test varint.
self.CheckUnknownField('optional_int32', self.CheckUnknownField('optional_int32',
unknown_field_set, unknown_field_set,
self.all_fields.optional_int32) self.all_fields.optional_int32)
self.InternalCheckUnknownField('optional_int32',
self.all_fields.optional_int32)
# Test fixed32. # Test fixed32.
self.CheckUnknownField('optional_fixed32', self.CheckUnknownField('optional_fixed32',
unknown_field_set, unknown_field_set,
self.all_fields.optional_fixed32) self.all_fields.optional_fixed32)
self.InternalCheckUnknownField('optional_fixed32',
self.all_fields.optional_fixed32)
# Test fixed64. # Test fixed64.
self.CheckUnknownField('optional_fixed64', self.CheckUnknownField('optional_fixed64',
unknown_field_set, unknown_field_set,
self.all_fields.optional_fixed64) self.all_fields.optional_fixed64)
self.InternalCheckUnknownField('optional_fixed64',
self.all_fields.optional_fixed64)
# Test length delimited. # Test length delimited.
self.CheckUnknownField('optional_string', self.CheckUnknownField('optional_string',
unknown_field_set, unknown_field_set,
self.all_fields.optional_string.encode('utf-8')) self.all_fields.optional_string.encode('utf-8'))
self.InternalCheckUnknownField('optional_string',
self.all_fields.optional_string)
# Test group. # Test group.
self.CheckUnknownField('optionalgroup', self.CheckUnknownField('optionalgroup',
unknown_field_set, unknown_field_set,
(17, 0, 117)) (17, 0, 117))
self.InternalCheckUnknownField('optionalgroup',
self.all_fields.optionalgroup)
self.assertEqual(98, len(unknown_field_set)) self.assertEqual(98, len(unknown_field_set))

Loading…
Cancel
Save