diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 6cf89e8fe8..6b90ddc220 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -46,6 +46,7 @@ import pickle import pydoc import sys import unittest +from unittest import mock import warnings cmp = lambda x, y: (x > y) - (x < y) @@ -1268,6 +1269,42 @@ class MessageTest(unittest.TestCase): self.assertEqual(bool, type(m.repeated_bool[0])) self.assertEqual(True, m.repeated_bool[0]) + def testEquality(self, message_module): + m = message_module.TestAllTypes() + m2 = message_module.TestAllTypes() + self.assertEqual(m, m) + self.assertEqual(m, m2) + self.assertEqual(m2, m) + + different_m = message_module.TestAllTypes() + different_m.repeated_float.append(1) + self.assertNotEqual(m, different_m) + self.assertNotEqual(different_m, m) + + self.assertIsNotNone(m) + self.assertIsNotNone(m) + self.assertNotEqual(42, m) + self.assertNotEqual(m, 42) + self.assertNotEqual('foo', m) + self.assertNotEqual(m, 'foo') + + self.assertEqual(mock.ANY, m) + self.assertEqual(m, mock.ANY) + + class ComparesWithFoo(object): + + def __eq__(self, other): + if getattr(other, 'optional_string', 'not_foo') == 'foo': + return True + return NotImplemented + + m.optional_string = 'foo' + self.assertEqual(m, ComparesWithFoo()) + self.assertEqual(ComparesWithFoo(), m) + m.optional_string = 'bar' + self.assertNotEqual(m, ComparesWithFoo()) + self.assertNotEqual(ComparesWithFoo(), m) + # Class to test proto2-only features (required, extensions, etc.) @testing_refleaks.TestCase diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index d6b36a620a..2bd8bc228a 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -987,7 +987,7 @@ def _AddEqualsMethod(message_descriptor, cls): def __eq__(self, other): if (not isinstance(other, message_mod.Message) or other.DESCRIPTOR != self.DESCRIPTOR): - return False + return NotImplemented if self is other: return True diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index a57b7f295c..0e04fe4e60 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -2045,25 +2045,26 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - bool equals = true; - // If other is not a message, it cannot be equal. + // If other is not a message, this implementation doesn't know how to perform + // comparisons. if (!PyObject_TypeCheck(other, CMessage_Type)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + // Otherwise, we have a CMessage whose message we can inspect. + bool equals = true; + const google::protobuf::Message* other_message = + reinterpret_cast(other)->message; + // If messages don't have the same descriptors, they are not equal. + if (equals && + self->message->GetDescriptor() != other_message->GetDescriptor()) { + equals = false; + } + // Check the message contents. + if (equals && + !google::protobuf::util::MessageDifferencer::Equals( + *self->message, *reinterpret_cast(other)->message)) { equals = false; - } else { - // Otherwise, we have a CMessage whose message we can inspect. - const google::protobuf::Message* other_message = - reinterpret_cast(other)->message; - // If messages don't have the same descriptors, they are not equal. - if (equals && - self->message->GetDescriptor() != other_message->GetDescriptor()) { - equals = false; - } - // Check the message contents. - if (equals && - !google::protobuf::util::MessageDifferencer::Equals( - *self->message, *reinterpret_cast(other)->message)) { - equals = false; - } } if (equals ^ (opid == Py_EQ)) {