Change extension_handle to field_descriptor in python `HasExtension()` and `ClearExtension()`

PiperOrigin-RevId: 499349711
pull/11453/head
Jie Luo 2 years ago committed by Copybara-Service
parent 1aef0a4006
commit f75fd051d6
  1. 30
      python/google/protobuf/internal/python_message.py
  2. 10
      python/google/protobuf/message.py

@ -769,12 +769,12 @@ def _AddPropertiesForExtensions(descriptor, cls):
def _AddStaticMethods(cls):
# TODO(robinson): This probably needs to be thread-safe(?)
def RegisterExtension(extension_handle):
extension_handle.containing_type = cls.DESCRIPTOR
def RegisterExtension(field_descriptor):
field_descriptor.containing_type = cls.DESCRIPTOR
# TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
# pylint: disable=protected-access
cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(extension_handle)
_AttachFieldHelpers(cls, extension_handle)
cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(field_descriptor)
_AttachFieldHelpers(cls, field_descriptor)
cls.RegisterExtension = staticmethod(RegisterExtension)
def FromString(s):
@ -886,28 +886,28 @@ def _AddClearFieldMethod(message_descriptor, cls):
def _AddClearExtensionMethod(cls):
"""Helper for _AddMessageMethods()."""
def ClearExtension(self, extension_handle):
extension_dict._VerifyExtensionHandle(self, extension_handle)
def ClearExtension(self, field_descriptor):
extension_dict._VerifyExtensionHandle(self, field_descriptor)
# Similar to ClearField(), above.
if extension_handle in self._fields:
del self._fields[extension_handle]
if field_descriptor in self._fields:
del self._fields[field_descriptor]
self._Modified()
cls.ClearExtension = ClearExtension
def _AddHasExtensionMethod(cls):
"""Helper for _AddMessageMethods()."""
def HasExtension(self, extension_handle):
extension_dict._VerifyExtensionHandle(self, extension_handle)
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
raise KeyError('"%s" is repeated.' % extension_handle.full_name)
def HasExtension(self, field_descriptor):
extension_dict._VerifyExtensionHandle(self, field_descriptor)
if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
value = self._fields.get(extension_handle)
if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
value = self._fields.get(field_descriptor)
return value is not None and value._is_present_in_parent
else:
return extension_handle in self._fields
return field_descriptor in self._fields
cls.HasExtension = HasExtension
def _InternalUnpackAny(msg):

@ -312,13 +312,13 @@ class Message(object):
"""
raise NotImplementedError
def HasExtension(self, extension_handle):
def HasExtension(self, field_descriptor):
"""Checks if a certain extension is present for this message.
Extensions are retrieved using the :attr:`Extensions` mapping (if present).
Args:
extension_handle: The handle for the extension to check.
field_descriptor: The field descriptor for the extension to check.
Returns:
bool: Whether the extension is present for this message.
@ -330,11 +330,11 @@ class Message(object):
"""
raise NotImplementedError
def ClearExtension(self, extension_handle):
def ClearExtension(self, field_descriptor):
"""Clears the contents of a given extension.
Args:
extension_handle: The handle for the extension to clear.
field_descriptor: The field descriptor for the extension to clear.
"""
raise NotImplementedError
@ -368,7 +368,7 @@ class Message(object):
raise NotImplementedError
@staticmethod
def RegisterExtension(extension_handle):
def RegisterExtension(field_descriptor):
raise NotImplementedError
def _SetListener(self, message_listener):

Loading…
Cancel
Save