am 4b5874fa: Merge "Correctness: floating point equality using bits instead of ==."

* commit '4b5874fad099faefb469c632e4c7b854cea733ae':
  Correctness: floating point equality using bits instead of ==.
pull/91/head
Max Cai 11 years ago committed by Android Git Automerger
commit d44a519d8f
  1. 131
      java/src/test/java/com/google/protobuf/NanoTest.java
  2. 98
      src/google/protobuf/compiler/javanano/javanano_primitive_field.cc
  3. 2
      src/google/protobuf/unittest_accessors_nano.proto
  4. 2
      src/google/protobuf/unittest_has_nano.proto

@ -2886,13 +2886,6 @@ public class NanoTest extends TestCase {
TestAllTypesNano.BAR,
TestAllTypesNano.BAZ
};
// We set the _nan fields to something other than nan, because equality
// is defined for nan such that Float.NaN != Float.NaN, which makes any
// instance of TestAllTypesNano unequal to any other instance unless
// these fields are set. This is also the behavior of the regular java
// generator when the value of a field is NaN.
message.defaultFloatNan = 1.0f;
message.defaultDoubleNan = 1.0;
return message;
}
@ -2915,7 +2908,6 @@ public class NanoTest extends TestCase {
TestAllTypesNano.BAR,
TestAllTypesNano.BAZ
};
message.defaultFloatNan = 1.0f;
return message;
}
@ -2924,8 +2916,7 @@ public class NanoTest extends TestCase {
.setOptionalInt32(5)
.setOptionalString("Hello")
.setOptionalBytes(new byte[] {1, 2, 3})
.setOptionalNestedEnum(TestNanoAccessors.BAR)
.setDefaultFloatNan(1.0f);
.setOptionalNestedEnum(TestNanoAccessors.BAR);
message.optionalNestedMessage = new TestNanoAccessors.NestedMessage().setBb(27);
message.repeatedInt32 = new int[] { 5, 6, 7, 8 };
message.repeatedString = new String[] { "One", "Two" };
@ -2973,6 +2964,126 @@ public class NanoTest extends TestCase {
return message;
}
public void testEqualsWithSpecialFloatingPointValues() throws Exception {
// Checks that the nano implementation complies with Object.equals() when treating
// floating point numbers, i.e. NaN == NaN and +0.0 != -0.0.
// This test assumes that the generated equals() implementations are symmetric, so
// there will only be one direction for each equality check.
TestAllTypesNano m1 = new TestAllTypesNano();
m1.optionalFloat = Float.NaN;
m1.optionalDouble = Double.NaN;
TestAllTypesNano m2 = new TestAllTypesNano();
m2.optionalFloat = Float.NaN;
m2.optionalDouble = Double.NaN;
assertTrue(m1.equals(m2));
assertTrue(m1.equals(
MessageNano.mergeFrom(new TestAllTypesNano(), MessageNano.toByteArray(m1))));
m1.optionalFloat = +0f;
m2.optionalFloat = -0f;
assertFalse(m1.equals(m2));
m1.optionalFloat = -0f;
m1.optionalDouble = +0d;
m2.optionalDouble = -0d;
assertFalse(m1.equals(m2));
m1.optionalDouble = -0d;
assertTrue(m1.equals(m2));
assertFalse(m1.equals(new TestAllTypesNano())); // -0 does not equals() the default +0
assertTrue(m1.equals(
MessageNano.mergeFrom(new TestAllTypesNano(), MessageNano.toByteArray(m1))));
// -------
TestAllTypesNanoHas m3 = new TestAllTypesNanoHas();
m3.optionalFloat = Float.NaN;
m3.hasOptionalFloat = true;
m3.optionalDouble = Double.NaN;
m3.hasOptionalDouble = true;
TestAllTypesNanoHas m4 = new TestAllTypesNanoHas();
m4.optionalFloat = Float.NaN;
m4.hasOptionalFloat = true;
m4.optionalDouble = Double.NaN;
m4.hasOptionalDouble = true;
assertTrue(m3.equals(m4));
assertTrue(m3.equals(
MessageNano.mergeFrom(new TestAllTypesNanoHas(), MessageNano.toByteArray(m3))));
m3.optionalFloat = +0f;
m4.optionalFloat = -0f;
assertFalse(m3.equals(m4));
m3.optionalFloat = -0f;
m3.optionalDouble = +0d;
m4.optionalDouble = -0d;
assertFalse(m3.equals(m4));
m3.optionalDouble = -0d;
m3.hasOptionalFloat = false; // -0 does not equals() the default +0,
m3.hasOptionalDouble = false; // so these incorrect 'has' flags should be disregarded.
assertTrue(m3.equals(m4)); // note: m4 has the 'has' flags set.
assertFalse(m3.equals(new TestAllTypesNanoHas())); // note: the new message has +0 defaults
assertTrue(m3.equals(
MessageNano.mergeFrom(new TestAllTypesNanoHas(), MessageNano.toByteArray(m3))));
// note: the deserialized message has the 'has' flags set.
// -------
TestNanoAccessors m5 = new TestNanoAccessors();
m5.setOptionalFloat(Float.NaN);
m5.setOptionalDouble(Double.NaN);
TestNanoAccessors m6 = new TestNanoAccessors();
m6.setOptionalFloat(Float.NaN);
m6.setOptionalDouble(Double.NaN);
assertTrue(m5.equals(m6));
assertTrue(m5.equals(
MessageNano.mergeFrom(new TestNanoAccessors(), MessageNano.toByteArray(m6))));
m5.setOptionalFloat(+0f);
m6.setOptionalFloat(-0f);
assertFalse(m5.equals(m6));
m5.setOptionalFloat(-0f);
m5.setOptionalDouble(+0d);
m6.setOptionalDouble(-0d);
assertFalse(m5.equals(m6));
m5.setOptionalDouble(-0d);
assertTrue(m5.equals(m6));
assertFalse(m5.equals(new TestNanoAccessors()));
assertTrue(m5.equals(
MessageNano.mergeFrom(new TestNanoAccessors(), MessageNano.toByteArray(m6))));
// -------
NanoReferenceTypes.TestAllTypesNano m7 = new NanoReferenceTypes.TestAllTypesNano();
m7.optionalFloat = Float.NaN;
m7.optionalDouble = Double.NaN;
NanoReferenceTypes.TestAllTypesNano m8 = new NanoReferenceTypes.TestAllTypesNano();
m8.optionalFloat = Float.NaN;
m8.optionalDouble = Double.NaN;
assertTrue(m7.equals(m8));
assertTrue(m7.equals(MessageNano.mergeFrom(
new NanoReferenceTypes.TestAllTypesNano(), MessageNano.toByteArray(m7))));
m7.optionalFloat = +0f;
m8.optionalFloat = -0f;
assertFalse(m7.equals(m8));
m7.optionalFloat = -0f;
m7.optionalDouble = +0d;
m8.optionalDouble = -0d;
assertFalse(m7.equals(m8));
m7.optionalDouble = -0d;
assertTrue(m7.equals(m8));
assertFalse(m7.equals(new NanoReferenceTypes.TestAllTypesNano()));
assertTrue(m7.equals(MessageNano.mergeFrom(
new NanoReferenceTypes.TestAllTypesNano(), MessageNano.toByteArray(m7))));
}
public void testNullRepeatedFields() throws Exception {
// Check that serialization after explicitly setting a repeated field
// to null doesn't NPE.

@ -175,38 +175,6 @@ int FixedSize(FieldDescriptor::Type type) {
return -1;
}
// Returns true if the field has a default value equal to NaN.
bool IsDefaultNaN(const FieldDescriptor* field) {
switch (field->type()) {
case FieldDescriptor::TYPE_INT32 : return false;
case FieldDescriptor::TYPE_UINT32 : return false;
case FieldDescriptor::TYPE_SINT32 : return false;
case FieldDescriptor::TYPE_FIXED32 : return false;
case FieldDescriptor::TYPE_SFIXED32: return false;
case FieldDescriptor::TYPE_INT64 : return false;
case FieldDescriptor::TYPE_UINT64 : return false;
case FieldDescriptor::TYPE_SINT64 : return false;
case FieldDescriptor::TYPE_FIXED64 : return false;
case FieldDescriptor::TYPE_SFIXED64: return false;
case FieldDescriptor::TYPE_FLOAT :
return isnan(field->default_value_float());
case FieldDescriptor::TYPE_DOUBLE :
return isnan(field->default_value_double());
case FieldDescriptor::TYPE_BOOL : return false;
case FieldDescriptor::TYPE_STRING : return false;
case FieldDescriptor::TYPE_BYTES : return false;
case FieldDescriptor::TYPE_ENUM : return false;
case FieldDescriptor::TYPE_GROUP : return false;
case FieldDescriptor::TYPE_MESSAGE : return false;
// No default because we want the compiler to complain if any new
// types are added.
}
GOOGLE_LOG(FATAL) << "Can't get here.";
return false;
}
// Return true if the type is a that has variable length
// for instance String's.
bool IsVariableLenType(JavaType type) {
@ -384,15 +352,21 @@ GenerateSerializationConditional(io::Printer* printer) const {
printer->Print(variables_,
"if (");
}
if (IsArrayType(GetJavaType(descriptor_))) {
JavaType java_type = GetJavaType(descriptor_);
if (IsArrayType(java_type)) {
printer->Print(variables_,
"!java.util.Arrays.equals(this.$name$, $default$)) {\n");
} else if (IsReferenceType(GetJavaType(descriptor_))) {
} else if (IsReferenceType(java_type)) {
printer->Print(variables_,
"!this.$name$.equals($default$)) {\n");
} else if (IsDefaultNaN(descriptor_)) {
} else if (java_type == JAVATYPE_FLOAT) {
printer->Print(variables_,
"!$capitalized_type$.isNaN(this.$name$)) {\n");
"java.lang.Float.floatToIntBits(this.$name$)\n"
" != java.lang.Float.floatToIntBits($default$)) {\n");
} else if (java_type == JAVATYPE_DOUBLE) {
printer->Print(variables_,
"java.lang.Double.doubleToLongBits(this.$name$)\n"
" != java.lang.Double.doubleToLongBits($default$)) {\n");
} else {
printer->Print(variables_,
"this.$name$ != $default$) {\n");
@ -464,6 +438,36 @@ GenerateEqualsCode(io::Printer* printer) const {
printer->Print(") {\n"
" return false;\n"
"}\n");
} else if (java_type == JAVATYPE_FLOAT) {
printer->Print(variables_,
"{\n"
" int bits = java.lang.Float.floatToIntBits(this.$name$);\n"
" if (bits != java.lang.Float.floatToIntBits(other.$name$)");
if (params_.generate_has()) {
printer->Print(variables_,
"\n"
" || (bits == java.lang.Float.floatToIntBits($default$)\n"
" && this.has$capitalized_name$ != other.has$capitalized_name$)");
}
printer->Print(") {\n"
" return false;\n"
" }\n"
"}\n");
} else if (java_type == JAVATYPE_DOUBLE) {
printer->Print(variables_,
"{\n"
" long bits = java.lang.Double.doubleToLongBits(this.$name$);\n"
" if (bits != java.lang.Double.doubleToLongBits(other.$name$)");
if (params_.generate_has()) {
printer->Print(variables_,
"\n"
" || (bits == java.lang.Double.doubleToLongBits($default$)\n"
" && this.has$capitalized_name$ != other.has$capitalized_name$)");
}
printer->Print(") {\n"
" return false;\n"
" }\n"
"}\n");
} else {
printer->Print(variables_,
"if (this.$name$ != other.$name$");
@ -623,12 +627,26 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
void AccessorPrimitiveFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
switch (GetJavaType(descriptor_)) {
// For all Java primitive types below, the hash codes match the
// results of BoxedType.valueOf(primitiveValue).hashCode().
case JAVATYPE_INT:
case JAVATYPE_LONG:
// For all Java primitive types below, the equality checks match the
// results of BoxedType.valueOf(primitiveValue).equals(otherValue).
case JAVATYPE_FLOAT:
printer->Print(variables_,
"if ($different_has$\n"
" || java.lang.Float.floatToIntBits($name$_)\n"
" != java.lang.Float.floatToIntBits(other.$name$_)) {\n"
" return false;\n"
"}\n");
break;
case JAVATYPE_DOUBLE:
printer->Print(variables_,
"if ($different_has$\n"
" || java.lang.Double.doubleToLongBits($name$_)\n"
" != java.lang.Double.doubleToLongBits(other.$name$_)) {\n"
" return false;\n"
"}\n");
break;
case JAVATYPE_INT:
case JAVATYPE_LONG:
case JAVATYPE_BOOLEAN:
printer->Print(variables_,
"if ($different_has$\n"

@ -49,6 +49,8 @@ message TestNanoAccessors {
// Singular
optional int32 optional_int32 = 1;
optional float optional_float = 11;
optional double optional_double = 12;
optional string optional_string = 14;
optional bytes optional_bytes = 15;

@ -49,6 +49,8 @@ message TestAllTypesNanoHas {
// Singular
optional int32 optional_int32 = 1;
optional float optional_float = 11;
optional double optional_double = 12;
optional string optional_string = 14;
optional bytes optional_bytes = 15;

Loading…
Cancel
Save