Comparing a proto message with an object of unknown returns NotImplemented

The Python comparison protocol requires that if an object doesn't know how to
compare itself to an object of a different type, it returns NotImplemented
rather than False. The interpreter will then try performing the comparison using
the other operand. This translates, for protos, to:
If a proto message doesn't know how to compare itself to an object of
non-message type, it returns NotImplemented. This way, the interpreter will then
try performing the comparison using the comparison methods of the other object,
which may know how to compare itself to a message. If not, then Python will
return the combined result (e.g., if both objects don't know how to perform
__eq__, then the equality operator `==` return false).
This change allows one to compare a proto with custom matchers such as mock.ANY
that the message doesn't know how to compare to, regardless of whether
mock.ANY is on the right-hand side or left-hand side of the equality (prior to
this change, it only worked with mock.ANY on the left-hand side).

Fixes https://github.com/protocolbuffers/protobuf/issues/9173

PiperOrigin-RevId: 561728156
pull/13816/head
Protobuf Team Bot 1 year ago committed by Copybara-Service
parent 14222b30f7
commit 12d4f418a7
  1. 37
      python/google/protobuf/internal/message_test.py
  2. 2
      python/google/protobuf/internal/python_message.py
  3. 35
      python/google/protobuf/pyext/message.cc

@ -46,6 +46,7 @@ import pickle
import pydoc
import sys
import unittest
from unittest import mock
import warnings
cmp = lambda x, y: (x > y) - (x < y)
@ -1268,6 +1269,42 @@ class MessageTest(unittest.TestCase):
self.assertEqual(bool, type(m.repeated_bool[0]))
self.assertEqual(True, m.repeated_bool[0])
def testEquality(self, message_module):
m = message_module.TestAllTypes()
m2 = message_module.TestAllTypes()
self.assertEqual(m, m)
self.assertEqual(m, m2)
self.assertEqual(m2, m)
different_m = message_module.TestAllTypes()
different_m.repeated_float.append(1)
self.assertNotEqual(m, different_m)
self.assertNotEqual(different_m, m)
self.assertIsNotNone(m)
self.assertIsNotNone(m)
self.assertNotEqual(42, m)
self.assertNotEqual(m, 42)
self.assertNotEqual('foo', m)
self.assertNotEqual(m, 'foo')
self.assertEqual(mock.ANY, m)
self.assertEqual(m, mock.ANY)
class ComparesWithFoo(object):
def __eq__(self, other):
if getattr(other, 'optional_string', 'not_foo') == 'foo':
return True
return NotImplemented
m.optional_string = 'foo'
self.assertEqual(m, ComparesWithFoo())
self.assertEqual(ComparesWithFoo(), m)
m.optional_string = 'bar'
self.assertNotEqual(m, ComparesWithFoo())
self.assertNotEqual(ComparesWithFoo(), m)
# Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase

@ -987,7 +987,7 @@ def _AddEqualsMethod(message_descriptor, cls):
def __eq__(self, other):
if (not isinstance(other, message_mod.Message) or
other.DESCRIPTOR != self.DESCRIPTOR):
return False
return NotImplemented
if self is other:
return True

@ -2045,25 +2045,26 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
bool equals = true;
// If other is not a message, it cannot be equal.
// If other is not a message, this implementation doesn't know how to perform
// comparisons.
if (!PyObject_TypeCheck(other, CMessage_Type)) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
// Otherwise, we have a CMessage whose message we can inspect.
bool equals = true;
const google::protobuf::Message* other_message =
reinterpret_cast<CMessage*>(other)->message;
// If messages don't have the same descriptors, they are not equal.
if (equals &&
self->message->GetDescriptor() != other_message->GetDescriptor()) {
equals = false;
}
// Check the message contents.
if (equals &&
!google::protobuf::util::MessageDifferencer::Equals(
*self->message, *reinterpret_cast<CMessage*>(other)->message)) {
equals = false;
} else {
// Otherwise, we have a CMessage whose message we can inspect.
const google::protobuf::Message* other_message =
reinterpret_cast<CMessage*>(other)->message;
// If messages don't have the same descriptors, they are not equal.
if (equals &&
self->message->GetDescriptor() != other_message->GetDescriptor()) {
equals = false;
}
// Check the message contents.
if (equals &&
!google::protobuf::util::MessageDifferencer::Equals(
*self->message, *reinterpret_cast<CMessage*>(other)->message)) {
equals = false;
}
}
if (equals ^ (opid == Py_EQ)) {

Loading…
Cancel
Save