diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index e42538e1a7..f25cf2ad4b 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -382,6 +382,11 @@ class MessageTest(unittest.TestCase): message.optional_float = 2.0 self.assertEqual(str(message), 'optional_float: 2.0\n') + def testFloatNanPrinting(self, message_module): + message = message_module.TestAllTypes() + message.optional_float = float('nan') + self.assertEqual(str(message), 'optional_float: nan\n') + def testHighPrecisionFloatPrinting(self, message_module): msg = message_module.TestAllTypes() msg.optional_float = 0.12345678912345678 @@ -389,6 +394,11 @@ class MessageTest(unittest.TestCase): msg.ParseFromString(msg.SerializeToString()) self.assertEqual(old_float, msg.optional_float) + def testDoubleNanPrinting(self, message_module): + message = message_module.TestAllTypes() + message.optional_double = float('nan') + self.assertEqual(str(message), 'optional_double: nan\n') + def testHighPrecisionDoublePrinting(self, message_module): msg = message_module.TestAllTypes() msg.optional_double = 0.12345678912345678 diff --git a/upb/lex/BUILD b/upb/lex/BUILD index b12e37d041..efcf9e9e40 100644 --- a/upb/lex/BUILD +++ b/upb/lex/BUILD @@ -41,6 +41,16 @@ cc_test( ], ) +cc_test( + name = "round_trip_test", + srcs = ["round_trip_test.cc"], + deps = [ + ":lex", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + # begin:github_only filegroup( name = "source_files", diff --git a/upb/lex/round_trip.c b/upb/lex/round_trip.c index 13104341b1..82818eea23 100644 --- a/upb/lex/round_trip.c +++ b/upb/lex/round_trip.c @@ -8,6 +8,8 @@ #include "upb/lex/round_trip.h" #include +#include +#include #include // Must be last. @@ -28,6 +30,10 @@ static void upb_FixLocale(char* p) { void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) { assert(size >= kUpb_RoundTripBufferSize); + if (isnan(val)) { + snprintf(buf, size, "%s", "nan"); + return; + } snprintf(buf, size, "%.*g", DBL_DIG, val); if (strtod(buf, NULL) != val) { snprintf(buf, size, "%.*g", DBL_DIG + 2, val); @@ -38,6 +44,10 @@ void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) { void _upb_EncodeRoundTripFloat(float val, char* buf, size_t size) { assert(size >= kUpb_RoundTripBufferSize); + if (isnan(val)) { + snprintf(buf, size, "%s", "nan"); + return; + } snprintf(buf, size, "%.*g", FLT_DIG, val); if (strtof(buf, NULL) != val) { snprintf(buf, size, "%.*g", FLT_DIG + 3, val); diff --git a/upb/lex/round_trip_test.cc b/upb/lex/round_trip_test.cc new file mode 100644 index 0000000000..c6fc718522 --- /dev/null +++ b/upb/lex/round_trip_test.cc @@ -0,0 +1,35 @@ +#include "upb/lex/round_trip.h" + +#include + +#include + +namespace { + +TEST(RoundTripTest, Double) { + char buf[32]; + + _upb_EncodeRoundTripDouble(0.123456789, buf, sizeof(buf)); + EXPECT_STREQ(buf, "0.123456789"); + + _upb_EncodeRoundTripDouble(0.0, buf, sizeof(buf)); + EXPECT_STREQ(buf, "0"); + + _upb_EncodeRoundTripDouble(nan(""), buf, sizeof(buf)); + EXPECT_STREQ(buf, "nan"); +} + +TEST(RoundTripTest, Float) { + char buf[32]; + + _upb_EncodeRoundTripFloat(0.123456, buf, sizeof(buf)); + EXPECT_STREQ(buf, "0.123456"); + + _upb_EncodeRoundTripFloat(0.0, buf, sizeof(buf)); + EXPECT_STREQ(buf, "0"); + + _upb_EncodeRoundTripFloat(nan(""), buf, sizeof(buf)); + EXPECT_STREQ(buf, "nan"); +} + +} // namespace