Fixed printing of nan floats/doubles in Python.

The second assert in _upb_EncodeRoundTripFloat is raised if val is a nan. This fix just returns the output of first spnprintf.

I am not sure how changes to this repo are made so feel free to ignore this CL.

To test this, you could
1. Define a proto with a float field
message Test {
  float val = 1;
}
2. In a python script, import the library and then set the val to nan and try to print it.
proto = Test(val=float('nan'))
print(proto)

This will cause a coredump due to assertion error:

assert.h assertion failed at third_party/upb/upb/lex/round_trip.c:46 in void _upb_EncodeRoundTripFloat(float, char *, size_t): strtof(buf, NULL) == val

Added the corresponding change to double too

PiperOrigin-RevId: 637127851
pull/16941/head
Protobuf Team Bot 10 months ago committed by Copybara-Service
parent ee98ba2c18
commit f65108072b
  1. 10
      python/google/protobuf/internal/message_test.py
  2. 10
      upb/lex/BUILD
  3. 10
      upb/lex/round_trip.c
  4. 35
      upb/lex/round_trip_test.cc

@ -382,6 +382,11 @@ class MessageTest(unittest.TestCase):
message.optional_float = 2.0
self.assertEqual(str(message), 'optional_float: 2.0\n')
def testFloatNanPrinting(self, message_module):
message = message_module.TestAllTypes()
message.optional_float = float('nan')
self.assertEqual(str(message), 'optional_float: nan\n')
def testHighPrecisionFloatPrinting(self, message_module):
msg = message_module.TestAllTypes()
msg.optional_float = 0.12345678912345678
@ -389,6 +394,11 @@ class MessageTest(unittest.TestCase):
msg.ParseFromString(msg.SerializeToString())
self.assertEqual(old_float, msg.optional_float)
def testDoubleNanPrinting(self, message_module):
message = message_module.TestAllTypes()
message.optional_double = float('nan')
self.assertEqual(str(message), 'optional_double: nan\n')
def testHighPrecisionDoublePrinting(self, message_module):
msg = message_module.TestAllTypes()
msg.optional_double = 0.12345678912345678

@ -41,6 +41,16 @@ cc_test(
],
)
cc_test(
name = "round_trip_test",
srcs = ["round_trip_test.cc"],
deps = [
":lex",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)
# begin:github_only
filegroup(
name = "source_files",

@ -8,6 +8,8 @@
#include "upb/lex/round_trip.h"
#include <float.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
// Must be last.
@ -28,6 +30,10 @@ static void upb_FixLocale(char* p) {
void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) {
assert(size >= kUpb_RoundTripBufferSize);
if (isnan(val)) {
snprintf(buf, size, "%s", "nan");
return;
}
snprintf(buf, size, "%.*g", DBL_DIG, val);
if (strtod(buf, NULL) != val) {
snprintf(buf, size, "%.*g", DBL_DIG + 2, val);
@ -38,6 +44,10 @@ void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) {
void _upb_EncodeRoundTripFloat(float val, char* buf, size_t size) {
assert(size >= kUpb_RoundTripBufferSize);
if (isnan(val)) {
snprintf(buf, size, "%s", "nan");
return;
}
snprintf(buf, size, "%.*g", FLT_DIG, val);
if (strtof(buf, NULL) != val) {
snprintf(buf, size, "%.*g", FLT_DIG + 3, val);

@ -0,0 +1,35 @@
#include "upb/lex/round_trip.h"
#include <math.h>
#include <gtest/gtest.h>
namespace {
TEST(RoundTripTest, Double) {
char buf[32];
_upb_EncodeRoundTripDouble(0.123456789, buf, sizeof(buf));
EXPECT_STREQ(buf, "0.123456789");
_upb_EncodeRoundTripDouble(0.0, buf, sizeof(buf));
EXPECT_STREQ(buf, "0");
_upb_EncodeRoundTripDouble(nan(""), buf, sizeof(buf));
EXPECT_STREQ(buf, "nan");
}
TEST(RoundTripTest, Float) {
char buf[32];
_upb_EncodeRoundTripFloat(0.123456, buf, sizeof(buf));
EXPECT_STREQ(buf, "0.123456");
_upb_EncodeRoundTripFloat(0.0, buf, sizeof(buf));
EXPECT_STREQ(buf, "0");
_upb_EncodeRoundTripFloat(nan(""), buf, sizeof(buf));
EXPECT_STREQ(buf, "nan");
}
} // namespace
Loading…
Cancel
Save