From 12d4f418a7311ed4d381bf82caead11d03ae7911 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Thu, 31 Aug 2023 12:29:02 -0700 Subject: [PATCH] Comparing a proto message with an object of unknown returns NotImplemented The Python comparison protocol requires that if an object doesn't know how to compare itself to an object of a different type, it returns NotImplemented rather than False. The interpreter will then try performing the comparison using the other operand. This translates, for protos, to: If a proto message doesn't know how to compare itself to an object of non-message type, it returns NotImplemented. This way, the interpreter will then try performing the comparison using the comparison methods of the other object, which may know how to compare itself to a message. If not, then Python will return the combined result (e.g., if both objects don't know how to perform __eq__, then the equality operator `==` return false). This change allows one to compare a proto with custom matchers such as mock.ANY that the message doesn't know how to compare to, regardless of whether mock.ANY is on the right-hand side or left-hand side of the equality (prior to this change, it only worked with mock.ANY on the left-hand side). Fixes https://github.com/protocolbuffers/protobuf/issues/9173 PiperOrigin-RevId: 561728156 --- .../google/protobuf/internal/message_test.py | 37 +++++++++++++++++++ .../protobuf/internal/python_message.py | 2 +- python/google/protobuf/pyext/message.cc | 35 +++++++++--------- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 6cf89e8fe8..6b90ddc220 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -46,6 +46,7 @@ import pickle import pydoc import sys import unittest +from unittest import mock import warnings cmp = lambda x, y: (x > y) - (x < y) @@ -1268,6 +1269,42 @@ class MessageTest(unittest.TestCase): self.assertEqual(bool, type(m.repeated_bool[0])) self.assertEqual(True, m.repeated_bool[0]) + def testEquality(self, message_module): + m = message_module.TestAllTypes() + m2 = message_module.TestAllTypes() + self.assertEqual(m, m) + self.assertEqual(m, m2) + self.assertEqual(m2, m) + + different_m = message_module.TestAllTypes() + different_m.repeated_float.append(1) + self.assertNotEqual(m, different_m) + self.assertNotEqual(different_m, m) + + self.assertIsNotNone(m) + self.assertIsNotNone(m) + self.assertNotEqual(42, m) + self.assertNotEqual(m, 42) + self.assertNotEqual('foo', m) + self.assertNotEqual(m, 'foo') + + self.assertEqual(mock.ANY, m) + self.assertEqual(m, mock.ANY) + + class ComparesWithFoo(object): + + def __eq__(self, other): + if getattr(other, 'optional_string', 'not_foo') == 'foo': + return True + return NotImplemented + + m.optional_string = 'foo' + self.assertEqual(m, ComparesWithFoo()) + self.assertEqual(ComparesWithFoo(), m) + m.optional_string = 'bar' + self.assertNotEqual(m, ComparesWithFoo()) + self.assertNotEqual(ComparesWithFoo(), m) + # Class to test proto2-only features (required, extensions, etc.) @testing_refleaks.TestCase diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index d6b36a620a..2bd8bc228a 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -987,7 +987,7 @@ def _AddEqualsMethod(message_descriptor, cls): def __eq__(self, other): if (not isinstance(other, message_mod.Message) or other.DESCRIPTOR != self.DESCRIPTOR): - return False + return NotImplemented if self is other: return True diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index a57b7f295c..0e04fe4e60 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -2045,25 +2045,26 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - bool equals = true; - // If other is not a message, it cannot be equal. + // If other is not a message, this implementation doesn't know how to perform + // comparisons. if (!PyObject_TypeCheck(other, CMessage_Type)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + // Otherwise, we have a CMessage whose message we can inspect. + bool equals = true; + const google::protobuf::Message* other_message = + reinterpret_cast(other)->message; + // If messages don't have the same descriptors, they are not equal. + if (equals && + self->message->GetDescriptor() != other_message->GetDescriptor()) { + equals = false; + } + // Check the message contents. + if (equals && + !google::protobuf::util::MessageDifferencer::Equals( + *self->message, *reinterpret_cast(other)->message)) { equals = false; - } else { - // Otherwise, we have a CMessage whose message we can inspect. - const google::protobuf::Message* other_message = - reinterpret_cast(other)->message; - // If messages don't have the same descriptors, they are not equal. - if (equals && - self->message->GetDescriptor() != other_message->GetDescriptor()) { - equals = false; - } - // Check the message contents. - if (equals && - !google::protobuf::util::MessageDifferencer::Equals( - *self->message, *reinterpret_cast(other)->message)) { - equals = false; - } } if (equals ^ (opid == Py_EQ)) {