Make TDP accept and discard garbage non-continuation bits on the 10th byte of a varint.

This is the behavior of the codegen parser and the reflection parser.

PiperOrigin-RevId: 504022814
pull/11632/head
Protobuf Team Bot 2 years ago committed by Copybara-Service
parent 46656ed080
commit 092e447280
  1. 22
      src/google/protobuf/generated_message_tctable_impl.h
  2. 4
      src/google/protobuf/generated_message_tctable_lite.cc
  3. 33
      src/google/protobuf/generated_message_tctable_lite_test.cc
  4. 188
      src/google/protobuf/message_unittest.inc
  5. 30
      src/google/protobuf/unittest.proto

@ -805,22 +805,24 @@ Parse64FallbackPair(const char* p, int64_t res1) {
// correctly, so all we have to do is check that the expected case is true.
if (PROTOBUF_PREDICT_TRUE(ptr[9] == 1)) goto done10;
// A value of 0, however, represents an over-serialized varint. This case
// should not happen, but if does (say, due to a nonconforming serializer),
// deassert the continuation bit that came from ptr[8].
if (ptr[9] == 0) {
if (PROTOBUF_PREDICT_FALSE(ptr[9] & 0x80)) {
// If the continue bit is set, it is an unterminated varint.
return {nullptr, 0};
}
// A zero value of the first bit of the 10th byte represents an
// over-serialized varint. This case should not happen, but if does (say, due
// to a nonconforming serializer), deassert the continuation bit that came
// from ptr[8].
if ((ptr[9] & 1) == 0) {
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// Use a small instruction since this is an uncommon code path.
asm("btcq $63,%0" : "+r"(res3));
#else
res3 ^= static_cast<uint64_t>(1) << 63;
#endif
goto done10;
}
// If the 10th byte/ptr[9] itself has any other value, then it is too big to
// fit in 64 bits. If the continue bit is set, it is an unterminated varint.
return {nullptr, 0};
goto done10;
done2:
return {p + 2, res1 & res2};
@ -963,7 +965,7 @@ PROTOBUF_NOINLINE const char* TcParser::FastTV32S1(PROTOBUF_TC_PARAM_DECL) {
if (PROTOBUF_PREDICT_FALSE(ptr[6] & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[7] & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[8] & 0x80)) {
if (ptr[9] & 0xFE) return Error(PROTOBUF_TC_PARAM_PASS);
if (ptr[9] & 0x80) return Error(PROTOBUF_TC_PARAM_PASS);
*out = RotateLeft(res, 28);
ptr += 10;
PROTOBUF_MUSTTAIL return ToTagDispatch(

@ -751,7 +751,9 @@ inline PROTOBUF_ALWAYS_INLINE const char* ParseVarint(const char* p,
if (PROTOBUF_PREDICT_FALSE(byte & 0x80)) {
byte = (byte - 0x80) | *p++;
if (PROTOBUF_PREDICT_FALSE(byte & 0x80)) {
byte = (byte - 0x80) | *p++;
// We only care about the continuation bit and the first bit
// of the 10th byte.
byte = (byte - 0x80) | (*p++ & 0x81);
if (PROTOBUF_PREDICT_FALSE(byte & 0x80)) {
return nullptr;
}

@ -126,6 +126,7 @@ TEST(FastVarints, NameHere) {
uint8_t serialize_buffer[64];
for (int size : {8, 32, 64, -8, -32, -64}) {
SCOPED_TRACE(size);
auto next_i = [](uint64_t i) {
// if i + 1 is a power of two, return that.
// (This will also match when i == -1, but for this loop we know that will
@ -136,7 +137,12 @@ TEST(FastVarints, NameHere) {
return i + (i - 1);
};
for (uint64_t i = 0; i + 1 != 0; i = next_i(i)) {
char fake_msg[64] = {
SCOPED_TRACE(i);
enum OverlongKind { kNotOverlong, kOverlong, kOverlongWithDroppedBits };
for (OverlongKind overlong :
{kNotOverlong, kOverlong, kOverlongWithDroppedBits}) {
SCOPED_TRACE(overlong);
alignas(16) char fake_msg[64] = {
kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, //
kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, //
kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, //
@ -150,6 +156,21 @@ TEST(FastVarints, NameHere) {
auto serialize_ptr = WireFormatLite::WriteUInt64ToArray(
/* field_number= */ 1, i, serialize_buffer);
if (overlong == kOverlong || overlong == kOverlongWithDroppedBits) {
// 1 for the tag plus 10 for the value
while (serialize_ptr - serialize_buffer < 11) {
serialize_ptr[-1] |= 0x80;
*serialize_ptr++ = 0;
}
if (overlong == kOverlongWithDroppedBits) {
// For this one we add some unused bits to the last byte.
// They should be dropped. Bits 1-6 are dropped. Bit 0 is used and
// bit 7 is checked for continuation.
serialize_ptr[-1] |= 0b0111'1110;
}
}
absl::string_view serialized{
reinterpret_cast<char*>(&serialize_buffer[0]),
static_cast<size_t>(serialize_ptr - serialize_buffer)};
@ -198,11 +219,13 @@ TEST(FastVarints, NameHere) {
case 8: {
if (end_ptr == nullptr) {
// If end_ptr is nullptr, that means the FastParser gave up and
// tried to pass control to MiniParse.... which is expected anytime
// we encounter something other than 0 or 1 encodings. (Since
// FastV8S1 is only used for `bool` fields.)
// tried to pass control to MiniParse.... which is expected
// anytime we encounter something other than 0 or 1 encodings.
// (Since FastV8S1 is only used for `bool` fields.)
if (overlong == kNotOverlong) {
EXPECT_NE(i, true);
EXPECT_NE(i, false);
}
EXPECT_THAT(fallback_hasbits_received, Optional(0));
// Like the mini-parser functions, and unlike the fast-parser
// functions, the fallback receives a ptr already incremented past
@ -219,6 +242,7 @@ TEST(FastVarints, NameHere) {
}; break;
case -32:
case 32: {
ASSERT_TRUE(end_ptr);
ASSERT_EQ(end_ptr - ptr, serialized.size());
auto actual_field = ReadAndReset<uint32_t>(&fake_msg[kFieldOffset]);
@ -247,6 +271,7 @@ TEST(FastVarints, NameHere) {
}
}
}
}
}
MATCHER_P3(IsEntryForFieldNum, table, field_num, field_numbers_table,

@ -1176,6 +1176,12 @@ TEST(MESSAGE_TEST_NAME, PreservesFloatingPointNegative0) {
std::signbit(out_message.optional_double()));
}
const uint8_t* SkipTag(const uint8_t* buf) {
while (*buf & 0x80) ++buf;
++buf;
return buf;
}
// Adds `non_canonical_bytes` bytes to the varint representation at the tail of
// the buffer.
// `buf` points to the start of the buffer, `p` points to one-past-the-end.
@ -1208,7 +1214,7 @@ std::string EncodeEnumValue(int number, int value, int non_canonical_bytes,
} else {
p = internal::WireFormatLite::WriteEnumToArray(number, value, p);
p = AddNonCanonicalBytes(buf, p, non_canonical_bytes);
p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
return std::string(buf, p);
}
}
@ -1257,9 +1263,15 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) {
SCOPED_TRACE(use_packed);
for (bool use_tail_field : {false, true}) {
SCOPED_TRACE(use_tail_field);
for (int non_canonical_bytes = 0; non_canonical_bytes < 5;
for (int non_canonical_bytes = 0; non_canonical_bytes < 9;
++non_canonical_bytes) {
SCOPED_TRACE(non_canonical_bytes);
for (bool add_garbage_bits : {false, true}) {
if (add_garbage_bits && non_canonical_bytes != 9) {
// We only add garbage on the 10th byte.
continue;
}
SCOPED_TRACE(add_garbage_bits);
for (int i = 0; i < descriptor->field_count(); ++i) {
const auto* field = descriptor->field(i);
if (field->name() == "other_field") continue;
@ -1278,9 +1290,13 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) {
auto encoded =
EncodeEnumValue(field->number(), value_desc->number(),
non_canonical_bytes, use_packed);
if (add_garbage_bits) {
// These bits should be discarded even in the `false` case.
encoded.back() |= 0b0111'1110;
}
if (use_tail_field) {
// Make sure that fields after this one can be parsed too. ie test
// that the "next" jump is correct too.
// Make sure that fields after this one can be parsed too. ie
// test that the "next" jump is correct too.
encoded += other_field;
}
@ -1327,6 +1343,7 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) {
}
}
}
}
}
std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) {
@ -1334,7 +1351,7 @@ std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) {
uint8_t* p = buf;
p = internal::WireFormatLite::WriteBoolToArray(number, value, p);
p = AddNonCanonicalBytes(buf, p, non_canonical_bytes);
p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
return std::string(buf, p);
}
@ -1358,6 +1375,12 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) {
for (int non_canonical_bytes = 0; non_canonical_bytes < 10;
++non_canonical_bytes) {
SCOPED_TRACE(non_canonical_bytes);
for (bool add_garbage_bits : {false, true}) {
if (add_garbage_bits && non_canonical_bytes != 9) {
// We only add garbage on the 10th byte.
continue;
}
SCOPED_TRACE(add_garbage_bits);
for (int i = 0; i < descriptor->field_count(); ++i) {
const auto* field = descriptor->field(i);
if (field->name() == "other_field") continue;
@ -1366,6 +1389,10 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) {
SCOPED_TRACE(value);
auto encoded =
EncodeBoolValue(field->number(), value, non_canonical_bytes);
if (add_garbage_bits) {
// These bits should be discarded even in the `false` case.
encoded.back() |= 0b0111'1110;
}
if (use_tail_field) {
// Make sure that fields after this one can be parsed too. ie test
// that the "next" jump is correct too.
@ -1378,7 +1405,81 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) {
EXPECT_EQ(ref->GetRepeatedBool(obj, field, 0), value);
} else {
EXPECT_TRUE(ref->HasField(obj, field));
EXPECT_EQ(ref->GetBool(obj, field), value);
EXPECT_EQ(ref->GetBool(obj, field), value)
<< testing::PrintToString(encoded);
}
auto& unknown = ref->GetUnknownFields(obj);
ASSERT_EQ(unknown.field_count(), 0);
}
}
}
}
}
}
std::string EncodeInt32Value(int number, int32_t value,
int non_canonical_bytes) {
uint8_t buf[100];
uint8_t* p = buf;
p = internal::WireFormatLite::WriteInt32ToArray(number, value, p);
p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
return std::string(buf, p);
}
TEST(MESSAGE_TEST_NAME, TestInt32Parsers) {
UNITTEST::Int32ParseTester obj;
const auto other_field = EncodeOtherField();
// Encode an int32 field for many different cases and verify that it can be
// parsed as expected.
// There are:
// - optional/repeated/packed fields
// - field tags that encode in 1/2/3 bytes
// - canonical and non-canonical encodings of the varint
// - last vs not last field
auto* ref = obj.GetReflection();
auto* descriptor = obj.descriptor();
for (bool use_tail_field : {false, true}) {
SCOPED_TRACE(use_tail_field);
for (int non_canonical_bytes = 0; non_canonical_bytes < 10;
++non_canonical_bytes) {
SCOPED_TRACE(non_canonical_bytes);
for (bool add_garbage_bits : {false, true}) {
if (add_garbage_bits && non_canonical_bytes != 9) {
// We only add garbage on the 10th byte.
continue;
}
SCOPED_TRACE(add_garbage_bits);
for (int i = 0; i < descriptor->field_count(); ++i) {
const auto* field = descriptor->field(i);
if (field->name() == "other_field") continue;
SCOPED_TRACE(field->full_name());
for (int32_t value : {1, 0, -1, (std::numeric_limits<int32_t>::min)(),
(std::numeric_limits<int32_t>::max)()}) {
SCOPED_TRACE(value);
auto encoded =
EncodeInt32Value(field->number(), value, non_canonical_bytes);
if (add_garbage_bits) {
// These bits should be discarded even in the `false` case.
encoded.back() |= 0b0111'1110;
}
if (use_tail_field) {
// Make sure that fields after this one can be parsed too. ie test
// that the "next" jump is correct too.
encoded += other_field;
}
EXPECT_TRUE(obj.ParseFromString(encoded));
if (field->is_repeated()) {
ASSERT_EQ(ref->FieldSize(obj, field), 1);
EXPECT_EQ(ref->GetRepeatedInt32(obj, field, 0), value);
} else {
EXPECT_TRUE(ref->HasField(obj, field));
EXPECT_EQ(ref->GetInt32(obj, field), value)
<< testing::PrintToString(encoded);
}
auto& unknown = ref->GetUnknownFields(obj);
ASSERT_EQ(unknown.field_count(), 0);
@ -1386,6 +1487,81 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) {
}
}
}
}
}
std::string EncodeInt64Value(int number, int64_t value,
int non_canonical_bytes) {
uint8_t buf[100];
uint8_t* p = buf;
p = internal::WireFormatLite::WriteInt64ToArray(number, value, p);
p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
return std::string(buf, p);
}
TEST(MESSAGE_TEST_NAME, TestInt64Parsers) {
UNITTEST::Int64ParseTester obj;
const auto other_field = EncodeOtherField();
// Encode an int64 field for many different cases and verify that it can be
// parsed as expected.
// There are:
// - optional/repeated/packed fields
// - field tags that encode in 1/2/3 bytes
// - canonical and non-canonical encodings of the varint
// - last vs not last field
auto* ref = obj.GetReflection();
auto* descriptor = obj.descriptor();
for (bool use_tail_field : {false, true}) {
SCOPED_TRACE(use_tail_field);
for (int non_canonical_bytes = 0; non_canonical_bytes < 10;
++non_canonical_bytes) {
SCOPED_TRACE(non_canonical_bytes);
for (bool add_garbage_bits : {false, true}) {
if (add_garbage_bits && non_canonical_bytes != 9) {
// We only add garbage on the 10th byte.
continue;
}
SCOPED_TRACE(add_garbage_bits);
for (int i = 0; i < descriptor->field_count(); ++i) {
const auto* field = descriptor->field(i);
if (field->name() == "other_field") continue;
SCOPED_TRACE(field->full_name());
for (int64_t value : {int64_t{1}, int64_t{0}, int64_t{-1},
(std::numeric_limits<int64_t>::min)(),
(std::numeric_limits<int64_t>::max)()}) {
SCOPED_TRACE(value);
auto encoded =
EncodeInt64Value(field->number(), value, non_canonical_bytes);
if (add_garbage_bits) {
// These bits should be discarded even in the `false` case.
encoded.back() |= 0b0111'1110;
}
if (use_tail_field) {
// Make sure that fields after this one can be parsed too. ie test
// that the "next" jump is correct too.
encoded += other_field;
}
EXPECT_TRUE(obj.ParseFromString(encoded));
if (field->is_repeated()) {
ASSERT_EQ(ref->FieldSize(obj, field), 1);
EXPECT_EQ(ref->GetRepeatedInt64(obj, field, 0), value);
} else {
EXPECT_TRUE(ref->HasField(obj, field));
EXPECT_EQ(ref->GetInt64(obj, field), value)
<< testing::PrintToString(encoded);
}
auto& unknown = ref->GetUnknownFields(obj);
ASSERT_EQ(unknown.field_count(), 0);
}
}
}
}
}
}
TEST(MESSAGE_TEST_NAME, IsDefaultInstance) {

@ -1565,6 +1565,36 @@ message BoolParseTester {
optional int32 other_field = 99;
};
message Int32ParseTester {
optional int32 optional_int32_lowfield = 1;
optional int32 optional_int32_midfield = 1001;
optional int32 optional_int32_hifield = 1000001;
repeated int32 repeated_int32_lowfield = 2;
repeated int32 repeated_int32_midfield = 1002;
repeated int32 repeated_int32_hifield = 1000002;
repeated int32 packed_int32_lowfield = 3 [packed = true];
repeated int32 packed_int32_midfield = 1003 [packed = true];
repeated int32 packed_int32_hifield = 1000003 [packed = true];
// An arbitrary field we can append to to break the runs of repeated fields.
optional int32 other_field = 99;
};
message Int64ParseTester {
optional int64 optional_int64_lowfield = 1;
optional int64 optional_int64_midfield = 1001;
optional int64 optional_int64_hifield = 1000001;
repeated int64 repeated_int64_lowfield = 2;
repeated int64 repeated_int64_midfield = 1002;
repeated int64 repeated_int64_hifield = 1000002;
repeated int64 packed_int64_lowfield = 3 [packed = true];
repeated int64 packed_int64_midfield = 1003 [packed = true];
repeated int64 packed_int64_hifield = 1000003 [packed = true];
// An arbitrary field we can append to to break the runs of repeated fields.
optional int32 other_field = 99;
};
message StringParseTester {
optional string optional_string_lowfield = 1;
optional string optional_string_midfield = 1001;

Loading…
Cancel
Save