From f65108072bbedad2e590f038eb23b8ef2235c329 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Fri, 24 May 2024 21:54:13 -0700 Subject: [PATCH] Fixed printing of nan floats/doubles in Python. The second assert in _upb_EncodeRoundTripFloat is raised if val is a nan. This fix just returns the output of first spnprintf. I am not sure how changes to this repo are made so feel free to ignore this CL. To test this, you could 1. Define a proto with a float field message Test { float val = 1; } 2. In a python script, import the library and then set the val to nan and try to print it. proto = Test(val=float('nan')) print(proto) This will cause a coredump due to assertion error: assert.h assertion failed at third_party/upb/upb/lex/round_trip.c:46 in void _upb_EncodeRoundTripFloat(float, char *, size_t): strtof(buf, NULL) == val Added the corresponding change to double too PiperOrigin-RevId: 637127851 --- .../google/protobuf/internal/message_test.py | 10 ++++++ upb/lex/BUILD | 10 ++++++ upb/lex/round_trip.c | 10 ++++++ upb/lex/round_trip_test.cc | 35 +++++++++++++++++++ 4 files changed, 65 insertions(+) create mode 100644 upb/lex/round_trip_test.cc diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index e42538e1a7..f25cf2ad4b 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -382,6 +382,11 @@ class MessageTest(unittest.TestCase): message.optional_float = 2.0 self.assertEqual(str(message), 'optional_float: 2.0\n') + def testFloatNanPrinting(self, message_module): + message = message_module.TestAllTypes() + message.optional_float = float('nan') + self.assertEqual(str(message), 'optional_float: nan\n') + def testHighPrecisionFloatPrinting(self, message_module): msg = message_module.TestAllTypes() msg.optional_float = 0.12345678912345678 @@ -389,6 +394,11 @@ class MessageTest(unittest.TestCase): msg.ParseFromString(msg.SerializeToString()) self.assertEqual(old_float, msg.optional_float) + def testDoubleNanPrinting(self, message_module): + message = message_module.TestAllTypes() + message.optional_double = float('nan') + self.assertEqual(str(message), 'optional_double: nan\n') + def testHighPrecisionDoublePrinting(self, message_module): msg = message_module.TestAllTypes() msg.optional_double = 0.12345678912345678 diff --git a/upb/lex/BUILD b/upb/lex/BUILD index b12e37d041..efcf9e9e40 100644 --- a/upb/lex/BUILD +++ b/upb/lex/BUILD @@ -41,6 +41,16 @@ cc_test( ], ) +cc_test( + name = "round_trip_test", + srcs = ["round_trip_test.cc"], + deps = [ + ":lex", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + # begin:github_only filegroup( name = "source_files", diff --git a/upb/lex/round_trip.c b/upb/lex/round_trip.c index 13104341b1..82818eea23 100644 --- a/upb/lex/round_trip.c +++ b/upb/lex/round_trip.c @@ -8,6 +8,8 @@ #include "upb/lex/round_trip.h" #include +#include +#include #include // Must be last. @@ -28,6 +30,10 @@ static void upb_FixLocale(char* p) { void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) { assert(size >= kUpb_RoundTripBufferSize); + if (isnan(val)) { + snprintf(buf, size, "%s", "nan"); + return; + } snprintf(buf, size, "%.*g", DBL_DIG, val); if (strtod(buf, NULL) != val) { snprintf(buf, size, "%.*g", DBL_DIG + 2, val); @@ -38,6 +44,10 @@ void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) { void _upb_EncodeRoundTripFloat(float val, char* buf, size_t size) { assert(size >= kUpb_RoundTripBufferSize); + if (isnan(val)) { + snprintf(buf, size, "%s", "nan"); + return; + } snprintf(buf, size, "%.*g", FLT_DIG, val); if (strtof(buf, NULL) != val) { snprintf(buf, size, "%.*g", FLT_DIG + 3, val); diff --git a/upb/lex/round_trip_test.cc b/upb/lex/round_trip_test.cc new file mode 100644 index 0000000000..c6fc718522 --- /dev/null +++ b/upb/lex/round_trip_test.cc @@ -0,0 +1,35 @@ +#include "upb/lex/round_trip.h" + +#include + +#include + +namespace { + +TEST(RoundTripTest, Double) { + char buf[32]; + + _upb_EncodeRoundTripDouble(0.123456789, buf, sizeof(buf)); + EXPECT_STREQ(buf, "0.123456789"); + + _upb_EncodeRoundTripDouble(0.0, buf, sizeof(buf)); + EXPECT_STREQ(buf, "0"); + + _upb_EncodeRoundTripDouble(nan(""), buf, sizeof(buf)); + EXPECT_STREQ(buf, "nan"); +} + +TEST(RoundTripTest, Float) { + char buf[32]; + + _upb_EncodeRoundTripFloat(0.123456, buf, sizeof(buf)); + EXPECT_STREQ(buf, "0.123456"); + + _upb_EncodeRoundTripFloat(0.0, buf, sizeof(buf)); + EXPECT_STREQ(buf, "0"); + + _upb_EncodeRoundTripFloat(nan(""), buf, sizeof(buf)); + EXPECT_STREQ(buf, "nan"); +} + +} // namespace