Comparing a proto message with an object of unknown returns NotImplemented

The Python comparison protocol requires that if an object doesn't know how to
compare itself to an object of a different type, it returns NotImplemented
rather than False. The interpreter will then try performing the comparison using
the other operand. This translates, for protos, to:
If a proto message doesn't know how to compare itself to an object of
non-message type, it returns NotImplemented. This way, the interpreter will then
try performing the comparison using the comparison methods of the other object,
which may know how to compare itself to a message. If not, then Python will
return the combined result (e.g., if both objects don't know how to perform
__eq__, then the equality operator `==` return false).
This change allows one to compare a proto with custom matchers such as mock.ANY
that the message doesn't know how to compare to, regardless of whether
mock.ANY is on the right-hand side or left-hand side of the equality (prior to
this change, it only worked with mock.ANY on the left-hand side).

Fixes https://github.com/protocolbuffers/protobuf/issues/9173

PiperOrigin-RevId: 561728156
pull/13816/head
Protobuf Team Bot 1 year ago committed by Copybara-Service
parent 14222b30f7
commit 12d4f418a7
  1. 37
      python/google/protobuf/internal/message_test.py
  2. 2
      python/google/protobuf/internal/python_message.py
  3. 35
      python/google/protobuf/pyext/message.cc

@ -46,6 +46,7 @@ import pickle
import pydoc import pydoc
import sys import sys
import unittest import unittest
from unittest import mock
import warnings import warnings
cmp = lambda x, y: (x > y) - (x < y) cmp = lambda x, y: (x > y) - (x < y)
@ -1268,6 +1269,42 @@ class MessageTest(unittest.TestCase):
self.assertEqual(bool, type(m.repeated_bool[0])) self.assertEqual(bool, type(m.repeated_bool[0]))
self.assertEqual(True, m.repeated_bool[0]) self.assertEqual(True, m.repeated_bool[0])
def testEquality(self, message_module):
m = message_module.TestAllTypes()
m2 = message_module.TestAllTypes()
self.assertEqual(m, m)
self.assertEqual(m, m2)
self.assertEqual(m2, m)
different_m = message_module.TestAllTypes()
different_m.repeated_float.append(1)
self.assertNotEqual(m, different_m)
self.assertNotEqual(different_m, m)
self.assertIsNotNone(m)
self.assertIsNotNone(m)
self.assertNotEqual(42, m)
self.assertNotEqual(m, 42)
self.assertNotEqual('foo', m)
self.assertNotEqual(m, 'foo')
self.assertEqual(mock.ANY, m)
self.assertEqual(m, mock.ANY)
class ComparesWithFoo(object):
def __eq__(self, other):
if getattr(other, 'optional_string', 'not_foo') == 'foo':
return True
return NotImplemented
m.optional_string = 'foo'
self.assertEqual(m, ComparesWithFoo())
self.assertEqual(ComparesWithFoo(), m)
m.optional_string = 'bar'
self.assertNotEqual(m, ComparesWithFoo())
self.assertNotEqual(ComparesWithFoo(), m)
# Class to test proto2-only features (required, extensions, etc.) # Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase @testing_refleaks.TestCase

@ -987,7 +987,7 @@ def _AddEqualsMethod(message_descriptor, cls):
def __eq__(self, other): def __eq__(self, other):
if (not isinstance(other, message_mod.Message) or if (not isinstance(other, message_mod.Message) or
other.DESCRIPTOR != self.DESCRIPTOR): other.DESCRIPTOR != self.DESCRIPTOR):
return False return NotImplemented
if self is other: if self is other:
return True return True

@ -2045,25 +2045,26 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
Py_INCREF(Py_NotImplemented); Py_INCREF(Py_NotImplemented);
return Py_NotImplemented; return Py_NotImplemented;
} }
bool equals = true; // If other is not a message, this implementation doesn't know how to perform
// If other is not a message, it cannot be equal. // comparisons.
if (!PyObject_TypeCheck(other, CMessage_Type)) { if (!PyObject_TypeCheck(other, CMessage_Type)) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
// Otherwise, we have a CMessage whose message we can inspect.
bool equals = true;
const google::protobuf::Message* other_message =
reinterpret_cast<CMessage*>(other)->message;
// If messages don't have the same descriptors, they are not equal.
if (equals &&
self->message->GetDescriptor() != other_message->GetDescriptor()) {
equals = false;
}
// Check the message contents.
if (equals &&
!google::protobuf::util::MessageDifferencer::Equals(
*self->message, *reinterpret_cast<CMessage*>(other)->message)) {
equals = false; equals = false;
} else {
// Otherwise, we have a CMessage whose message we can inspect.
const google::protobuf::Message* other_message =
reinterpret_cast<CMessage*>(other)->message;
// If messages don't have the same descriptors, they are not equal.
if (equals &&
self->message->GetDescriptor() != other_message->GetDescriptor()) {
equals = false;
}
// Check the message contents.
if (equals &&
!google::protobuf::util::MessageDifferencer::Equals(
*self->message, *reinterpret_cast<CMessage*>(other)->message)) {
equals = false;
}
} }
if (equals ^ (opid == Py_EQ)) { if (equals ^ (opid == Py_EQ)) {

Loading…
Cancel
Save